| core.custom_jvp_generic (Generic passthrough for custom JVP calls) |
➖ |
custom_jvp_square ✅
custom_jvp_square_f64 ✅ |
0.7.1 |
| core.custom_vjp_generic (Generic passthrough for custom VJP calls) |
➖ |
custom_vjp_square ✅
custom_vjp_square_f64 ✅ |
0.7.1 |
| core.dim_as_value |
Cast Gather Reshape Shape |
dim_as_value_dynamic ✅
dim_as_value_dynamic_f64 ✅
dim_as_value ✅
dim_as_value_f64 ✅ |
0.5.0 |
| core.jit_inline |
➖ |
jit_identity ✅
jit_identity_f64 ✅ |
0.9.0 |
| dm_pix.depth_to_space |
DepthToSpace |
depth_to_space_nhwc ✅ |
0.12.0 |
| eqx.conv |
Conv |
eqx_conv2d_nchw ✅
eqx_conv2d_batched_nchw ✅ |
0.10.0 |
| eqx.dropout |
Dropout Not |
eqx_dropout_inference_mode ✅
eqx_dropout_inference_mode_f64 ✅
eqx_dropout_training_mode ✅
eqx_dropout_training_mode_f64 ✅
eqx_dropout_dynamic_inference ✅
eqx_dropout_dynamic_inference_f64 ✅
eqx_dropout_batched_inference_dynamic ✅
eqx_dropout_batched_inference_dynamic_f64 ✅
eqx_dropout_batched_inference ✅
eqx_dropout_batched_inference_f64 ✅ |
0.8.0 |
| eqx.identity |
Identity |
eqx_identity_static ✅
eqx_identity_static_f64 ✅
eqx_identity_symbolic_batch_dynamic ✅
eqx_identity_symbolic_batch_dynamic_f64 ✅
eqx_identity_symbolic_batch ✅
eqx_identity_symbolic_batch_f64 ✅ |
0.8.0 |
| eqx.layer_norm |
LayerNormalization |
layer_norm ✅
layer_norm_f64 ✅
layer_norm_multiaxis ✅
layer_norm_multiaxis_f64 ✅
batched_layer_norm_dynamic ✅
batched_layer_norm_dynamic_f64 ✅
batched_layer_norm ✅
batched_layer_norm_f64 ✅
layer_norm_no_bias_no_scale ✅
layer_norm_no_bias_no_scale_f64 ✅ |
0.8.0 |
| eqx.linear |
Gemm Reshape |
eqx_linear_symbolic_batch_dynamic ✅
eqx_linear_symbolic_batch_dynamic_f64 ✅
eqx_linear_symbolic_batch ✅
eqx_linear_symbolic_batch_f64 ✅
eqx_linear_no_bias_symbolic_batch_dynamic ✅
eqx_linear_no_bias_symbolic_batch_dynamic_f64 ✅
eqx_linear_no_bias_symbolic_batch ✅
eqx_linear_no_bias_symbolic_batch_f64 ✅
eqx_linear_no_bias_vector ✅
eqx_linear_no_bias_vector_f64 ✅
eqx_linear_high_rank ✅
eqx_linear_high_rank_f64 ✅
eqx_linear_vector ✅
eqx_linear_vector_f64 ✅ |
0.8.0 |
| eqx.multihead_attention |
Gemm MatMul Softmax |
eqx_multihead_attention ✅
eqx_multihead_attention_core_dynamic ✅
eqx_multihead_attention_core_static ✅ |
0.10.0 |
| eqx.rms_norm |
RMSNormalization |
rms_norm_vector ✅
rms_norm_vector_f64 ✅
rms_norm_no_affine ✅
rms_norm_no_affine_f64 ✅ |
0.10.2 |
| eqx.rotary_positional_embedding |
Add Concat Multiply RotaryEmbedding |
eqx_rotary_positional_embedding ✅
eqx_rotary_positional_embedding_heads ✅ |
0.10.0 |
| jax_image.resize |
Resize |
resize_linear ✅
resize_nearest ✅ |
0.10.0 |
| jnp.add |
Add |
add ✅
add_f64 ✅
jnp_add_vector ✅
jnp_add_vector_f64 ✅
jnp_add_broadcast ✅
jnp_add_broadcast_f64 ✅
add_vmap_batching ✅
add_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.arange |
Range |
arange_data_dependent_indices ✅
arange_stop_only_concrete_input_val ✅
arange_stop_only_concrete_input_val_f64 ✅
arange_start_stop_concrete_input_val ✅
arange_start_stop_concrete_input_val_f64 ✅
arange_start_stop_step_concrete_input_val ✅
arange_start_stop_step_concrete_input_val_f64 ✅
arange_float_concrete_input_val ✅
arange_float_concrete_input_val_f64 ✅
arange_static_stop_only_int ✅
arange_static_stop_only_int_f64 ✅
arange_static_stop_only_float ✅
arange_static_stop_only_float_f64 ✅
arange_static_start_stop_int ✅
arange_static_start_stop_int_f64 ✅
arange_static_start_stop_step_int ✅
arange_static_start_stop_step_int_f64 ✅
arange_static_empty_result_pos_step ✅
arange_static_empty_result_pos_step_f64 ✅
arange_static_empty_result_neg_step ✅
arange_static_empty_result_neg_step_f64 ✅
arange_static_negative_step ✅
arange_static_negative_step_f64 ✅
arange_static_float_step_explicit_dtype ✅
arange_static_float_step_explicit_dtype_f64 ✅
arange_static_float_step_inferred_dtype ✅
arange_static_float_step_inferred_dtype_f64 ✅
arange_static_stop_zero ✅
arange_static_stop_zero_f64 ✅
arange_static_start_equals_stop ✅
arange_static_start_equals_stop_f64 ✅
arange_static_large_numbers_int ✅
arange_static_large_numbers_int_f64 ✅ |
0.5.2 |
| jnp.clip |
Clip Max Min |
clip_i32_scalar_bounds ✅
clip_i32_scalar_bounds_f64 ✅
clip_f32_scalar_bounds_no_upcast_f64_mode ✅
clip_only_upper ✅
clip_only_upper_f64 ✅
clip_only_lower ✅
clip_only_lower_f64 ✅
clip_broadcast_bounds ✅ |
0.8.0 |
| jnp.concatenate |
Concat |
concatenate_basic ✅
concatenate_basic_f64 ✅
concatenate_mixed_dtypes ✅
concatenate_mixed_dtypes_f64 ✅
concatenate_with_explicit_dtype ✅
concatenate_with_explicit_dtype_casts_inputs ✅
concatenate_abstract_middle_dim_dynamic ✅
concatenate_abstract_middle_dim_dynamic_f64 ✅
concatenate_abstract_middle_dim ✅
concatenate_abstract_middle_dim_f64 ✅
concatenate_tile_and_symbolic_dynamic ✅
concatenate_tile_and_symbolic_dynamic_f64 ✅
concatenate_tile_and_symbolic ✅
concatenate_tile_and_symbolic_f64 ✅ |
0.8.0 |
| jnp.conj |
Identity |
jnp_conj_real ✅
jnp_conj_real_f64 ✅
jnp_conj_complex64 ✅
jnp_conj_complex64_f64 ✅
conj_vmap_batching ✅
conj_vmap_batching_f64 ✅ |
0.10.1 |
| jnp.cumsum |
CumSum |
jnp_cumsum_axis1 ✅
jnp_cumsum_axis1_f64 ✅
jnp_cumsum_reverse_dtype ✅
cumsum_axis2_i32 ✅
cumsum_axis2_i32_f64 ✅
cumsum_axis2_reverse_i32 ✅
cumsum_axis2_reverse_i32_f64 ✅
cumsum_vmap_batching ✅
cumsum_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.einsum |
Einsum |
einsum_vector_dot ✅
einsum_vector_dot_f64 ✅
einsum_matrix_vector ✅
einsum_matrix_vector_f64 ✅
einsum_matrix_matrix_dynamic ✅
einsum_matrix_matrix_dynamic_f64 ✅
einsum_matrix_matrix ✅
einsum_matrix_matrix_f64 ✅
einsum_transpose ✅
einsum_transpose_f64 ✅
einsum_batch_transpose_dynamic ✅
einsum_batch_transpose_dynamic_f64 ✅
einsum_batch_transpose ✅
einsum_batch_transpose_f64 ✅
einsum_diag ✅
einsum_diag_f64 ✅
einsum_sum_reduce ✅
einsum_sum_reduce_f64 ✅
einsum_multi_operand ✅
einsum_multi_operand_f64 ✅
einsum_attention_logits_orig_dynamic ✅
einsum_attention_logits_orig_dynamic_f64 ✅
einsum_attention_logits_orig ✅
einsum_attention_logits_orig_f64 ✅
einsum_attention_output_orig_dynamic ✅
einsum_attention_output_orig_dynamic_f64 ✅
einsum_attention_output_orig ✅
einsum_attention_output_orig_f64 ✅
einsum_attention_logits_batched_dynamic ✅
einsum_attention_logits_batched_dynamic_f64 ✅
einsum_attention_logits_batched ✅
einsum_attention_logits_batched_f64 ✅
einsum_attention_output_batched_dynamic ✅
einsum_attention_output_batched_dynamic_f64 ✅
einsum_attention_output_batched ✅
einsum_attention_output_batched_f64 ✅
einsum_ellipsis_rank_mismatch ✅
einsum_ellipsis_rank_mismatch_f64 ✅
einsum_attention_logits_batched_rank_mismatch ✅
einsum_attention_logits_batched_rank_mismatch_f64 ✅ |
0.1.0 |
| jnp.fft |
DFT |
jnp_fft_complex64 ✅
jnp_fft_complex128 ✅ |
0.10.1 |
| jnp.ifft |
DFT |
jnp_ifft_complex64 ✅ |
0.10.1 |
| jnp.linspace |
Add Range |
linspace_static_basic ✅
linspace_static_basic_f64 ✅
linspace_static_endpoint_false ✅
linspace_static_endpoint_false_f64 ✅
linspace_static_int_inputs_default_dtype ✅
linspace_static_int_inputs_default_dtype_f64 ✅
linspace_basic_f32 ✅
linspace_basic_f32_f64 ✅
linspace_endpoint_false_i32 ✅
linspace_endpoint_false_i32_f64 ✅
linspace_num_zero ✅
linspace_num_zero_f64 ✅
linspace_num_one ✅
linspace_num_one_f64 ✅
linspace_static_num_0 ✅
linspace_static_num_0_f64 ✅
linspace_static_num_1 ✅
linspace_static_num_1_f64 ✅
linspace_vmap_batching ✅
linspace_vmap_batching_f64 ✅ |
0.5.2 |
| jnp.matmul |
MatMul |
matmul_1d ✅
matmul_1d_f64 ✅
matmul_1d_2d ✅
matmul_1d_2d_f64 ✅
matmul_2d ✅
matmul_2d_f64 ✅
matmul_2d_1d ✅
matmul_2d_1d_f64 ✅
matmul_3d ✅
matmul_3d_f64 ✅
matmul_dynamic_dynamic ✅
matmul_dynamic_dynamic_f64 ✅
matmul_dynamic ✅
matmul_dynamic_f64 ✅
matmul_dynamic_a_dynamic ✅
matmul_dynamic_a_dynamic_f64 ✅
matmul_dynamic_a ✅
matmul_dynamic_a_f64 ✅
matmul_complex64 ✅
matmul_complex64_f64 ✅
matmul_vmap_batching ✅
matmul_vmap_batching_f64 ✅ |
0.1.0 |
| jnp.mean |
ReduceMean |
basic_mean ✅
basic_mean_f64 ✅
mean_with_axis ✅
mean_with_axis_f64 ✅
mean_with_keepdims ✅
mean_with_keepdims_f64 ✅
jnp_mean_basic ✅
jnp_mean_basic_f64 ✅
jnp_mean_axis ✅
jnp_mean_axis_f64 ✅ |
0.12.0 |
| jnp.outer |
Mul |
outer_vector ✅
outer_vector_f64 ✅
outer ✅
outer_f64 ✅
outer_vmap_batching ✅
outer_vmap_batching_f64 ✅ |
0.10.0 |
| jnp.pow |
Pow |
jnp_pow_vector ✅
jnp_pow_vector_f64 ✅
pow_jnp_pow ✅
pow_jnp_pow_f64 ✅
pow_vmap_batching ✅
pow_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.power |
Pow |
jnp_power_vector ✅
jnp_power_vector_f64 ✅
pow_jnp_power ✅
pow_jnp_power_f64 ✅
power_vmap_batching ✅
power_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.prod |
ReduceProd |
basic_prod ✅
basic_prod_f64 ✅
prod_with_axis ✅
prod_with_axis_f64 ✅
prod_with_keepdims ✅
prod_with_keepdims_f64 ✅
jnp_prod_basic ✅
jnp_prod_basic_f64 ✅
jnp_prod_axis ✅
jnp_prod_axis_f64 ✅
jnp_prod_keepdims ✅
jnp_prod_keepdims_f64 ✅
prod_vmap_batching ✅
prod_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.reshape |
Reshape |
reshape_1 ✅
reshape_1_f64 ✅
reshape_2 ✅
reshape_2_f64 ✅
reshape_3 ✅
reshape_3_f64 ✅
reshape_4_dynamic ✅
reshape_4_dynamic_f64 ✅
reshape_4 ✅
reshape_4_f64 ✅
reshape_to_scalar ✅
reshape_to_scalar_f64 ✅
reshape_from_scalar ✅
reshape_from_scalar_f64 ✅
reshape_cnn_dynamic ✅
reshape_cnn_dynamic_f64 ✅
reshape_cnn ✅
reshape_cnn_f64 ✅
reshape_valid_flatten_trailing ✅
reshape_valid_flatten_trailing_f64 ✅
reshape_with_target_shape_from_symbolic_dim_computation ✅
reshape_with_target_shape_from_symbolic_dim_computation_f64 ✅
reshape_basic ✅
reshape_basic_f64 ✅
reshape_infer ✅
reshape_infer_f64 ✅
reshape_symbolic_flatten_dynamic ✅
reshape_symbolic_flatten_dynamic_f64 ✅
reshape_symbolic_flatten ✅
reshape_symbolic_flatten_f64 ✅
reshape_vmap_batching_issue_144 ✅
reshape_vmap_batching_issue_144_f64 ✅ |
0.1.0 |
| jnp.rfft |
DFT |
jnp_rfft_float32 ✅ |
0.10.1 |
| jnp.select |
Where |
select_simple ✅
select_simple_f64 ✅
select_broadcast ✅
select_broadcast_f64 ✅
select_gpt2_attention_mask_dynamic ✅
select_gpt2_attention_mask_dynamic_f64 ✅
select_gpt2_attention_mask ✅
select_gpt2_attention_mask_f64 ✅
select_basic ✅
select_basic_f64 ✅
select_vmap_batching ✅
select_vmap_batching_f64 ✅ |
0.7.1 |
| jnp.shape |
Shape |
shape_basic ✅
shape_basic_f64 ✅
shape_dynamic_dynamic ✅
shape_dynamic_dynamic_f64 ✅
shape_dynamic ✅
shape_dynamic_f64 ✅
shape_vmap_batching ✅
shape_vmap_batching_f64 ✅ |
0.4.0 |
| jnp.sort |
Sort |
sort_1d ✅
sort_1d_f64 ✅
sort_2d_axis0 ✅
sort_2d_axis0_f64 ✅
sort_basic ✅
sort_basic_f64 ✅
sort_vmap_batching ✅
sort_vmap_batching_f64 ✅ |
0.5.2 |
| jnp.split |
Split |
split_by_sections ✅
split_by_sections_f64 ✅
split_by_indices ✅
split_by_indices_f64 ✅
split_by_indices_symbolic_dynamic ✅
split_by_indices_symbolic_dynamic_f64 ✅
split_by_indices_symbolic ✅
split_by_indices_symbolic_f64 ✅
split_sections ✅
split_sections_f64 ✅
split_indices_numpy ✅
split_indices_numpy_f64 ✅ |
0.7.2 |
| jnp.squeeze |
Squeeze |
squeeze_single_dim ✅
squeeze_single_dim_f64 ✅
squeeze_multiple_dims ✅
squeeze_multiple_dims_f64 ✅
squeeze_vit_output ✅
squeeze_vit_output_f64 ✅
squeeze_dynamic_batch_dynamic ✅
squeeze_dynamic_batch_dynamic_f64 ✅
squeeze_dynamic_batch ✅
squeeze_dynamic_batch_f64 ✅
squeeze_all_dims ✅
squeeze_all_dims_f64 ✅
squeeze_negative_axis ✅
squeeze_negative_axis_f64 ✅
squeeze_negative_axis_tuple ✅
squeeze_negative_axis_tuple_f64 ✅
squeeze_dynamic_and_negative_axis_dynamic ✅
squeeze_dynamic_and_negative_axis_dynamic_f64 ✅
squeeze_dynamic_and_negative_axis ✅
squeeze_dynamic_and_negative_axis_f64 ✅
squeeze_vmap_batching ✅
squeeze_vmap_batching_f64 ✅ |
0.1.0 |
| jnp.stack |
Concat Unsqueeze |
stack_axis_0 ✅
stack_axis_0_f64 ✅
stack_axis_1 ✅
stack_axis_1_f64 ✅
stack_negative_axis ✅
stack_negative_axis_f64 ✅
stack_scalars ✅
stack_scalars_f64 ✅
jnp_stack_axis0 ✅
jnp_stack_axis0_f64 ✅
jnp_stack_axis1 ✅
jnp_stack_axis1_f64 ✅
jnp_stack_negative_axis ✅
jnp_stack_negative_axis_f64 ✅
jnp_stack_scalars ✅
jnp_stack_scalars_f64 ✅ |
0.8.0 |
| jnp.take |
Gather |
take_data_dependent_indices ✅
take_basic_axis1 ✅
take_basic_axis1_f64 ✅
take_vmap_batching ✅
take_vmap_batching_f64 ✅ |
0.7.0 |
| jnp.tile |
Tile |
tile_repeats ✅
tile_repeats_f64 ✅
tile_a ✅
tile_a_f64 ✅
tile_b ✅
tile_b_f64 ✅
tile_c ✅
tile_c_f64 ✅
tile_d ✅
tile_d_f64 ✅
tile_dynamic_input_static ✅
tile_dynamic_input_static_f64 ✅
tile_dynamic_input_dynamic ✅
tile_dynamic_input_dynamic_f64 ✅
tile_dynamic_input ✅
tile_dynamic_input_f64 ✅
tile_pad ✅
tile_pad_f64 ✅
tile_param_symbolic_dynamic ✅
tile_param_symbolic_dynamic_f64 ✅
tile_param_symbolic ✅
tile_param_symbolic_f64 ✅
tile_with_symbolic_repeats_static ✅
tile_with_symbolic_repeats_static_f64 ✅
tile_with_symbolic_repeats_dynamic ✅
tile_with_symbolic_repeats_dynamic_f64 ✅
tile_with_symbolic_repeats ✅
tile_with_symbolic_repeats_f64 ✅
jnp_tile_basic ✅
jnp_tile_basic_f64 ✅
jnp_tile_scalar_repeats ✅
jnp_tile_scalar_repeats_f64 ✅
jnp_tile_pad_rank ✅
jnp_tile_pad_rank_f64 ✅
jnp_tile_symbolic_dynamic ✅
jnp_tile_symbolic_dynamic_f64 ✅
jnp_tile_symbolic ✅
jnp_tile_symbolic_f64 ✅
tile_vmap_batching ✅
tile_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.transpose |
Transpose |
transpose_basic ✅
transpose_basic_f64 ✅
transpose_reverse_default ✅
transpose_reverse_default_f64 ✅
transpose_high_dim ✅
transpose_high_dim_f64 ✅
transpose_3d ✅
transpose_3d_f64 ✅
transpose_4d ✅
transpose_4d_f64 ✅
transpose_no_axes ✅
transpose_no_axes_f64 ✅
transpose_reverse ✅
transpose_reverse_f64 ✅
transpose_square_matrix ✅
transpose_square_matrix_f64 ✅
transpose_vmap_batching ✅
transpose_vmap_batching_f64 ✅ |
0.1.0 |
| jnp.unstack |
Split Squeeze |
unstack_axis_0 ✅
unstack_axis_0_f64 ✅
unstack_axis_0_f64 ✅
unstack_axis_1 ✅
unstack_axis_1_f64 ✅
unstack_axis_1_f64 ✅
unstack_negative_axis ✅
unstack_negative_axis_f64 ✅
unstack_vmap_batching ✅
unstack_vmap_batching_f64 ✅ |
0.7.1 |
| jnp.where |
Where |
where_simple ✅
where_simple_f64 ✅
where_broadcast ✅
where_broadcast_f64 ✅
where_gpt_mask_scores_literal_else_dynamic ✅
where_gpt_mask_scores_literal_else_dynamic_f64 ✅
where_gpt_mask_scores_literal_else ✅
where_gpt_mask_scores_literal_else_f64 ✅
where_multidim_condition_scalar_branches_broadcast ✅
where_multidim_condition_scalar_branches_broadcast_f64 ✅
where_A ✅
where_A_f64 ✅
where_B ✅
where_B_f64 ✅
where_gpt_mask_scores_scalar_else_dynamic ✅
where_gpt_mask_scores_scalar_else_dynamic_f64 ✅
where_gpt_mask_scores_scalar_else ✅
where_gpt_mask_scores_scalar_else_f64 ✅
where_int_condition_cast ✅
where_int_condition_cast_f64 ✅
where_literal_else_pyfloat ✅
where_literal_else_pyfloat_f64 ✅
where_jax_int_literals_broadcast_f64_mode ✅
where_dtype_mismatch_f64_vs_i32_promote ✅
jnp_where_basic ✅
jnp_where_basic_f64 ✅
jnp_where_broadcast ✅
jnp_where_broadcast_f64 ✅
jnp_where_scalar_else ✅
jnp_where_scalar_else_f64 ✅
where_vmap_batching ✅
where_vmap_batching_f64 ✅ |
0.8.0 |
| lax.abs |
Abs |
abs ✅
abs_f64 ✅ |
0.5.0 |
| lax.add |
Add |
add ✅
add_f64 ✅
add_const ✅
add_const_f64 ✅
add_complex64 ✅
add_complex64_f64 ✅ |
0.2.0 |
| lax.add_any |
Add |
add_any_via_jvp_on_mul ✅
add_any_via_jvp_on_mul_f64 ✅ |
0.8.0 |
| lax.and |
And BitwiseAnd |
and_bool ✅
and_bool_f64 ✅
and_int ✅
and_int_f64 ✅ |
0.6.5 |
| lax.argmax |
ArgMax |
argmax_float_axis0 ✅
argmax_float_axis0_f64 ✅
argmax_float_axis1 ✅
argmax_float_axis1_f64 ✅
argmax_boolean_input_axis0_specific_values ✅
argmax_boolean_input_axis0_specific_values_f64 ✅
argmax_boolean_input_axis1_specific_values ✅
argmax_boolean_input_axis1_specific_values_f64 ✅
argmax_boolean_random_input_axis0 ✅
argmax_boolean_random_input_axis0_f64 ✅ |
0.2.0 |
| lax.argmin |
ArgMin |
argmin_test1 ✅
argmin_test1_f64 ✅
argmin_test2 ✅
argmin_test2_f64 ✅ |
0.2.0 |
| lax.bitcast_convert_type |
Bitcast |
bitcast_scalar_f32_to_i32 ✅
bitcast_scalar_f32_to_i32_f64 ✅
bitcast_tensor_i32_to_f32 ✅
bitcast_tensor_i32_to_f32_f64 ✅ |
0.7.2 |
| lax.bitwise_not |
BitwiseNot Not |
bitwise_not_bool ✅
bitwise_not_bool_f64 ✅
bitwise_not_i32 ✅
bitwise_not_i32_f64 ✅ |
0.7.5 |
| lax.broadcast_in_dim |
Expand Identity Reshape |
broadcast_in_dim ✅
broadcast_in_dim_f64 ✅
broadcast_in_dim_2d_to_3d ✅
broadcast_in_dim_2d_to_3d_f64 ✅
broadcast_in_dim_scalar ✅
broadcast_in_dim_scalar_f64 ✅
broadcast_in_dim_batch_dynamic ✅
broadcast_in_dim_batch_dynamic_f64 ✅
broadcast_in_dim_batch ✅
broadcast_in_dim_batch_f64 ✅
broadcast_in_dim_dynamic_B_dynamic ✅
broadcast_in_dim_dynamic_B_dynamic_f64 ✅
broadcast_in_dim_dynamic_B ✅
broadcast_in_dim_dynamic_B_f64 ✅ |
0.2.0 |
| lax.clamp |
Max Min |
clamp_i32_scalar_bounds ✅
clamp_i32_scalar_bounds_f64 ✅
clamp_scalar_float_bounds_match_x ✅
clamp_scalar_float_bounds_match_x_f64 ✅
clamp_vector_bounds_match ✅
clamp_pyint_bounds_promote_to_x_dtype ✅
clamp_pyint_bounds_promote_to_x_dtype_f64 ✅ |
0.7.5 |
| lax.concatenate |
Cast Concat |
concatenate ✅
concatenate_f64 ✅
concatenate_axis1_dynamic ✅
concatenate_axis1_dynamic_f64 ✅
concatenate_axis1 ✅
concatenate_axis1_f64 ✅
concatenate_axis0 ✅
concatenate_axis0_f64 ✅
concatenate_3d ✅
concatenate_3d_f64 ✅
concatenate_internal_int32_then_cast_to_f32_zeroarg ✅ |
0.2.0 |
| lax.cond |
If |
cond_scalar ✅
cond_scalar_f64 ✅
cond_multiple_operands_in_tuple ✅
cond_multiple_operands_in_tuple_f64 ✅
cond_my_new_complex_scenario ✅
cond_my_new_complex_scenario_f64 ✅
cond_nested_conditional ✅
cond_nested_conditional_f64 ✅
cond_variables ✅
cond_variables_f64 ✅
cond_internal_constant_f64 ✅
cond_passthrough_identity ✅
cond_passthrough_identity_f64 ✅
cond_with_scatter ✅
cond_with_scatter_f64 ✅ |
0.5.1 |
| lax.conj |
Identity |
conj_real ✅
conj_real_f64 ✅
conj_complex64 ✅
conj_complex64_f64 ✅ |
0.10.1 |
| lax.conv |
Conv |
conv ✅
conv2 ✅
conv_nchw ✅
conv_nhwc ✅
conv_general_dilated_nhwc_output ✅
conv_complex64 ✅
conv_complex64_nhwc ✅
conv_complex128_grouped ✅ |
0.2.0 |
| lax.convert_element_type |
Cast |
convert_element_type ✅
convert_element_type_f64 ✅ |
0.2.0 |
lax.copy (Handles the JAX primitive lax.copy_p. The public API jax.lax.copy is deprecated, but the primitive still appears in transformed jaxprs.) |
Identity |
copy_float32_array ✅
copy_int64_scalar ✅ |
|
| lax.cos |
Cos |
cos ✅
cos_f64 ✅ |
0.4.4 |
| lax.cosh |
Cosh |
cosh ✅
cosh_f64 ✅ |
0.4.4 |
| lax.cumsum |
CumSum |
cumsum_i32_axis2 ✅
cumsum_i32_axis2_f64 ✅
cumsum_f32_axism1_reverse ✅
cumsum_f32_axism1_reverse_f64 ✅ |
0.7.4 |
| lax.device_put |
Identity |
device_put_array ✅
device_put_array_f64 ✅
device_put_scalar ✅
device_put_scalar_f64 ✅ |
0.4.0 |
| lax.div |
Div |
div ✅
div_f64 ✅
div_const ✅
div_const_f64 ✅
div_complex64 ✅
div_complex64_f64 ✅ |
0.2.0 |
| lax.dot_general |
Einsum MatMul/Gemm |
dot_contract_nm ✅
dot_contract_nm_f64 ✅
dot_contract_min ✅
dot_contract_min_f64 ✅
dot_general ✅
dot_general_f64 ✅
dot_general_lhs1_rhs1 ✅
dot_general_lhs1_rhs1_f64 ✅
dot_double_contract ✅
dot_double_contract_f64 ✅
dot_batched_double_contract ✅
dot_batched_double_contract_f64 ✅
dot_highrank_batch ✅
dot_highrank_batch_f64 ✅
dot_contract_inner_lhs_with_middle_rhs ✅
dot_contract_inner_lhs_with_middle_rhs_f64 ✅
dot_outer_product ✅
dot_outer_product_f64 ✅
dot_full_contract_scalar ✅
dot_full_contract_scalar_f64 ✅
dot_general_complex_matmul ✅
dot_general_complex_matmul_f64 ✅ |
0.2.0 |
| lax.dynamic_slice |
Slice |
dynamic_slice_test1 ✅
dynamic_slice_test1_f64 ✅
dynamic_slice_2d ✅
dynamic_slice_2d_f64 ✅
dynamic_slice_3d ✅
dynamic_slice_3d_f64 ✅
dynamic_slice_vit_like_dynamic ✅
dynamic_slice_vit_like_dynamic_f64 ✅
dynamic_slice_vit_like ✅
dynamic_slice_vit_like_f64 ✅ |
0.1.0 |
| lax.dynamic_update_slice |
ScatterND |
dus_1d_scalar_update ✅
dus_1d_scalar_update_f64 ✅
dus_1d_block_update ✅
dus_1d_block_update_f64 ✅
dus_2d_block_update ✅
dus_2d_block_update_f64 ✅
dus_3d_block_update ✅
dus_3d_block_update_f64 ✅
dus_4d_block_update ✅
dus_4d_block_update_f64 ✅ |
0.8.1 |
| lax.eq |
Equal |
eq ✅
eq_f64 ✅ |
0.2.0 |
| lax.erf |
Erf |
erf ✅ |
0.4.4 |
| lax.exp |
Exp |
exp ✅
exp_f64 ✅ |
0.2.0 |
| lax.fft |
DFT |
fft_complex64_1d ✅
fft_complex64_len8 ✅
fft_complex64_batch ✅
fft_complex128_1d ✅
fft_complex128_len8 ✅
ifft_complex64_1d ✅
ifft_complex64_len8 ✅
ifft_complex128_batch ✅
rfft_real32_1d ✅
rfft_real64_len8 ✅
irfft_complex64_1d ✅
irfft_complex128_len8 ✅ |
0.10.1 |
| lax.fori_loop |
Loop |
fori_loop_counter ✅
fori_loop_counter_f64 ✅
fori_loop_zero ✅
fori_loop_zero_f64 ✅
fori_loop_vector ✅
fori_loop_vector_f64 ✅
fori_loop_example ✅
fori_loop_example_f64 ✅
fori_loop_test ✅
fori_loop_test_f64 ✅ |
0.5.1 |
| lax.gather |
GatherND |
gather_trig_where_pipeline_f64_indices_i64 ✅
gather_trig_where_pipeline_f64_indices_i32 ✅
gather_f64_data_i64_indices_output_is_f64 ✅
gather_f64_data_i32_indices_cast_and_output_is_f64 ✅
gather_static ✅
gather_static_f64 ✅
gather_dynamic_batch_simple_index_dynamic ✅
gather_dynamic_batch_simple_index_dynamic_f64 ✅
gather_dynamic_batch_simple_index ✅
gather_dynamic_batch_simple_index_f64 ✅ |
0.2.0 |
| lax.greater_equal |
GreaterOrEqual |
greater_equal ✅
greater_equal_f64 ✅ |
0.7.5 |
| lax.gt |
Greater |
gt ✅
gt_f64 ✅ |
0.2.0 |
| lax.imag |
Mul |
imag_complex64_input ✅
imag_complex64_input_f64 ✅ |
0.10.2 |
| lax.integer_pow |
Pow |
integer_pow ✅
integer_pow_f64 ✅ |
0.2.0 |
| lax.iota |
Range |
iota_int32 ✅
iota_int32_f64 ✅
iota_float32 ✅
iota_float32_f64 ✅
broadcasted_iota ✅
broadcasted_iota_f64 ✅ |
0.5.0 |
| lax.less_equal |
LessOrEqual |
less_equal ✅
less_equal_f64 ✅ |
0.7.5 |
| lax.log |
Log |
log ✅
log_f64 ✅ |
0.2.0 |
| lax.log1p |
Add Log |
log1p ✅
log1p_f64 ✅ |
0.11.0 |
| lax.logistic |
Sigmoid |
lax_logistic_basic ✅
lax_logistic_basic_f64 ✅ |
0.7.2 |
| lax.lt |
Less |
lt ✅
lt_f64 ✅ |
0.2.0 |
| lax.max |
Max |
max ✅
max_f64 ✅ |
0.2.0 |
| lax.min |
Min |
min_test1 ✅
min_test1_f64 ✅ |
0.1.0 |
| lax.mul |
Mul |
mul_test1 ✅
mul_test1_f64 ✅
mul_test2 ✅
mul_test2_f64 ✅
mul_pyfloat_promotes_to_array_dtype_f64 ✅
mul_scalar_broadcast_promote_to_f64 ✅
mul_complex128 ✅
mul_complex64 ✅ |
0.1.0 |
| lax.ne |
Equal Not |
ne ✅
ne_f64 ✅ |
0.2.0 |
| lax.neg |
Neg |
neg ✅
neg_f64 ✅ |
0.2.0 |
| lax.or |
BitwiseOr Or |
or_bool_vec ✅
or_bool_vec_f64 ✅
or_int_vec ✅
or_int_vec_f64 ✅ |
0.7.2 |
| lax.pad |
Pad |
pad_const_1d ✅
pad_const_1d_f64 ✅
pad_const_2d ✅
pad_const_2d_f64 ✅
pad_const_2d_cval ✅
pad_const_2d_cval_f64 ✅
pad_inside_scan_smoke_f64 ✅
pad_inside_nested_scan_smoke_f64 ✅ |
0.8.0 |
| lax.pjit |
➖ |
pjit_inline_mul ✅
pjit_inline_mul_f64 ✅
pjit_inline_tuple ✅
pjit_inline_tuple_f64 ✅ |
0.1.0 |
| lax.pow |
Pow |
pow_basic ✅
pow_basic_f64 ✅
pow_lax ✅
pow_lax_f64 ✅ |
0.8.2 |
| lax.real |
Identity |
real_complex64_input ✅
real_complex64_input_f64 ✅ |
0.10.2 |
| lax.reduce_and |
ReduceMin |
reduce_and_all_true ✅
reduce_and_all_true_f64 ✅
reduce_and_one_false ✅
reduce_and_one_false_f64 ✅
reduce_and_keepdims ✅
reduce_and_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_max |
ReduceMax |
reduce_max ✅
reduce_max_f64 ✅
reduce_max_allaxes ✅
reduce_max_allaxes_f64 ✅
reduce_max_axes_input ✅
reduce_max_axes_input_f64 ✅
reduce_max_keepdims ✅
reduce_max_keepdims_f64 ✅ |
0.2.0 |
| lax.reduce_min |
ReduceMin |
reduce_min ✅
reduce_min_f64 ✅
reduce_min_allaxes ✅
reduce_min_allaxes_f64 ✅
reduce_min_keepdims ✅
reduce_min_keepdims_f64 ✅ |
0.2.0 |
| lax.reduce_or |
ReduceMax |
reduce_or_all_false ✅
reduce_or_all_false_f64 ✅
reduce_or_one_true ✅
reduce_or_one_true_f64 ✅
reduce_or_keepdims ✅
reduce_or_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_prod |
ReduceProd |
reduce_prod ✅
reduce_prod_f64 ✅
reduce_prod_allaxes ✅
reduce_prod_allaxes_f64 ✅
reduce_prod_dtype ✅
reduce_prod_dtype_f64 ✅
reduce_prod_dtype_f64 ✅
reduce_prod_dtype_f64_f64 ✅
reduce_prod_keepdims ✅
reduce_prod_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_sum |
ReduceSum |
reduce_sum ✅
reduce_sum_f64 ✅
reduce_sum_allaxes ✅
reduce_sum_allaxes_f64 ✅
reduce_sum_dtype ✅
reduce_sum_dtype_f64 ✅
reduce_sum_dtype_f64 ✅
reduce_sum_dtype_f64_f64 ✅
reduce_sum_keepdims ✅
reduce_sum_keepdims_f64 ✅ |
0.2.0 |
| lax.reduce_window_sum |
Conv |
reduce_window_sum_valid ✅
reduce_window_sum_same_padding ✅
reduce_window_sum_stride_dilate ✅
reduce_window_sum_int32 ✅
reduce_window_sum_base_dilation ✅ |
0.10.1 |
| lax.reduce_xor |
Mod ReduceSum |
reduce_xor_all_false ✅
reduce_xor_all_false_f64 ✅
reduce_xor_one_true ✅
reduce_xor_one_true_f64 ✅
reduce_xor_two_true ✅
reduce_xor_two_true_f64 ✅
reduce_xor_keepdims ✅
reduce_xor_keepdims_f64 ✅ |
0.6.1 |
| lax.rem |
Div Mod |
rem_int ✅
rem_int_f64 ✅
rem_float ✅
rem_float_f64 ✅
rem_int_neg ✅
rem_int_neg_f64 ✅
rem_float_neg ✅
rem_float_neg_f64 ✅ |
0.6.5 |
| lax.remat2 |
➖ |
remat2_scalar_sin_chain ✅
remat2_scalar_sin_chain_f64 ✅
remat2_tuple_passthrough ✅
remat2_tuple_passthrough_f64 ✅ |
0.6.5 |
| lax.reshape |
Reshape |
reshape_after_transpose_folds_const_shape ✅
reshape_after_transpose_folds_const_shape_f64 ✅
reshape_flatten_trailing_folds_const_shape ✅
reshape_flatten_trailing_folds_const_shape_f64 ✅
reshape ✅
reshape_f64 ✅
reshape_valid_squeeze_middle_dim_from_problematic_source ✅
reshape_valid_squeeze_middle_dim_from_problematic_source_f64 ✅
reshape_valid_flatten_trailing ✅
reshape_valid_flatten_trailing_f64 ✅
reshape_with_target_shape_from_symbolic_dim_computation ✅
reshape_with_target_shape_from_symbolic_dim_computation_f64 ✅
reshape_with_inferred_dimension_from_input_dynamic_dynamic ✅
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64 ✅
reshape_with_inferred_dimension_from_input_dynamic ✅
reshape_with_inferred_dimension_from_input_dynamic_f64 ✅
reshape_with_inferred_dimension_from_input ✅
reshape_with_inferred_dimension_from_input_f64 ✅
reshape_merge_symbolic_with_static_and_check_name_dynamic ✅
reshape_merge_symbolic_with_static_and_check_name ✅ |
0.2.0 |
| lax.rev |
Flip |
rev_vector ✅
rev_vector_f64 ✅
rev_matrix_axes01 ✅
rev_matrix_axes01_f64 ✅ |
0.7.5 |
| lax.rsqrt |
Div Sqrt |
rsqrt ✅
rsqrt_f64 ✅ |
0.10.2 |
| lax.scan |
Scan |
scan_identity_slice_helper ✅
scan_identity_slice_helper_f64 ✅
scan_cumsum ✅
scan_cumsum_f64 ✅
scan_carry_only ✅
scan_carry_only_f64 ✅
scan_multiple_sequences ✅
scan_multiple_sequences_f64 ✅
scan_multiple_carry ✅
scan_multiple_carry_f64 ✅
scan_matrix_carry_multidim_xs ✅
scan_matrix_carry_multidim_xs_f64 ✅
scan_no_xs ✅
scan_no_xs_f64 ✅
scan_fn ✅
scan_fn_f64 ✅
scan_jit_no_xs ✅
scan_jit_no_xs_f64 ✅
scan_captured_scalar ✅
scan_captured_scalar_f64 ✅
scan_rank0_sequence_vectorized ✅
scan_rank0_sequence_vectorized_f64 ✅
scan_two_diff_lengths ✅
scan_two_diff_lengths_f64 ✅
scan_two_diff_lengths_broadcast ✅
scan_two_diff_lengths_broadcast_f64 ✅
scan_two_diff_lengths_with_broadcast ✅
scan_nested_len_mismatch ✅
scan_nested_len_mismatch_f64 ✅
scan_captured_scalar_with_xs ✅
scan_captured_vector_with_xs_f64 ✅ |
0.5.1 |
| lax.scatter |
ScatterND |
scatter_set_axis0 ✅
scatter_set_axis0_f64 ✅
scatter_set_middle ✅
scatter_set_middle_f64 ✅
scatter_set_single ✅
scatter_set_single_f64 ✅
scatter_set_vector ✅
scatter_set_vector_f64 ✅
scatter_correct_axis_determination ✅
scatter_correct_axis_determination_f64 ✅
scatter_updates_slice_needed_axis0 ✅
scatter_updates_slice_needed_axis0_f64 ✅
scatter_from_user_warning_shapes_valid_jax ✅
scatter_from_user_warning_shapes_valid_jax_f64 ✅
scatter_user_error_scenario_precise ✅
scatter_user_error_scenario_precise_f64 ✅
scatter_window_update_f64 ✅
scatter_window_update_depth3_shapes_ok ✅
scatter_static_slice_set_f64 ✅
scatter_depth2_fp64_type_mismatch ✅
scatter_clip_2d_window_at_edge ✅
scatter_simple_2d_window_out_of_bounds ✅
scatter_depth2_mixed_dtypes_fp_mismatch_f64 ✅
scatter_depth2_mixed_dtypes_fp_mismatch ✅ |
0.4.4 |
| lax.scatter_add |
ScatterND(reduction='add') |
scatter_add_vector ✅
scatter_add_vector_f64 ✅
scatter_add_scalar ✅
scatter_add_scalar_f64 ✅
scatter_add_simple_1d ✅
scatter_add_simple_1d_f64 ✅
scatter_add_batch_updates_1d_operand ✅
scatter_add_batch_updates_1d_operand_f64 ✅
scatter_add_window_2d_operand_1d_indices ✅
scatter_add_window_2d_operand_1d_indices_f64 ✅
scatter_add_mismatched_window_dims_from_user_report ✅
scatter_add_mismatched_window_dims_from_user_report2 ✅
scatter_add_mismatched_window_dims_from_user_report3 ✅
scatter_add_fluids_pattern_updates_5_4_1_1 ✅
scatter_add_in_cond_float64 ✅
scatter_add_fp64_dtype_mismatch ✅
scatter_add_depth2_depth2_helper_regression ✅
scatter_depth2_fp64_type_mismatch ✅ |
0.5.3 |
| lax.scatter_max |
ScatterND(reduction='max') |
scatter_max_simple_1d ✅
scatter_max_simple_1d_f64 ✅
scatter_max_batch_updates_1d_operand ✅
scatter_max_batch_updates_1d_operand_f64 ✅
scatter_max_window_2d_operand_1d_indices ✅
scatter_max_window_2d_operand_1d_indices_f64 ✅
scatter_max_fp64_dtype_path_check ✅
scatter_max_depth2_helper_regression_fp64 ✅ |
0.7.5 |
| lax.scatter_min |
ScatterND(reduction='min') |
scatter_min_simple_1d ✅
scatter_min_simple_1d_f64 ✅
scatter_min_batch_updates_1d_operand ✅
scatter_min_batch_updates_1d_operand_f64 ✅
scatter_min_window_2d_operand_1d_indices ✅
scatter_min_window_2d_operand_1d_indices_f64 ✅
scatter_min_fp64_dtype_path_check ✅
scatter_min_depth2_helper_regression_fp64 ✅ |
0.7.5 |
| lax.scatter_mul |
ScatterND(reduction='mul') |
scatter_mul_simple_1d ✅
scatter_mul_simple_1d_f64 ✅
scatter_mul_batch_updates_1d_operand ✅
scatter_mul_batch_updates_1d_operand_f64 ✅
scatter_mul_window_2d_operand_1d_indices ✅
scatter_mul_window_2d_operand_1d_indices_f64 ✅
scatter_mul_mismatched_window_dims_from_user_report ✅
scatter_mul_mismatched_window_dims_from_user_report2 ✅
scatter_mul_mismatched_window_dims_from_user_report3 ✅
scatter_mul_fluids_pattern_updates_5_4_1_1 ✅
scatter_mul_in_cond_float64 ✅ |
0.6.4 |
| lax.select |
Where |
select_simple ✅
select_simple_f64 ✅
select_basic ✅
select_basic_f64 ✅
select_mask_scores_tensor_else_dynamic ✅
select_mask_scores_tensor_else_dynamic_f64 ✅
select_mask_scores_tensor_else_dynamic_f64 ✅
select_mask_scores_tensor_else ✅
select_mask_scores_tensor_else_f64 ✅
select_mask_scores_tensor_else_f64 ✅ |
0.7.1 |
| lax.select_n |
Equal Where |
select_n_bool_predicate_two_cases_float ✅
select_n_bool_predicate_two_cases_float_f64 ✅
select_n_bool_predicate_two_cases_int ✅
select_n_bool_predicate_two_cases_int_f64 ✅
select_n_bool_predicate_scalar_broadcast ✅
select_n_bool_predicate_scalar_broadcast_f64 ✅
select_n_int_indices_three_cases ✅
select_n_int_indices_three_cases_f64 ✅
select_n_int_indices_four_cases ✅
select_n_int_indices_four_cases_f64 ✅ |
0.2.0 |
| lax.shard_map |
➖ |
shard_map_inline_add ✅
shard_map_inline_add_f64 ✅ |
0.10.2 |
| lax.shift_right_logical |
BitShift |
shift_right_logical_vec ✅
shift_right_logical_vec_f64 ✅
shift_right_logical_scalar ✅
shift_right_logical_scalar_f64 ✅ |
0.7.2 |
| lax.sign |
Sign |
sign ✅
sign_f64 ✅ |
0.5.0 |
| lax.sin |
Sin |
sin ✅
sin_f64 ✅ |
0.4.4 |
| lax.sinh |
Sinh |
sinh ✅
sinh_f64 ✅ |
0.4.4 |
| lax.slice |
Slice |
slice_test1 ✅
slice_test1_f64 ✅
slice_3d_none_strides ✅
slice_3d_none_strides_f64 ✅
slice_scan_axis_drop ✅
slice_scan_axis_drop_f64 ✅ |
0.1.0 |
| lax.sort |
TopK |
sort_1d ✅
sort_1d_f64 ✅
sort_2d ✅
sort_2d_f64 ✅ |
0.2.0 |
| lax.split |
Split |
lax_split_equal_parts ✅
lax_split_equal_parts_f64 ✅
lax_split_unequal_parts ✅
lax_split_unequal_parts_f64 ✅ |
0.7.2 |
| lax.sqrt |
Sqrt |
sqrt ✅
sqrt_f64 ✅ |
0.2.0 |
| lax.square |
Mul |
square ✅
square_f64 ✅ |
0.2.0 |
| lax.squeeze |
Squeeze |
squeeze_single_axis ✅
squeeze_single_axis_f64 ✅
squeeze_all_unit_dims_default ✅
squeeze_all_unit_dims_default_f64 ✅
lax_squeeze_specific_axis_0 ✅
lax_squeeze_specific_axis_0_f64 ✅
lax_squeeze_multiple_axes ✅
lax_squeeze_multiple_axes_f64 ✅
lax_squeeze_no_op_empty_dims ✅
lax_squeeze_no_op_empty_dims_f64 ✅
lax_squeeze_problem_case_input_squeeze_only_axis_0 ✅
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64 ✅
lax_squeeze_problem_case_input_squeeze_axes_0_2 ✅
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64 ✅
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly ✅
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64 ✅ |
0.2.0 |
| lax.stop_gradient |
Identity |
stop_gradient ✅
stop_gradient_f64 ✅
stop_gradient_basic ✅
stop_gradient_basic_f64 ✅ |
0.2.0 |
| lax.sub |
Sub |
sub_test1 ✅
sub_test1_f64 ✅
sub_test2 ✅
sub_test2_f64 ✅
sub_const ✅
sub_const_f64 ✅
sub_complex64 ✅
sub_complex64_f64 ✅ |
0.1.0 |
| lax.tanh |
Tanh |
tanh ✅
tanh_f64 ✅ |
0.2.0 |
| lax.top_k |
TopK |
top_k_last_axis ✅
top_k_last_axis_f64 ✅
top_k_matrix ✅
top_k_matrix_f64 ✅ |
0.10.2 |
| lax.transpose |
Transpose |
transpose_basic ✅
transpose_basic_f64 ✅
transpose_square_matrix ✅
transpose_square_matrix_f64 ✅
transpose_3d ✅
transpose_3d_f64 ✅
transpose_4d ✅
transpose_4d_f64 ✅
transpose_reverse ✅
transpose_reverse_f64 ✅
transpose_no_axes ✅
transpose_no_axes_f64 ✅
transpose_nhwc_to_nchw ✅
transpose_nhwc_to_nchw_f64 ✅ |
0.2.0 |
| lax.while_loop |
Loop |
while_scalar_counter ✅
while_scalar_counter_f64 ✅
while_tuple_state ✅
while_tuple_state_f64 ✅
while_loop_counter ✅
while_loop_counter_f64 ✅
while_loop_vector ✅
while_loop_vector_f64 ✅
while_loop_f64 ✅
while_loop_multi_state_f32 ✅
while_loop_multi_state_f64 ✅
while_loop_with_closure ✅
while_loop_with_closure_f64 ✅
while_loop_basic ✅
while_loop_two_state ✅
while_loop_captured_tracer ✅
while_loop_with_scalar_state ✅
while_loop_renamed_passthrough ✅
while_loop_closure_topo ✅
while_loop_mixed_rank ✅
while_loop_tracer_passthrough ✅
while_loop_no_loop_output_reused_as_input ✅
while_loop_4d_and_scalar_state ✅
while_loop_4d_and_scalar_state_f64 ✅
while_loop_cnn_scalar_state_bug ✅
while_loop_cnn_scalar_state_bug_f64 ✅
while_loop_nnx_repro ✅
while_loop_nnx_repro_f64 ✅ |
0.5.1 |
| linen.activation |
➖ |
activation_glu_basic ✅
activation_glu_basic_f64 ✅
activation_hard_sigmoid_basic ✅
activation_hard_sigmoid_basic_f64 ✅
activation_hard_silu_basic ✅
activation_hard_silu_basic_f64 ✅
activation_hard_swish_basic ✅
activation_hard_swish_basic_f64 ✅
activation_hard_tanh_basic ✅
activation_hard_tanh_basic_f64 ✅
activation_log_sigmoid_basic ✅
activation_log_sigmoid_basic_f64 ✅
activation_log_softmax_basic ✅
activation_log_softmax_basic_f64 ✅
activation_relu6_basic ✅
activation_relu6_basic_f64 ✅
activation_silu_basic ✅
activation_silu_basic_f64 ✅
activation_swish_basic ✅
activation_swish_basic_f64 ✅
activation_tanh_basic ✅
activation_tanh_basic_f64 ✅
activation_normalize_basic ✅
activation_normalize_basic_f64 ✅
activation_one_hot_basic ✅
activation_one_hot_basic_f64 ✅ |
0.11.0 |
| linen.avg_pool |
AveragePool Transpose |
avg_pool_dynamic ✅
avg_pool ✅
avg_pool_same_padding_dynamic ✅
avg_pool_same_padding ✅
avg_pool_default_padding_dynamic ✅
avg_pool_default_padding ✅
avg_pool_stride1_dynamic ✅
avg_pool_stride1 ✅
avg_pool_win3x3_stride2_dynamic ✅
avg_pool_win3x3_stride2 ✅
avg_pool_stride_none_dynamic ✅
avg_pool_stride_none ✅
avg_pool_count_include_pad_false_dynamic ✅
avg_pool_count_include_pad_false ✅ |
0.11.0 |
| linen.batch_norm |
BatchNormalization |
batch_norm_no_bias_no_scale_dynamic ✅
batch_norm_no_bias_no_scale ✅
batch_norm_bias_no_scale_dynamic ✅
batch_norm_bias_no_scale ✅
batch_norm_no_bias_scale_dynamic ✅
batch_norm_no_bias_scale ✅
batch_norm_bias_scale_dynamic ✅
batch_norm_bias_scale ✅
batch_norm_3d_dynamic ✅
batch_norm_3d ✅
batch_norm_4d_dynamic ✅
batch_norm_4d ✅
batch_norm_4d_no_bias_no_scale_dynamic ✅
batch_norm_4d_no_bias_no_scale ✅ |
0.11.0 |
| linen.bidirectional |
Concat Loop |
bidirectional_basic_dynamic ✅
bidirectional_basic_dynamic_f64 ✅
bidirectional_basic ✅
bidirectional_basic_f64 ✅ |
0.11.0 |
| linen.conv |
CastLike Conv Reshape Transpose |
conv_basic_dynamic ✅
conv_basic ✅
conv_no_bias ✅
conv_stride ✅ |
0.11.0 |
| linen.conv_local |
Conv Gemm Reshape Transpose |
conv_local_valid ✅
conv_local_same ✅ |
0.11.0 |
| linen.conv_lstm_cell |
Add Conv Gemm Mul Sigmoid Tanh |
conv_lstm_cell_basic_dynamic ✅
conv_lstm_cell_basic ✅ |
0.11.0 |
| linen.conv_transpose |
ConvTranspose |
conv_transpose_basic_dynamic ✅
conv_transpose_basic ✅
conv_transpose_valid_stride ✅ |
0.11.0 |
| linen.dense |
Gemm |
dense_basic_dynamic ✅
dense_basic_dynamic_f64 ✅
dense_basic ✅
dense_basic_f64 ✅
dense_high_rank_dynamic_dynamic ✅
dense_high_rank_dynamic_dynamic_f64 ✅
dense_high_rank_static ✅
dense_high_rank_static_f64 ✅
dense_high_rank_no_bias ✅
dense_high_rank_no_bias_f64 ✅
dense_no_bias_dynamic ✅
dense_no_bias_dynamic_f64 ✅
dense_no_bias ✅
dense_no_bias_f64 ✅ |
0.11.0 |
| linen.dense_general |
CastLike Concat Gemm Reshape Shape Slice |
dense_general_basic_dynamic ✅
dense_general_basic_dynamic_f64 ✅
dense_general_basic ✅
dense_general_basic_f64 ✅
dense_general_multi_out ✅
dense_general_multi_out_f64 ✅
dense_general_contract_last_two ✅
dense_general_contract_last_two_f64 ✅
dense_general_dynamic_batch_dynamic ✅
dense_general_dynamic_batch_dynamic_f64 ✅
dense_general_no_bias ✅
dense_general_no_bias_f64 ✅ |
0.11.0 |
| linen.dot_product_attention |
MatMul Mul Softmax Transpose Where |
dot_product_attention_basic ✅
dot_product_attention_basic_f64 ✅ |
0.11.0 |
| linen.dot_product_attention_weights |
Add Div MatMul Mul ReduceSum Softmax Where |
dot_product_attention_weights_basic ✅
dot_product_attention_weights_basic_f64 ✅ |
0.11.0 |
| linen.dropout |
Dropout |
dropout_init_params_dynamic ✅
dropout_init_params_dynamic_f64 ✅
dropout_init_params ✅
dropout_init_params_f64 ✅
dropout_call_params_dynamic ✅
dropout_call_params_dynamic_f64 ✅
dropout_call_params ✅
dropout_call_params_f64 ✅ |
0.11.0 |
| linen.einsum |
Add Einsum Reshape |
einsum_with_bias ✅
einsum_with_bias_f64 ✅
einsum_no_bias ✅
einsum_no_bias_f64 ✅ |
0.11.0 |
| linen.embed |
Gather |
token_embedding_dynamic ✅
token_embedding_dynamic_f64 ✅
token_embedding ✅
token_embedding_f64 ✅
positional_embedding_dynamic ✅
positional_embedding_dynamic_f64 ✅
positional_embedding ✅
positional_embedding_f64 ✅ |
0.11.0 |
| linen.group_norm |
GroupNormalization |
group_norm_rank4 ✅
group_norm_rank2_dynamic ✅
group_norm_rank2 ✅
group_norm_no_bias_no_scale_dynamic ✅
group_norm_no_bias_no_scale ✅ |
0.11.0 |
| linen.gru_cell |
Add Gemm Mul Sigmoid Tanh |
gru_cell_basic_dynamic ✅
gru_cell_basic_dynamic_f64 ✅
gru_cell_basic ✅
gru_cell_basic_f64 ✅ |
0.11.0 |
| linen.instance_norm |
GroupNormalization |
instance_norm_rank4_dynamic ✅
instance_norm_rank4 ✅
instance_norm_rank2_dynamic ✅
instance_norm_rank2 ✅ |
0.11.0 |
| linen.layer_norm |
LayerNormalization |
layer_norm_dynamic ✅
layer_norm ✅
layer_norm_no_bias_no_scale_dynamic ✅
layer_norm_no_bias_no_scale ✅
layer_norm_multiaxis_dynamic ✅
layer_norm_multiaxis ✅
layer_norm_default_epsilon_dynamic ✅
layer_norm_default_epsilon ✅ |
0.11.0 |
| linen.lstm_cell |
Add Gemm Mul Sigmoid Tanh |
lstm_cell_basic_dynamic ✅
lstm_cell_basic_dynamic_f64 ✅
lstm_cell_basic ✅
lstm_cell_basic_f64 ✅ |
0.11.0 |
| linen.make_attention_mask |
Cast Mul |
make_attention_mask_basic ✅
make_attention_mask_basic_f64 ✅ |
0.11.0 |
| linen.make_causal_mask |
Cast Less |
make_causal_mask_basic ✅
make_causal_mask_basic_f64 ✅ |
0.11.0 |
| linen.max_pool |
MaxPool |
max_pool ✅
max_pool_same_padding ✅
max_pool_basic ✅
max_pool_same_dynamic ✅
max_pool_same ✅ |
0.11.0 |
| linen.mgu_cell |
Add Gemm Mul Sigmoid Tanh |
mgu_cell_basic_dynamic ✅
mgu_cell_basic_dynamic_f64 ✅
mgu_cell_basic ✅
mgu_cell_basic_f64 ✅ |
0.11.0 |
| linen.min_pool |
MaxPool Neg |
min_pool ✅
min_pool_same_padding ✅
min_pool_basic ✅
min_pool_same_dynamic ✅
min_pool_same ✅ |
0.11.0 |
| linen.multi_head_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
multi_head_attention_basic_dynamic ✅
multi_head_attention_basic_dynamic_f64 ✅
multi_head_attention_basic ✅
multi_head_attention_basic_f64 ✅
multi_head_attention_no_bias_dynamic ✅
multi_head_attention_no_bias_dynamic_f64 ✅
multi_head_attention_no_bias ✅
multi_head_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.multi_head_dot_product_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
multi_head_dot_product_attention_basic_dynamic ✅
multi_head_dot_product_attention_basic_dynamic_f64 ✅
multi_head_dot_product_attention_basic ✅
multi_head_dot_product_attention_basic_f64 ✅
multi_head_dot_product_attention_no_bias_dynamic ✅
multi_head_dot_product_attention_no_bias_dynamic_f64 ✅
multi_head_dot_product_attention_no_bias ✅
multi_head_dot_product_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.optimized_lstm_cell |
Add Gemm Mul Sigmoid Tanh |
optimized_lstm_cell_basic_dynamic ✅
optimized_lstm_cell_basic_dynamic_f64 ✅
optimized_lstm_cell_basic ✅
optimized_lstm_cell_basic_f64 ✅ |
0.11.0 |
| linen.pool |
AveragePool MaxPool Mul Neg Transpose |
pool_max_basic_dynamic ✅
pool_max_basic ✅
pool_min_basic_dynamic ✅
pool_min_basic ✅
pool_sum_basic_dynamic ✅
pool_sum_basic ✅ |
0.11.0 |
| linen.rms_norm |
RMSNormalization |
rms_norm_basic ✅
rms_norm_use_scale_false ✅
rms_norm_4d_dynamic_dynamic ✅
rms_norm_4d_dynamic ✅ |
0.11.0 |
| linen.rnn |
Loop |
rnn_basic_dynamic ✅
rnn_basic_dynamic_f64 ✅
rnn_basic ✅
rnn_basic_f64 ✅ |
0.11.0 |
| linen.self_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
self_attention_basic_dynamic ✅
self_attention_basic_dynamic_f64 ✅
self_attention_basic ✅
self_attention_basic_f64 ✅
self_attention_no_bias_dynamic ✅
self_attention_no_bias_dynamic_f64 ✅
self_attention_no_bias ✅
self_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.simple_cell |
Add Gemm Mul Sigmoid Tanh |
simple_cell_basic_dynamic ✅
simple_cell_basic_dynamic_f64 ✅
simple_cell_basic ✅
simple_cell_basic_f64 ✅ |
0.11.0 |
| linen.spectral_norm |
Div MatMul |
spectral_norm_dense_dynamic ✅
spectral_norm_dense ✅ |
0.11.0 |
| linen.weight_norm |
Div Mul |
weight_norm_dense_dynamic ✅
weight_norm_dense ✅ |
0.11.0 |
| nn.celu |
Celu |
jaxnn_celu ✅
jaxnn_celu_1 ✅
jaxnn_celu_alpha_default ✅
jaxnn_celu_alpha_custom_dynamic ✅
jaxnn_celu_alpha_custom ✅ |
0.7.1 |
| nn.dot_product_attention |
Add MatMul Mul Softmax Transpose Where |
dpa_basic ✅
dpa_basic_f64 ✅
dpa_positional_bias_mask ✅
dpa_positional_bias_mask_f64 ✅
dpa_diff_heads_embed ✅
dpa_diff_heads_embed_f64 ✅
dpa_batch4_seq16 ✅
dpa_batch4_seq16_f64 ✅
dpa_float64 ✅
dpa_heads1_embed4 ✅
dpa_heads1_embed4_f64 ✅
dpa_heads8_embed8 ✅
dpa_heads8_embed8_f64 ✅
dpa_batch1_seq2 ✅
dpa_batch1_seq2_f64 ✅
dpa_batch8_seq4 ✅
dpa_batch8_seq4_f64 ✅
dpa_axis1 ✅
dpa_axis1_f64 ✅
dpa_with_tensor_mask ✅
dpa_with_tensor_mask_f64 ✅
dpa_tiny_mask_all_valid ✅
dpa_tiny_mask_all_valid_f64 ✅
dpa_tiny_mask_mixed ✅
dpa_tiny_mask_mixed_f64 ✅
dpa_one_false ✅
dpa_one_false_f64 ✅
dpa_mostly_false ✅
dpa_mostly_false_f64 ✅
dpa_with_causal_mask ✅
dpa_with_causal_mask_f64 ✅
dpa_with_padding_mask ✅
dpa_with_padding_mask_f64 ✅
dpa_with_local_window_mask ✅
dpa_with_local_window_mask_f64 ✅
dpa_mask_none ✅ |
0.8.0 |
| nn.elu |
Elu |
jaxnn_elu ✅
jaxnn_elu_1 ✅
jaxnn_elu_default ✅
jaxnn_elu_custom_alpha_dynamic ✅
jaxnn_elu_custom_alpha ✅ |
0.7.1 |
| nn.gelu |
Gelu |
jaxnn_gelu ✅
jaxnn_gelu_1 ✅
jaxnn_gelu_approx ✅
jaxnn_gelu_exact ✅
jaxnn_gelu_tanh_dynamic ✅
jaxnn_gelu_tanh ✅ |
0.7.1 |
| nn.identity |
Identity |
jaxnn_identity ✅
jaxnn_identity_f64 ✅
jaxnn_identity_1 ✅
jaxnn_identity_1_f64 ✅
jaxnn_identity_basic ✅
jaxnn_identity_basic_f64 ✅
jaxnn_identity_dynamic_dynamic ✅
jaxnn_identity_dynamic_dynamic_f64 ✅
jaxnn_identity_dynamic ✅
jaxnn_identity_dynamic_f64 ✅ |
0.7.1 |
| nn.leaky_relu |
LeakyRelu |
jaxnn_leaky_relu ✅
jaxnn_leaky_relu_1 ✅
jaxnn_leaky_relu_default_dynamic ✅
jaxnn_leaky_relu_default ✅
jaxnn_leaky_relu_custom ✅ |
0.7.1 |
| nn.mish |
Mish |
jaxnn_mish ✅
jaxnn_mish_1 ✅
jaxnn_mish_basic ✅ |
0.7.1 |
| nn.relu |
Relu |
jaxnn_relu ✅
jaxnn_relu_f64 ✅
jaxnn_relu_1 ✅
jaxnn_relu_1_f64 ✅
jaxnn_relu_basic ✅
jaxnn_relu_basic_f64 ✅
jaxnn_relu_dynamic_dynamic ✅
jaxnn_relu_dynamic_dynamic_f64 ✅
jaxnn_relu_dynamic ✅
jaxnn_relu_dynamic_f64 ✅ |
0.7.1 |
| nn.selu |
Selu |
jaxnn_selu ✅
jaxnn_selu_1 ✅
jaxnn_selu_basic_dynamic ✅
jaxnn_selu_basic ✅ |
0.7.1 |
| nn.sigmoid |
Sigmoid |
jaxnn_sigmoid ✅
jaxnn_sigmoid_f64 ✅
jaxnn_sigmoid_1 ✅
jaxnn_sigmoid_1_f64 ✅ |
0.7.1 |
| nn.soft_sign |
Softsign |
jaxnn_soft_sign ✅
jaxnn_soft_sign_1 ✅
jaxnn_softsign_basic ✅ |
0.7.1 |
| nn.softmax |
Softmax |
softmax ✅
softmax_f64 ✅
softmax_2d ✅
softmax_2d_f64 ✅
softmax_3d ✅
softmax_3d_f64 ✅
softmax_mask_where ✅
softmax_mask_where_f64 ✅ |
0.7.1 |
| nn.softplus |
Softplus |
jaxnn_softplus ✅
jaxnn_softplus_1 ✅
jaxnn_softplus_basic ✅ |
0.7.1 |
| nn.truncated_normal |
➖ |
initializer ✅
random_truncated_normal_positional ✅
flax_dense_like_init ✅ |
0.7.1 |
| nnx.avg_pool |
AveragePool Transpose |
avg_pool_dynamic ✅
avg_pool ✅
avg_pool_same_padding_dynamic ✅
avg_pool_same_padding ✅
avg_pool_default_padding_dynamic ✅
avg_pool_default_padding ✅
avg_pool_stride1_dynamic ✅
avg_pool_stride1 ✅
avg_pool_win3x3_stride2_dynamic ✅
avg_pool_win3x3_stride2 ✅
avg_pool_stride_none_dynamic ✅
avg_pool_stride_none ✅
avg_pool_count_include_pad_false_dynamic ✅
avg_pool_count_include_pad_false ✅ |
0.1.0 |
| nnx.batch_norm |
BatchNormalization |
batch_norm_no_bias_no_scale_dynamic ✅
batch_norm_no_bias_no_scale ✅
batch_norm_bias_no_scale_dynamic ✅
batch_norm_bias_no_scale ✅
batch_norm_no_bias_scale_dynamic ✅
batch_norm_no_bias_scale ✅
batch_norm_bias_scale_dynamic ✅
batch_norm_bias_scale ✅
batch_norm_3d_dynamic ✅
batch_norm_3d ✅
batch_norm_4d_dynamic ✅
batch_norm_4d ✅
batch_norm_4d_no_bias_no_scale_dynamic ✅
batch_norm_4d_no_bias_no_scale ✅ |
0.1.0 |
| nnx.conv |
CastLike Conv Reshape Transpose |
conv_basic_bias_dynamic ✅
conv_basic_bias ✅
conv_basic_bias_2 ✅
conv_basic_bias_3 ✅
conv_stride2_bias ✅
conv_no_bias_dynamic ✅
conv_no_bias ✅
conv_valid_padding ✅
conv_stride1 ✅
conv_stride2 ✅
conv_different_kernel ✅
conv_float64 ✅
conv_single_batch ✅
conv_large_batch ✅
conv_1d ✅
conv_1d_more_1d_inputs ✅
conv_1d_more_2d_inputs ✅
conv_1d_large_kernel ✅
conv_1d_dilation ✅
conv_1d_stride_dilation ✅
conv_2d_asymmetric_kernel ✅
conv_2d_asymmetric_stride ✅
conv_2d_asymmetric_dilation ✅
conv_2d_large_dilation ✅
conv_2d_large_stride ✅
conv_2d_mixed_params ✅
conv_2d_same_padding_mixed_dilation ✅
conv_3d_basic ✅
conv_3d_stride ✅
conv_3d_asymmetric ✅
conv_3d_dilation ✅
conv_2d_small_input ✅
conv_2d_many_channels ✅
conv_1d_wide_input ✅
conv_2d_kernel_1x1 ✅
conv_1d_kernel_1 ✅
conv_2d_group_conv ✅
conv_1d_group_conv_more_dims ✅
conv_2d_depthwise ✅
conv_1d_complex_on_4d ✅
conv_2d_complex_on_5d ✅
conv_2d_asymmetric_on_5d ✅
conv_1d_high_dilation_on_3d ✅
conv_1d_large_kernel_on_4d ✅
conv_2d_group_stride_dilation ✅
conv_1d_group_on_higher_dim ✅
conv_1d_same_padding_on_3d ✅
conv_3d_group_complex ✅
conv_1d_unit_group_on_multi_dim ✅ |
0.1.0 |
| nnx.dot_product_attention |
Add Attention MatMul Mul Softmax Transpose Where |
dpa_basic ✅
dpa_basic_f64 ✅
dpa_with_tensor_mask ✅
dpa_with_bias ✅
dpa_with_causal_mask ✅
dpa_with_causal_mask_f64 ✅
dpa_with_mask_and_bias ✅ |
0.1.0 |
| nnx.dropout |
Dropout |
dropout_init_params_dynamic ✅
dropout_init_params_dynamic_f64 ✅
dropout_init_params ✅
dropout_init_params_f64 ✅
dropout_call_params_dynamic ✅
dropout_call_params_dynamic_f64 ✅
dropout_call_params ✅
dropout_call_params_f64 ✅ |
0.1.0 |
| nnx.einsum |
Add Einsum |
einsum_module_with_bias ✅
einsum_module_with_bias_f64 ✅
einsum_module_no_bias ✅
einsum_module_no_bias_f64 ✅ |
0.4.2 |
| nnx.elu |
Elu |
elu ✅
elu_default_dynamic ✅
elu_default ✅
elu_alpha ✅ |
0.2.0 |
| nnx.embed |
Gather |
token_embedding_dynamic ✅
token_embedding_dynamic_f64 ✅
token_embedding ✅
token_embedding_f64 ✅
positional_embedding_dynamic ✅
positional_embedding_dynamic_f64 ✅
positional_embedding ✅
positional_embedding_f64 ✅ |
0.7.0 |
| nnx.gelu |
Gelu |
gelu ✅
gelu_1 ✅
gelu_2 ✅
gelu_2_f64 ✅
gelu_3_dynamic ✅
gelu_3_dynamic_f64 ✅
gelu_3 ✅
gelu_3_f64 ✅
gelu_4 ✅
gelu_4_f64 ✅
gelu_5_dynamic ✅
gelu_5_dynamic_f64 ✅
gelu_5 ✅
gelu_5_f64 ✅ |
0.1.0 |
| nnx.group_norm |
GroupNormalization |
group_norm ✅
group_norm_rank2_dynamic ✅
group_norm_rank2 ✅
group_norm_rank4 ✅
group_norm_no_bias_dynamic ✅
group_norm_no_bias ✅
group_norm_no_bias_no_scale_dynamic ✅
group_norm_no_bias_no_scale ✅
group_norm_bias_no_scale_dynamic ✅
group_norm_bias_no_scale ✅
group_norm_no_scale_dynamic ✅
group_norm_no_scale ✅
group_norm_no_bias_scale_dynamic ✅
group_norm_no_bias_scale ✅
group_norm_bias_scale_dynamic ✅
group_norm_bias_scale ✅ |
0.2.0 |
| nnx.layer_norm |
LayerNormalization |
layer_norm_dynamic ✅
layer_norm ✅
layer_norm_no_bias_no_scale_dynamic ✅
layer_norm_no_bias_no_scale ✅
layer_norm_bias_no_scale_dynamic ✅
layer_norm_bias_no_scale ✅
layer_norm_no_bias_scale_dynamic ✅
layer_norm_no_bias_scale ✅
layer_norm_bias_scale_dynamic ✅
layer_norm_bias_scale ✅
layer_norm_multiaxis_dynamic ✅
layer_norm_multiaxis ✅
layer_norm_symbolic_batch_dynamic ✅
layer_norm_symbolic_batch ✅
layer_norm_symbolic_batch_seq10_feat3_dynamic ✅
layer_norm_symbolic_batch_seq10_feat3 ✅
layer_norm_symbolic_batch_seq10_feat3_2_dynamic ✅
layer_norm_symbolic_batch_seq10_feat3_2 ✅
layer_norm_negative_axis_no_div_dynamic ✅
layer_norm_negative_axis_no_div ✅ |
0.1.0 |
| nnx.leaky_relu |
LeakyRelu |
leaky_relu ✅
leaky_relu_default_dynamic ✅
leaky_relu_default ✅
leaky_relu_custom ✅ |
0.2.0 |
| nnx.linear |
CastLike Concat Gemm Reshape Shape Slice |
linear_symbolic_batch_dynamic ✅
linear_symbolic_batch_dynamic_f64 ✅
linear_symbolic_batch ✅
linear_symbolic_batch_f64 ✅
linear_high_rank_dynamic ✅
linear_high_rank_dynamic_f64 ✅
linear_high_rank_static ✅
linear_high_rank_static_f64 ✅
linear_no_bias_dynamic ✅
linear_no_bias_dynamic_f64 ✅
linear_no_bias ✅
linear_no_bias_f64 ✅
linear_high_rank_no_bias_dynamic ✅
linear_high_rank_no_bias_dynamic_f64 ✅
linear_high_rank_no_bias ✅
linear_high_rank_no_bias_f64 ✅
linear_merge_symbolic_dim_dynamic ✅ |
0.1.0 |
| nnx.linear_general |
CastLike Concat Gemm Reshape Shape Slice |
linear_general_merge_symbolic_dim_dynamic ✅
linear_general_dynamic ✅
linear_general ✅
linear_general_2 ✅
linear_general_3 ✅
linear_general_4 ✅
linear_general_abstract_eval_axes ✅
linear_general_abstract_eval_axes_pair ✅
dynamic_batch_and_feature_dims_dynamic ✅ |
0.1.0 |
| nnx.log_softmax |
LogSoftmax |
log_softmax ✅
log_softmax_f64 ✅
log_softmax_default_axis_dynamic ✅
log_softmax_default_axis_dynamic_f64 ✅
log_softmax_default_axis ✅
log_softmax_default_axis_f64 ✅
log_softmax_axis0 ✅
log_softmax_axis0_f64 ✅ |
0.2.0 |
| nnx.max_pool |
MaxPool |
max_pool ✅
max_pool_same_padding ✅
max_pool_basic ✅
max_pool_same_dynamic ✅
max_pool_same ✅ |
0.2.0 |
| nnx.relu |
Relu |
relu_1d ✅
relu_1d_f64 ✅
relu_4d_dynamic ✅
relu_4d_dynamic_f64 ✅
relu_4d ✅
relu_4d_f64 ✅ |
0.2.0 |
| nnx.rms_norm |
RMSNormalization |
rms_norm_basic ✅
rms_norm_use_scale_false ✅
rms_norm_4d_dynamic_dynamic ✅
rms_norm_4d_dynamic ✅
rms_norm_4d_dynamic_no_scale_dynamic ✅
rms_norm_4d_dynamic_no_scale ✅ |
0.2.0 |
| nnx.sigmoid |
Sigmoid |
sigmoid_dynamic ✅
sigmoid_dynamic_f64 ✅
sigmoid ✅
sigmoid_f64 ✅ |
0.2.0 |
| nnx.softmax |
Softmax |
softmax_dynamic ✅
softmax_dynamic_f64 ✅
softmax ✅
softmax_f64 ✅ |
0.1.0 |
| nnx.softplus |
Softplus |
softplus ✅ |
0.1.0 |
| nnx.tanh |
Tanh |
tanh ✅
tanh_f64 ✅ |
0.1.0 |
| random.random_bits |
Cast Floor RandomUniform |
random_bits_uint32 ✅
random_bits_uint32_f64 ✅ |
0.7.2 |
| random.random_fold_in |
Identity |
random_fold_in_passthrough ✅
random_fold_in_passthrough_f64 ✅ |
0.2.0 |
| random.random_seed |
Cast Concat |
random_seed_basic ✅
random_seed_basic_f64 ✅ |
0.2.0 |