jax2onnx

Coverage Tables

Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
core.custom_jvp_generic custom_jvp_square
custom_jvp_square_f64
v0.7.1
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic
dim_as_value_dynamic_f64
dim_as_value
dim_as_value_f64
v0.5.0
core.jit_inline jit_identity
jit_identity_f64
v0.9.0
eqx.conv Conv eqx_conv2d_nchw
eqx_conv2d_batched_nchw
v0.10.0
eqx.dropout Dropout
Not
eqx_dropout_inference_mode
eqx_dropout_inference_mode_f64
eqx_dropout_training_mode
eqx_dropout_training_mode_f64
eqx_dropout_dynamic_inference
eqx_dropout_dynamic_inference_f64
eqx_dropout_batched_inference_dynamic
eqx_dropout_batched_inference_dynamic_f64
eqx_dropout_batched_inference
eqx_dropout_batched_inference_f64
v0.8.0
eqx.identity Identity eqx_identity_static
eqx_identity_static_f64
eqx_identity_symbolic_batch_dynamic
eqx_identity_symbolic_batch_dynamic_f64
eqx_identity_symbolic_batch
eqx_identity_symbolic_batch_f64
v0.8.0
eqx.layer_norm LayerNormalization layer_norm
layer_norm_f64
layer_norm_multiaxis
layer_norm_multiaxis_f64
batched_layer_norm_dynamic
batched_layer_norm_dynamic_f64
batched_layer_norm
batched_layer_norm_f64
layer_norm_no_bias_no_scale
layer_norm_no_bias_no_scale_f64
v0.8.0
eqx.linear Gemm
Reshape
eqx_linear_symbolic_batch_dynamic
eqx_linear_symbolic_batch_dynamic_f64
eqx_linear_symbolic_batch
eqx_linear_symbolic_batch_f64
eqx_linear_no_bias_symbolic_batch_dynamic
eqx_linear_no_bias_symbolic_batch_dynamic_f64
eqx_linear_no_bias_symbolic_batch
eqx_linear_no_bias_symbolic_batch_f64
eqx_linear_no_bias_vector
eqx_linear_no_bias_vector_f64
eqx_linear_high_rank
eqx_linear_high_rank_f64
eqx_linear_vector
eqx_linear_vector_f64
v0.8.0
eqx.multihead_attention Gemm
MatMul
Softmax
eqx_multihead_attention
eqx_multihead_attention_core_dynamic
eqx_multihead_attention_core_static
v0.10.0
eqx.rotary_positional_embedding Add
Concat
Multiply
eqx_rotary_positional_embedding
eqx_rotary_positional_embedding_heads
v0.10.0
jax_image.resize Resize resize_linear
resize_nearest
v0.10.0
jnp.add Add add
add_f64
jnp_add_vector
jnp_add_vector_f64
jnp_add_broadcast
jnp_add_broadcast_f64
v0.8.0
jnp.arange Range arange_data_dependent_indices
arange_stop_only_concrete_input_val
arange_stop_only_concrete_input_val_f64
arange_start_stop_concrete_input_val
arange_start_stop_concrete_input_val_f64
arange_start_stop_step_concrete_input_val
arange_start_stop_step_concrete_input_val_f64
arange_float_concrete_input_val
arange_float_concrete_input_val_f64
arange_static_stop_only_int
arange_static_stop_only_int_f64
arange_static_stop_only_float
arange_static_stop_only_float_f64
arange_static_start_stop_int
arange_static_start_stop_int_f64
arange_static_start_stop_step_int
arange_static_start_stop_step_int_f64
arange_static_empty_result_pos_step
arange_static_empty_result_pos_step_f64
arange_static_empty_result_neg_step
arange_static_empty_result_neg_step_f64
arange_static_negative_step
arange_static_negative_step_f64
arange_static_float_step_explicit_dtype
arange_static_float_step_explicit_dtype_f64
arange_static_float_step_inferred_dtype
arange_static_float_step_inferred_dtype_f64
arange_static_stop_zero
arange_static_stop_zero_f64
arange_static_start_equals_stop
arange_static_start_equals_stop_f64
arange_static_large_numbers_int
arange_static_large_numbers_int_f64
v0.5.2
jnp.clip Max
Min
clip_i32_scalar_bounds
clip_i32_scalar_bounds_f64
clip_f32_scalar_bounds_no_upcast_f64_mode
clip_only_upper
clip_only_upper_f64
clip_only_lower
clip_only_lower_f64
clip_broadcast_bounds
v0.8.0
jnp.concatenate Concat concatenate_basic
concatenate_basic_f64
concatenate_mixed_dtypes
concatenate_mixed_dtypes_f64
concatenate_with_explicit_dtype
concatenate_with_explicit_dtype_casts_inputs
concatenate_abstract_middle_dim_dynamic
concatenate_abstract_middle_dim_dynamic_f64
concatenate_abstract_middle_dim
concatenate_abstract_middle_dim_f64
concatenate_tile_and_symbolic_dynamic
concatenate_tile_and_symbolic_dynamic_f64
concatenate_tile_and_symbolic
concatenate_tile_and_symbolic_f64
v0.8.0
jnp.cumsum CumSum jnp_cumsum_axis1
jnp_cumsum_axis1_f64
jnp_cumsum_reverse_dtype
cumsum_axis2_i32
cumsum_axis2_i32_f64
cumsum_axis2_reverse_i32
cumsum_axis2_reverse_i32_f64
v0.8.0
jnp.einsum Einsum einsum_vector_dot
einsum_vector_dot_f64
einsum_matrix_vector
einsum_matrix_vector_f64
einsum_matrix_matrix_dynamic
einsum_matrix_matrix_dynamic_f64
einsum_matrix_matrix
einsum_matrix_matrix_f64
einsum_transpose
einsum_transpose_f64
einsum_batch_transpose_dynamic
einsum_batch_transpose_dynamic_f64
einsum_batch_transpose
einsum_batch_transpose_f64
einsum_diag
einsum_diag_f64
einsum_sum_reduce
einsum_sum_reduce_f64
einsum_multi_operand
einsum_multi_operand_f64
einsum_attention_logits_orig_dynamic
einsum_attention_logits_orig_dynamic_f64
einsum_attention_logits_orig
einsum_attention_logits_orig_f64
einsum_attention_output_orig_dynamic
einsum_attention_output_orig_dynamic_f64
einsum_attention_output_orig
einsum_attention_output_orig_f64
einsum_attention_logits_batched_dynamic
einsum_attention_logits_batched_dynamic_f64
einsum_attention_logits_batched
einsum_attention_logits_batched_f64
einsum_attention_output_batched_dynamic
einsum_attention_output_batched_dynamic_f64
einsum_attention_output_batched
einsum_attention_output_batched_f64
einsum_ellipsis_rank_mismatch
einsum_ellipsis_rank_mismatch_f64
einsum_attention_logits_batched_rank_mismatch
einsum_attention_logits_batched_rank_mismatch_f64
v0.1.0
jnp.linspace Add
Range
linspace_static_basic
linspace_static_basic_f64
linspace_static_endpoint_false
linspace_static_endpoint_false_f64
linspace_static_int_inputs_default_dtype
linspace_static_int_inputs_default_dtype_f64
linspace_basic_f32
linspace_basic_f32_f64
linspace_endpoint_false_i32
linspace_endpoint_false_i32_f64
linspace_num_zero
linspace_num_zero_f64
linspace_num_one
linspace_num_one_f64
linspace_static_num_0
linspace_static_num_0_f64
linspace_static_num_1
linspace_static_num_1_f64
v0.5.2
jnp.matmul MatMul matmul_1d
matmul_1d_f64
matmul_1d_2d
matmul_1d_2d_f64
matmul_2d
matmul_2d_f64
matmul_2d_1d
matmul_2d_1d_f64
matmul_3d
matmul_3d_f64
matmul_dynamic_dynamic
matmul_dynamic_dynamic_f64
matmul_dynamic
matmul_dynamic_f64
matmul_dynamic_a_dynamic
matmul_dynamic_a_dynamic_f64
matmul_dynamic_a
matmul_dynamic_a_f64
v0.1.0
jnp.outer Mul outer_vector
outer_vector_f64
outer
outer_f64
v0.10.0
jnp.pow Pow jnp_pow_vector
jnp_pow_vector_f64
pow_jnp_pow
pow_jnp_pow_f64
v0.8.0
jnp.power Pow jnp_power_vector
jnp_power_vector_f64
pow_jnp_power
pow_jnp_power_f64
v0.8.0
jnp.prod ReduceProd basic_prod
basic_prod_f64
prod_with_axis
prod_with_axis_f64
prod_with_keepdims
prod_with_keepdims_f64
jnp_prod_basic
jnp_prod_basic_f64
jnp_prod_axis
jnp_prod_axis_f64
jnp_prod_keepdims
jnp_prod_keepdims_f64
v0.8.0
jnp.reshape Reshape reshape_1
reshape_1_f64
reshape_2
reshape_2_f64
reshape_3
reshape_3_f64
reshape_4_dynamic
reshape_4_dynamic_f64
reshape_4
reshape_4_f64
reshape_to_scalar
reshape_to_scalar_f64
reshape_from_scalar
reshape_from_scalar_f64
reshape_cnn_dynamic
reshape_cnn_dynamic_f64
reshape_cnn
reshape_cnn_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_basic
reshape_basic_f64
reshape_infer
reshape_infer_f64
reshape_symbolic_flatten_dynamic
reshape_symbolic_flatten_dynamic_f64
reshape_symbolic_flatten
reshape_symbolic_flatten_f64
v0.1.0
jnp.select Where select_simple
select_simple_f64
select_broadcast
select_broadcast_f64
select_gpt2_attention_mask_dynamic
select_gpt2_attention_mask_dynamic_f64
select_gpt2_attention_mask
select_gpt2_attention_mask_f64
select_basic
select_basic_f64
v0.7.1
jnp.shape Shape shape_basic
shape_basic_f64
shape_dynamic_dynamic
shape_dynamic_dynamic_f64
shape_dynamic
shape_dynamic_f64
v0.4.0
jnp.sort Sort sort_1d
sort_1d_f64
sort_2d_axis0
sort_2d_axis0_f64
sort_basic
sort_basic_f64
v0.5.2
jnp.split Split split_by_sections
split_by_sections_f64
split_by_indices
split_by_indices_f64
split_by_indices_symbolic_dynamic
split_by_indices_symbolic_dynamic_f64
split_by_indices_symbolic
split_by_indices_symbolic_f64
split_sections
split_sections_f64
split_indices_numpy
split_indices_numpy_f64
v0.7.2
jnp.squeeze Squeeze squeeze_single_dim
squeeze_single_dim_f64
squeeze_multiple_dims
squeeze_multiple_dims_f64
squeeze_vit_output
squeeze_vit_output_f64
squeeze_dynamic_batch_dynamic
squeeze_dynamic_batch_dynamic_f64
squeeze_dynamic_batch
squeeze_dynamic_batch_f64
squeeze_all_dims
squeeze_all_dims_f64
squeeze_negative_axis
squeeze_negative_axis_f64
squeeze_negative_axis_tuple
squeeze_negative_axis_tuple_f64
squeeze_dynamic_and_negative_axis_dynamic
squeeze_dynamic_and_negative_axis_dynamic_f64
squeeze_dynamic_and_negative_axis
squeeze_dynamic_and_negative_axis_f64
v0.1.0
jnp.stack Concat
Unsqueeze
stack_axis_0
stack_axis_0_f64
stack_axis_1
stack_axis_1_f64
stack_negative_axis
stack_negative_axis_f64
stack_scalars
stack_scalars_f64
jnp_stack_axis0
jnp_stack_axis0_f64
jnp_stack_axis1
jnp_stack_axis1_f64
jnp_stack_negative_axis
jnp_stack_negative_axis_f64
jnp_stack_scalars
jnp_stack_scalars_f64
v0.8.0
jnp.take Gather take_data_dependent_indices
take_basic_axis1
take_basic_axis1_f64
v0.7.0
jnp.tile Tile tile_repeats
tile_repeats_f64
tile_a
tile_a_f64
tile_b
tile_b_f64
tile_c
tile_c_f64
tile_d
tile_d_f64
tile_dynamic_input_static
tile_dynamic_input_static_f64
tile_dynamic_input_dynamic
tile_dynamic_input_dynamic_f64
tile_dynamic_input
tile_dynamic_input_f64
tile_pad
tile_pad_f64
tile_param_symbolic_dynamic
tile_param_symbolic_dynamic_f64
tile_param_symbolic
tile_param_symbolic_f64
tile_with_symbolic_repeats_static
tile_with_symbolic_repeats_static_f64
tile_with_symbolic_repeats_dynamic
tile_with_symbolic_repeats_dynamic_f64
tile_with_symbolic_repeats
tile_with_symbolic_repeats_f64
jnp_tile_basic
jnp_tile_basic_f64
jnp_tile_scalar_repeats
jnp_tile_scalar_repeats_f64
jnp_tile_pad_rank
jnp_tile_pad_rank_f64
jnp_tile_symbolic_dynamic
jnp_tile_symbolic_dynamic_f64
jnp_tile_symbolic
jnp_tile_symbolic_f64
v0.8.0
jnp.transpose Transpose transpose_basic
transpose_basic_f64
transpose_reverse_default
transpose_reverse_default_f64
transpose_high_dim
transpose_high_dim_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_no_axes
transpose_no_axes_f64
transpose_reverse
transpose_reverse_f64
transpose_square_matrix
transpose_square_matrix_f64
v0.1.0
jnp.unstack Split
Squeeze
unstack_axis_0
unstack_axis_0_f64
unstack_axis_0_f64
unstack_axis_1
unstack_axis_1_f64
unstack_axis_1_f64
unstack_negative_axis
unstack_negative_axis_f64
v0.7.1
jnp.where Where where_simple
where_simple_f64
where_broadcast
where_broadcast_f64
where_gpt_mask_scores_literal_else_dynamic
where_gpt_mask_scores_literal_else_dynamic_f64
where_gpt_mask_scores_literal_else
where_gpt_mask_scores_literal_else_f64
where_multidim_condition_scalar_branches_broadcast
where_multidim_condition_scalar_branches_broadcast_f64
where_A
where_A_f64
where_B
where_B_f64
where_gpt_mask_scores_scalar_else_dynamic
where_gpt_mask_scores_scalar_else_dynamic_f64
where_gpt_mask_scores_scalar_else
where_gpt_mask_scores_scalar_else_f64
where_int_condition_cast
where_int_condition_cast_f64
where_literal_else_pyfloat
where_literal_else_pyfloat_f64
where_jax_int_literals_broadcast_f64_mode
where_dtype_mismatch_f64_vs_i32_promote
jnp_where_basic
jnp_where_basic_f64
jnp_where_broadcast
jnp_where_broadcast_f64
jnp_where_scalar_else
jnp_where_scalar_else_f64
v0.8.0
lax.abs Abs abs
abs_f64
v0.5.0
lax.add Add add
add_f64
add_const
add_const_f64
v0.2.0
lax.add_any Add add_any_via_jvp_on_mul
add_any_via_jvp_on_mul_f64
v0.8.0
lax.and And
BitwiseAnd
and_bool
and_bool_f64
and_int
and_int_f64
v0.6.5
lax.argmax ArgMax argmax_float_axis0
argmax_float_axis0_f64
argmax_float_axis1
argmax_float_axis1_f64
argmax_boolean_input_axis0_specific_values
argmax_boolean_input_axis0_specific_values_f64
argmax_boolean_input_axis1_specific_values
argmax_boolean_input_axis1_specific_values_f64
argmax_boolean_random_input_axis0
argmax_boolean_random_input_axis0_f64
v0.2.0
lax.argmin ArgMin argmin_test1
argmin_test1_f64
argmin_test2
argmin_test2_f64
v0.2.0
lax.bitcast_convert_type Bitcast bitcast_scalar_f32_to_i32
bitcast_scalar_f32_to_i32_f64
bitcast_tensor_i32_to_f32
bitcast_tensor_i32_to_f32_f64
v0.7.2
lax.bitwise_not BitwiseNot
Not
bitwise_not_bool
bitwise_not_bool_f64
bitwise_not_i32
bitwise_not_i32_f64
v0.7.5
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim
broadcast_in_dim_f64
broadcast_in_dim_2d_to_3d
broadcast_in_dim_2d_to_3d_f64
broadcast_in_dim_scalar
broadcast_in_dim_scalar_f64
broadcast_in_dim_batch_dynamic
broadcast_in_dim_batch_dynamic_f64
broadcast_in_dim_batch
broadcast_in_dim_batch_f64
broadcast_in_dim_dynamic_B_dynamic
broadcast_in_dim_dynamic_B_dynamic_f64
broadcast_in_dim_dynamic_B
broadcast_in_dim_dynamic_B_f64
v0.2.0
lax.clamp Max
Min
clamp_i32_scalar_bounds
clamp_i32_scalar_bounds_f64
clamp_scalar_float_bounds_match_x
clamp_scalar_float_bounds_match_x_f64
clamp_vector_bounds_match
clamp_pyint_bounds_promote_to_x_dtype
clamp_pyint_bounds_promote_to_x_dtype_f64
v0.7.5
lax.concatenate Cast
Concat
concatenate
concatenate_f64
concatenate_axis1_dynamic
concatenate_axis1_dynamic_f64
concatenate_axis1
concatenate_axis1_f64
concatenate_axis0
concatenate_axis0_f64
concatenate_3d
concatenate_3d_f64
concatenate_internal_int32_then_cast_to_f32_zeroarg
v0.2.0
lax.cond If cond_scalar
cond_scalar_f64
cond_multiple_operands_in_tuple
cond_multiple_operands_in_tuple_f64
cond_my_new_complex_scenario
cond_my_new_complex_scenario_f64
cond_nested_conditional
cond_nested_conditional_f64
cond_variables
cond_variables_f64
cond_internal_constant_f64
cond_passthrough_identity
cond_passthrough_identity_f64
cond_with_scatter
cond_with_scatter_f64
v0.5.1
lax.conv Conv conv
conv2
conv_nchw
conv_nhwc
conv_general_dilated_nhwc_output
v0.2.0
lax.convert_element_type Cast convert_element_type
convert_element_type_f64
v0.2.0
lax.copy Identity copy_float32_array
copy_int64_scalar
lax.cos Cos cos
cos_f64
v0.4.4
lax.cosh Cosh cosh
cosh_f64
v0.4.4
lax.cumsum CumSum cumsum_i32_axis2
cumsum_i32_axis2_f64
cumsum_f32_axism1_reverse
cumsum_f32_axism1_reverse_f64
v0.7.4
lax.device_put Identity device_put_array
device_put_array_f64
device_put_scalar
device_put_scalar_f64
v0.4.0
lax.div Div div
div_f64
div_const
div_const_f64
v0.2.0
lax.dot_general MatMul/Gemm dot_contract_nm
dot_contract_nm_f64
dot_contract_min
dot_contract_min_f64
dot_general
dot_general_f64
dot_general_lhs1_rhs1
dot_general_lhs1_rhs1_f64
v0.2.0
lax.dynamic_slice Slice dynamic_slice_test1
dynamic_slice_test1_f64
dynamic_slice_2d
dynamic_slice_2d_f64
dynamic_slice_3d
dynamic_slice_3d_f64
dynamic_slice_vit_like_dynamic
dynamic_slice_vit_like_dynamic_f64
dynamic_slice_vit_like
dynamic_slice_vit_like_f64
v0.1.0
lax.dynamic_update_slice ScatterND dus_1d_scalar_update
dus_1d_scalar_update_f64
dus_1d_block_update
dus_1d_block_update_f64
dus_2d_block_update
dus_2d_block_update_f64
dus_3d_block_update
dus_3d_block_update_f64
dus_4d_block_update
dus_4d_block_update_f64
v0.8.1
lax.eq Equal eq
eq_f64
v0.2.0
lax.exp Exp exp
exp_f64
v0.2.0
lax.fori_loop Loop fori_loop_counter
fori_loop_counter_f64
fori_loop_zero
fori_loop_zero_f64
fori_loop_vector
fori_loop_vector_f64
fori_loop_example
fori_loop_example_f64
fori_loop_test
fori_loop_test_f64
v0.5.1
lax.gather GatherND gather_trig_where_pipeline_f64_indices_i64
gather_trig_where_pipeline_f64_indices_i32
gather_f64_data_i64_indices_output_is_f64
gather_f64_data_i32_indices_cast_and_output_is_f64
gather_static
gather_static_f64
gather_dynamic_batch_simple_index_dynamic
gather_dynamic_batch_simple_index_dynamic_f64
gather_dynamic_batch_simple_index
gather_dynamic_batch_simple_index_f64
v0.2.0
lax.greater_equal GreaterOrEqual greater_equal
greater_equal_f64
v0.7.5
lax.gt Greater gt
gt_f64
v0.2.0
lax.integer_pow Pow integer_pow
integer_pow_f64
v0.2.0
lax.iota Range iota_int32
iota_int32_f64
iota_float32
iota_float32_f64
broadcasted_iota
broadcasted_iota_f64
v0.5.0
lax.log Log log
log_f64
v0.2.0
lax.logistic Sigmoid lax_logistic_basic
lax_logistic_basic_f64
v0.7.2
lax.lt Less lt
lt_f64
v0.2.0
lax.max Max max
max_f64
v0.2.0
lax.min Min min_test1
min_test1_f64
v0.1.0
lax.mul Mul mul_test1
mul_test1_f64
mul_test2
mul_test2_f64
mul_pyfloat_promotes_to_array_dtype_f64
mul_scalar_broadcast_promote_to_f64
v0.1.0
lax.ne Equal
Not
ne
ne_f64
v0.2.0
lax.neg Neg neg
neg_f64
v0.2.0
lax.or BitwiseOr
Or
or_bool_vec
or_bool_vec_f64
or_int_vec
or_int_vec_f64
v0.7.2
lax.pad Pad pad_const_1d
pad_const_1d_f64
pad_const_2d
pad_const_2d_f64
pad_const_2d_cval
pad_const_2d_cval_f64
pad_inside_scan_smoke_f64
pad_inside_nested_scan_smoke_f64
v0.8.0
lax.pjit pjit_inline_mul
pjit_inline_mul_f64
pjit_inline_tuple
pjit_inline_tuple_f64
v0.1.0
lax.pow Pow pow_basic
pow_basic_f64
pow_lax
pow_lax_f64
v0.8.2
lax.reduce_and ReduceMin reduce_and_all_true
reduce_and_all_true_f64
reduce_and_one_false
reduce_and_one_false_f64
reduce_and_keepdims
reduce_and_keepdims_f64
v0.6.1
lax.reduce_max ReduceMax reduce_max
reduce_max_f64
reduce_max_allaxes
reduce_max_allaxes_f64
reduce_max_axes_input
reduce_max_axes_input_f64
reduce_max_keepdims
reduce_max_keepdims_f64
v0.2.0
lax.reduce_min ReduceMin reduce_min
reduce_min_f64
reduce_min_allaxes
reduce_min_allaxes_f64
reduce_min_keepdims
reduce_min_keepdims_f64
v0.2.0
lax.reduce_or ReduceMax reduce_or_all_false
reduce_or_all_false_f64
reduce_or_one_true
reduce_or_one_true_f64
reduce_or_keepdims
reduce_or_keepdims_f64
v0.6.1
lax.reduce_prod ReduceProd reduce_prod
reduce_prod_f64
reduce_prod_allaxes
reduce_prod_allaxes_f64
reduce_prod_dtype
reduce_prod_dtype_f64
reduce_prod_dtype_f64
reduce_prod_dtype_f64_f64
reduce_prod_keepdims
reduce_prod_keepdims_f64
v0.6.1
lax.reduce_sum ReduceSum reduce_sum
reduce_sum_f64
reduce_sum_allaxes
reduce_sum_allaxes_f64
reduce_sum_dtype
reduce_sum_dtype_f64
reduce_sum_dtype_f64
reduce_sum_dtype_f64_f64
reduce_sum_keepdims
reduce_sum_keepdims_f64
v0.2.0
lax.reduce_xor Mod
ReduceSum
reduce_xor_all_false
reduce_xor_all_false_f64
reduce_xor_one_true
reduce_xor_one_true_f64
reduce_xor_two_true
reduce_xor_two_true_f64
reduce_xor_keepdims
reduce_xor_keepdims_f64
v0.6.1
lax.rem Div
Mod
rem_int
rem_int_f64
rem_float
rem_float_f64
rem_int_neg
rem_int_neg_f64
rem_float_neg
rem_float_neg_f64
v0.6.5
lax.remat2 remat2_scalar_sin_chain
remat2_scalar_sin_chain_f64
remat2_tuple_passthrough
remat2_tuple_passthrough_f64
v0.6.5
lax.reshape Reshape reshape_after_transpose_folds_const_shape
reshape_after_transpose_folds_const_shape_f64
reshape_flatten_trailing_folds_const_shape
reshape_flatten_trailing_folds_const_shape_f64
reshape
reshape_f64
reshape_valid_squeeze_middle_dim_from_problematic_source
reshape_valid_squeeze_middle_dim_from_problematic_source_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_with_inferred_dimension_from_input_dynamic_dynamic
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64
reshape_with_inferred_dimension_from_input_dynamic
reshape_with_inferred_dimension_from_input_dynamic_f64
reshape_with_inferred_dimension_from_input
reshape_with_inferred_dimension_from_input_f64
reshape_merge_symbolic_with_static_and_check_name_dynamic
reshape_merge_symbolic_with_static_and_check_name
v0.2.0
lax.rev Flip rev_vector
rev_vector_f64
rev_matrix_axes01
rev_matrix_axes01_f64
v0.7.5
lax.scan Scan scan_identity_slice_helper
scan_identity_slice_helper_f64
scan_cumsum
scan_cumsum_f64
scan_carry_only
scan_carry_only_f64
scan_multiple_sequences
scan_multiple_sequences_f64
scan_multiple_carry
scan_multiple_carry_f64
scan_matrix_carry_multidim_xs
scan_matrix_carry_multidim_xs_f64
scan_no_xs
scan_no_xs_f64
scan_fn
scan_fn_f64
scan_jit_no_xs
scan_jit_no_xs_f64
scan_captured_scalar
scan_captured_scalar_f64
scan_rank0_sequence_vectorized
scan_rank0_sequence_vectorized_f64
scan_two_diff_lengths
scan_two_diff_lengths_f64
scan_two_diff_lengths_broadcast
scan_two_diff_lengths_broadcast_f64
scan_two_diff_lengths_with_broadcast
scan_nested_len_mismatch
scan_nested_len_mismatch_f64
scan_captured_scalar_with_xs
scan_captured_vector_with_xs_f64
v0.5.1
lax.scatter ScatterND scatter_set_axis0
scatter_set_axis0_f64
scatter_set_middle
scatter_set_middle_f64
scatter_set_single
scatter_set_single_f64
scatter_set_vector
scatter_set_vector_f64
scatter_correct_axis_determination
scatter_correct_axis_determination_f64
scatter_updates_slice_needed_axis0
scatter_updates_slice_needed_axis0_f64
scatter_from_user_warning_shapes_valid_jax
scatter_from_user_warning_shapes_valid_jax_f64
scatter_user_error_scenario_precise
scatter_user_error_scenario_precise_f64
scatter_window_update_f64
scatter_window_update_depth3_shapes_ok
scatter_static_slice_set_f64
scatter_depth2_fp64_type_mismatch
scatter_clip_2d_window_at_edge
scatter_simple_2d_window_out_of_bounds
scatter_depth2_mixed_dtypes_fp_mismatch_f64
scatter_depth2_mixed_dtypes_fp_mismatch
v0.4.4
lax.scatter_add ScatterND(reduction=’add’) scatter_add_vector
scatter_add_vector_f64
scatter_add_scalar
scatter_add_scalar_f64
scatter_add_simple_1d
scatter_add_simple_1d_f64
scatter_add_batch_updates_1d_operand
scatter_add_batch_updates_1d_operand_f64
scatter_add_window_2d_operand_1d_indices
scatter_add_window_2d_operand_1d_indices_f64
scatter_add_mismatched_window_dims_from_user_report
scatter_add_mismatched_window_dims_from_user_report2
scatter_add_mismatched_window_dims_from_user_report3
scatter_add_fluids_pattern_updates_5_4_1_1
scatter_add_in_cond_float64
scatter_add_fp64_dtype_mismatch
scatter_add_depth2_depth2_helper_regression
scatter_depth2_fp64_type_mismatch
v0.5.3
lax.scatter_max ScatterND(reduction=’max’) scatter_max_simple_1d
scatter_max_simple_1d_f64
scatter_max_batch_updates_1d_operand
scatter_max_batch_updates_1d_operand_f64
scatter_max_window_2d_operand_1d_indices
scatter_max_window_2d_operand_1d_indices_f64
scatter_max_fp64_dtype_path_check
scatter_max_depth2_helper_regression_fp64
v0.7.5
lax.scatter_min ScatterND(reduction=’min’) scatter_min_simple_1d
scatter_min_simple_1d_f64
scatter_min_batch_updates_1d_operand
scatter_min_batch_updates_1d_operand_f64
scatter_min_window_2d_operand_1d_indices
scatter_min_window_2d_operand_1d_indices_f64
scatter_min_fp64_dtype_path_check
scatter_min_depth2_helper_regression_fp64
v0.7.5
lax.scatter_mul ScatterND(reduction=’mul’) scatter_mul_simple_1d
scatter_mul_simple_1d_f64
scatter_mul_batch_updates_1d_operand
scatter_mul_batch_updates_1d_operand_f64
scatter_mul_window_2d_operand_1d_indices
scatter_mul_window_2d_operand_1d_indices_f64
scatter_mul_mismatched_window_dims_from_user_report
scatter_mul_mismatched_window_dims_from_user_report2
scatter_mul_mismatched_window_dims_from_user_report3
scatter_mul_fluids_pattern_updates_5_4_1_1
scatter_mul_in_cond_float64
v0.6.4
lax.select Where select_simple
select_simple_f64
select_basic
select_basic_f64
select_mask_scores_tensor_else_dynamic
select_mask_scores_tensor_else_dynamic_f64
select_mask_scores_tensor_else_dynamic_f64
select_mask_scores_tensor_else
select_mask_scores_tensor_else_f64
select_mask_scores_tensor_else_f64
v0.7.1
lax.select_n Equal
Where
select_n_bool_predicate_two_cases_float
select_n_bool_predicate_two_cases_float_f64
select_n_bool_predicate_two_cases_int
select_n_bool_predicate_two_cases_int_f64
select_n_bool_predicate_scalar_broadcast
select_n_bool_predicate_scalar_broadcast_f64
select_n_int_indices_three_cases
select_n_int_indices_three_cases_f64
select_n_int_indices_four_cases
select_n_int_indices_four_cases_f64
v0.2.0
lax.shift_right_logical BitShift shift_right_logical_vec
shift_right_logical_vec_f64
shift_right_logical_scalar
shift_right_logical_scalar_f64
v0.7.2
lax.sign Sign sign
sign_f64
v0.5.0
lax.sin Sin sin
sin_f64
v0.4.4
lax.sinh Sinh sinh
sinh_f64
v0.4.4
lax.slice Slice slice_test1
slice_test1_f64
slice_3d_none_strides
slice_3d_none_strides_f64
slice_scan_axis_drop
slice_scan_axis_drop_f64
v0.1.0
lax.sort TopK sort_1d
sort_1d_f64
sort_2d
sort_2d_f64
v0.2.0
lax.split Split lax_split_equal_parts
lax_split_equal_parts_f64
lax_split_unequal_parts
lax_split_unequal_parts_f64
v0.7.2
lax.sqrt Sqrt sqrt
sqrt_f64
v0.2.0
lax.square Mul square
square_f64
v0.2.0
lax.squeeze Squeeze squeeze_single_axis
squeeze_single_axis_f64
squeeze_all_unit_dims_default
squeeze_all_unit_dims_default_f64
lax_squeeze_specific_axis_0
lax_squeeze_specific_axis_0_f64
lax_squeeze_multiple_axes
lax_squeeze_multiple_axes_f64
lax_squeeze_no_op_empty_dims
lax_squeeze_no_op_empty_dims_f64
lax_squeeze_problem_case_input_squeeze_only_axis_0
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64
lax_squeeze_problem_case_input_squeeze_axes_0_2
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64
v0.2.0
lax.stop_gradient Identity stop_gradient
stop_gradient_f64
stop_gradient_basic
stop_gradient_basic_f64
v0.2.0
lax.sub Sub sub_test1
sub_test1_f64
sub_test2
sub_test2_f64
sub_const
sub_const_f64
v0.1.0
lax.tanh Tanh tanh
tanh_f64
v0.2.0
lax.transpose Transpose transpose_basic
transpose_basic_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_reverse
transpose_reverse_f64
transpose_no_axes
transpose_no_axes_f64
transpose_nhwc_to_nchw
transpose_nhwc_to_nchw_f64
v0.2.0
lax.while_loop Loop while_scalar_counter
while_scalar_counter_f64
while_tuple_state
while_tuple_state_f64
while_loop_counter
while_loop_counter_f64
while_loop_vector
while_loop_vector_f64
while_loop_f64
while_loop_multi_state_f32
while_loop_multi_state_f64
while_loop_with_closure
while_loop_with_closure_f64
while_loop_basic
while_loop_two_state
while_loop_captured_tracer
while_loop_with_scalar_state
while_loop_renamed_passthrough
while_loop_closure_topo
while_loop_mixed_rank
while_loop_tracer_passthrough
while_loop_no_loop_output_reused_as_input
while_loop_4d_and_scalar_state
while_loop_4d_and_scalar_state_f64
while_loop_cnn_scalar_state_bug
while_loop_cnn_scalar_state_bug_f64
while_loop_nnx_repro
while_loop_nnx_repro_f64
v0.5.1
nn.celu Celu jaxnn_celu
jaxnn_celu_1
jaxnn_celu_alpha_default
jaxnn_celu_alpha_custom_dynamic
jaxnn_celu_alpha_custom
v0.7.1
nn.dot_product_attention Add
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_positional_bias_mask
dpa_positional_bias_mask_f64
dpa_diff_heads_embed
dpa_diff_heads_embed_f64
dpa_batch4_seq16
dpa_batch4_seq16_f64
dpa_float64
dpa_heads1_embed4
dpa_heads1_embed4_f64
dpa_heads8_embed8
dpa_heads8_embed8_f64
dpa_batch1_seq2
dpa_batch1_seq2_f64
dpa_batch8_seq4
dpa_batch8_seq4_f64
dpa_axis1
dpa_axis1_f64
dpa_with_tensor_mask
dpa_with_tensor_mask_f64
dpa_tiny_mask_all_valid
dpa_tiny_mask_all_valid_f64
dpa_tiny_mask_mixed
dpa_tiny_mask_mixed_f64
dpa_one_false
dpa_one_false_f64
dpa_mostly_false
dpa_mostly_false_f64
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_padding_mask
dpa_with_padding_mask_f64
dpa_with_local_window_mask
dpa_with_local_window_mask_f64
dpa_mask_none
v0.8.0
nn.elu Elu jaxnn_elu
jaxnn_elu_1
jaxnn_elu_default
jaxnn_elu_custom_alpha_dynamic
jaxnn_elu_custom_alpha
v0.7.1
nn.gelu Gelu jaxnn_gelu
jaxnn_gelu_1
jaxnn_gelu_approx
jaxnn_gelu_exact
jaxnn_gelu_tanh_dynamic
jaxnn_gelu_tanh
v0.7.1
nn.identity Identity jaxnn_identity
jaxnn_identity_f64
jaxnn_identity_1
jaxnn_identity_1_f64
jaxnn_identity_basic
jaxnn_identity_basic_f64
jaxnn_identity_dynamic_dynamic
jaxnn_identity_dynamic_dynamic_f64
jaxnn_identity_dynamic
jaxnn_identity_dynamic_f64
v0.7.1
nn.leaky_relu LeakyRelu jaxnn_leaky_relu
jaxnn_leaky_relu_1
jaxnn_leaky_relu_default_dynamic
jaxnn_leaky_relu_default
jaxnn_leaky_relu_custom
v0.7.1
nn.mish Mish jaxnn_mish
jaxnn_mish_1
jaxnn_mish_basic
v0.7.1
nn.relu Relu jaxnn_relu
jaxnn_relu_f64
jaxnn_relu_1
jaxnn_relu_1_f64
jaxnn_relu_basic
jaxnn_relu_basic_f64
jaxnn_relu_dynamic_dynamic
jaxnn_relu_dynamic_dynamic_f64
jaxnn_relu_dynamic
jaxnn_relu_dynamic_f64
v0.7.1
nn.selu Selu jaxnn_selu
jaxnn_selu_1
jaxnn_selu_basic_dynamic
jaxnn_selu_basic
v0.7.1
nn.sigmoid Sigmoid jaxnn_sigmoid
jaxnn_sigmoid_f64
jaxnn_sigmoid_1
jaxnn_sigmoid_1_f64
v0.7.1
nn.soft_sign Softsign jaxnn_soft_sign
jaxnn_soft_sign_1
jaxnn_softsign_basic
v0.7.1
nn.softmax Softmax softmax
softmax_f64
softmax_2d
softmax_2d_f64
softmax_3d
softmax_3d_f64
v0.7.1
nn.softplus Softplus jaxnn_softplus
jaxnn_softplus_1
jaxnn_softplus_basic
v0.7.1
nn.truncated_normal initializer
random_truncated_normal_positional
flax_dense_like_init
v0.7.1
nnx.avg_pool AveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
v0.1.0
nnx.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale
batch_norm_bias_scale_dynamic
batch_norm_bias_scale
batch_norm_3d_dynamic
batch_norm_3d
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale
v0.1.0
nnx.conv CastLike
Conv
Reshape
Transpose
conv_basic_bias_dynamic
conv_basic_bias
conv_basic_bias_2
conv_basic_bias_3
conv_stride2_bias
conv_no_bias_dynamic
conv_no_bias
conv_valid_padding
conv_stride1
conv_stride2
conv_different_kernel
conv_float64
conv_single_batch
conv_large_batch
conv_1d
conv_1d_more_1d_inputs
conv_1d_more_2d_inputs
conv_1d_large_kernel
conv_1d_dilation
conv_1d_stride_dilation
conv_2d_asymmetric_kernel
conv_2d_asymmetric_stride
conv_2d_asymmetric_dilation
conv_2d_large_dilation
conv_2d_large_stride
conv_2d_mixed_params
conv_2d_same_padding_mixed_dilation
conv_3d_basic
conv_3d_stride
conv_3d_asymmetric
conv_3d_dilation
conv_2d_small_input
conv_2d_many_channels
conv_1d_wide_input
conv_2d_kernel_1x1
conv_1d_kernel_1
conv_2d_group_conv
conv_1d_group_conv_more_dims
conv_2d_depthwise
conv_1d_complex_on_4d
conv_2d_complex_on_5d
conv_2d_asymmetric_on_5d
conv_1d_high_dilation_on_3d
conv_1d_large_kernel_on_4d
conv_2d_group_stride_dilation
conv_1d_group_on_higher_dim
conv_1d_same_padding_on_3d
conv_3d_group_complex
conv_1d_unit_group_on_multi_dim
v0.1.0
nnx.dot_product_attention Add
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_with_tensor_mask
dpa_with_bias
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_mask_and_bias
v0.1.0
nnx.dropout Dropout dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
v0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias
einsum_module_with_bias_f64
einsum_module_no_bias
einsum_module_no_bias_f64
v0.4.2
nnx.elu Elu elu
elu_default_dynamic
elu_default
elu_alpha
v0.2.0
nnx.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
v0.7.0
nnx.gelu Gelu gelu
gelu_1
gelu_2
gelu_2_f64
gelu_3_dynamic
gelu_3_dynamic_f64
gelu_3
gelu_3_f64
gelu_4
gelu_4_f64
gelu_5_dynamic
gelu_5_dynamic_f64
gelu_5
gelu_5_f64
v0.1.0
nnx.group_norm GroupNormalization group_norm
group_norm_rank2_dynamic
group_norm_rank2
group_norm_rank4
group_norm_no_bias_dynamic
group_norm_no_bias
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale
group_norm_bias_no_scale_dynamic
group_norm_bias_no_scale
group_norm_no_scale_dynamic
group_norm_no_scale
group_norm_no_bias_scale_dynamic
group_norm_no_bias_scale
group_norm_bias_scale_dynamic
group_norm_bias_scale
v0.2.0
nnx.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale
layer_norm_bias_no_scale_dynamic
layer_norm_bias_no_scale
layer_norm_no_bias_scale_dynamic
layer_norm_no_bias_scale
layer_norm_bias_scale_dynamic
layer_norm_bias_scale
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
layer_norm_symbolic_batch_dynamic
layer_norm_symbolic_batch
layer_norm_symbolic_batch_seq10_feat3_dynamic
layer_norm_symbolic_batch_seq10_feat3
layer_norm_symbolic_batch_seq10_feat3_2_dynamic
layer_norm_symbolic_batch_seq10_feat3_2
layer_norm_negative_axis_no_div_dynamic
layer_norm_negative_axis_no_div
v0.1.0
nnx.leaky_relu LeakyRelu leaky_relu
leaky_relu_default_dynamic
leaky_relu_default
leaky_relu_custom
v0.2.0
nnx.linear CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_symbolic_batch_dynamic
linear_symbolic_batch_dynamic_f64
linear_symbolic_batch
linear_symbolic_batch_f64
linear_high_rank_dynamic
linear_high_rank_dynamic_f64
linear_high_rank_static
linear_high_rank_static_f64
linear_no_bias_dynamic
linear_no_bias_dynamic_f64
linear_no_bias
linear_no_bias_f64
linear_high_rank_no_bias_dynamic
linear_high_rank_no_bias_dynamic_f64
linear_high_rank_no_bias
linear_high_rank_no_bias_f64
linear_merge_symbolic_dim_dynamic
v0.1.0
nnx.linear_general CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_general_merge_symbolic_dim_dynamic
linear_general_dynamic
linear_general
linear_general_2
linear_general_3
linear_general_4
linear_general_abstract_eval_axes
linear_general_abstract_eval_axes_pair
dynamic_batch_and_feature_dims_dynamic
v0.1.0
nnx.log_softmax LogSoftmax log_softmax
log_softmax_f64
log_softmax_default_axis_dynamic
log_softmax_default_axis_dynamic_f64
log_softmax_default_axis
log_softmax_default_axis_f64
log_softmax_axis0
log_softmax_axis0_f64
v0.2.0
nnx.max_pool MaxPool max_pool
max_pool_same_padding
max_pool_basic
max_pool_same_dynamic
max_pool_same
v0.2.0
nnx.relu Relu relu_1d
relu_1d_f64
relu_4d_dynamic
relu_4d_dynamic_f64
relu_4d
relu_4d_f64
v0.2.0
nnx.rms_norm RMSNormalization rms_norm_basic
rms_norm_use_scale_false
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic
rms_norm_4d_dynamic_no_scale_dynamic
rms_norm_4d_dynamic_no_scale
v0.2.0
nnx.sigmoid Sigmoid sigmoid_dynamic
sigmoid_dynamic_f64
sigmoid
sigmoid_f64
v0.2.0
nnx.softmax Softmax softmax_dynamic
softmax_dynamic_f64
softmax
softmax_f64
v0.1.0
nnx.softplus Softplus softplus v0.1.0
nnx.tanh Tanh tanh
tanh_f64
v0.1.0
random.random_bits Cast
Floor
RandomUniform
random_bits_uint32
random_bits_uint32_f64
v0.7.2
random.random_fold_in Identity random_fold_in_passthrough
random_fold_in_passthrough_f64
v0.2.0
random.random_seed Cast
Concat
random_seed_basic
random_seed_basic_f64
v0.2.0

Legend:
✅ = Passed
❌ = Failed
➖ = No testcase yet


Examples

Component Description Testcases Since
MlpExample A simple Equinox MLP (converter pipeline). mlp_training_mode
mlp_training_mode_f64
mlp_inference_mode
mlp_inference_mode_f64
mlp_batched_training_mode
mlp_batched_training_mode_f64
v0.8.0
SimpleLinearExample A simple linear layer example using Equinox (converter). simple_linear
simple_linear_f64
nn_linear
nn_linear_f64
v0.7.1
Attention Multi-Head Self-Attention using Equinox modules. attention_dynamic
attention
v0.10.0
AttentionCore Multi-Head Self-Attention without rotary processing. attention_core_dynamic
attention_core
v0.10.0
Block Transformer Block. transformer_block_dynamic
transformer_block
v0.10.0
DINOv3VisionTransformer DINOv3 Vision Transformer eqx_dinov3_vit_Ti14_dynamic
eqx_dinov3_vit_Ti14
eqx_dinov3_vit_S14_dynamic
eqx_dinov3_vit_S14
eqx_dinov3_vit_B14_dynamic
eqx_dinov3_vit_B14
eqx_dinov3_vit_S16_dynamic
eqx_dinov3_vit_S16
v0.10.0
PatchEmbed Image to Patch Embedding. patch_embed v0.10.0
GPT A simple GPT model that reuses nnx.MultiHeadAttention. gpt_dynamic
gpt
v0.7.0
GPT_Attention A multi-head attention layer. gpt_attention v0.7.1
GPT_CausalSelfAttention A causal self-attention module. gpt_causal_self_attention_dynamic
gpt_causal_self_attention
v0.7.0
GPT_Embeddings Combines token and position embeddings with dropout. gpt_embeddings_dynamic
gpt_embeddings
v0.7.0
GPT_Head The head of the GPT model. gpt_head_dynamic
gpt_head
v0.7.0
GPT_MLP An MLP block with GELU activation from nanoGPT. gpt_mlp_dynamic
gpt_mlp
v0.7.0
GPT_PositionEmbedding A positional embedding layer using nnx.Embed. gpt_position_embedding v0.7.0
GPT_TokenEmbedding A token embedding layer using nnx.Embed. gpt_token_embedding_dynamic
gpt_token_embedding
v0.7.0
GPT_TransformerBlock A transformer block combining attention and MLP. gpt_block_dynamic
gpt_block
v0.7.0
GPT_TransformerStack A stack of transformer blocks. gpt_transformer_stack_dynamic
gpt_transformer_stack
v0.7.0
broadcast_add Simple dynamic broadcast + add gpt_broadcast_add_dynamic_dynamic
gpt_broadcast_add_dynamic_dynamic_f64
gpt_broadcast_add_dynamic
gpt_broadcast_add_dynamic_f64
v0.7.0
cfl_timestep Tests the CFL condition timestep calculation. cfl_timestep_f64 v0.6.5
weno_reconstruction Tests the complex arithmetic pattern found in WENO schemes. weno_reconstruction_f64 v0.6.5
fori_loop_test fori_loop_test: demonstrates jax.lax.fori_loop with a simple loop. fori_loop_test
fori_loop_test_f64
v0.6.3
issue18_abs Test jnp.abs from issue 18 abs_fn
abs_fn_f64
v0.6.3
issue18_arange Test jnp.arange from issue 18 arange_fn v0.6.3
issue18_fori_loop Test jax.lax.fori_loop from issue 18 fori_loop_fn
fori_loop_fn_f64
v0.6.3
issue18_linspace Test jnp.linspace from issue 18 linspace_fn v0.6.3
issue18_scan Test jax.lax.scan from issue 18 (no xs) scan_fn v0.6.3
issue18_sign Test jnp.sign from issue 18 sign_fn
sign_fn_f64
v0.6.3
issue18_where Test jnp.where from issue 18 where_fn
where_fn_f64
v0.6.3
issue18_while_loop Test jax.lax.while_loop from issue 18 while_loop_fn v0.9.0
select_test Demonstrates jnp.select with scalar and tensor predicates. select_test_all_options
select_test_scalar_select_option_0
select_test_scalar_select_option_1
select_test_scalar_select_option_2
select_test_default_case
v0.9.0
sort_test sort_test: demonstrates jnp.sort on slices of an input array. sort_test_basic v0.9.0
cond_scatter_add_mul Scatter add/mul inside conditional branches (converter). cond_scatter_add_mul_f64_a
cond_scatter_add_mul_f64_b
v0.8.0
cond_scatter_repro Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. cond_scatter_repro_f64 v0.6.4
remat2 Tests a simple case of jax.checkpoint (also known as jax.remat2). checkpoint_scalar_f32
checkpoint_scalar_f32_f64
v0.6.5
scatter_window Window-scatter (H×W patch) with implicit batch (depth-3 path). Exercises GatherScatterMode.FILL_OR_DROP and double precision. Regression of a prior conversion failure. scatter_window_update_f64_example v0.7.4
AutoEncoder A simple autoencoder example (converter pipeline). simple_autoencoder
simple_autoencoder_f64
v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_static
simple_cnn_dynamic
v0.2.0
ForiLoop fori_loop example using nnx-compatible primitives (converter). fori_loop_counter
fori_loop_counter_f64
v0.5.1
GRUCell Flax/nnx GRUCell lowered through converter primitives. gru_cell_basic v0.7.2
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_static
simple_mlp_static_f64
simple_mlp_dynamic
simple_mlp_dynamic_f64
simple_mlp_with_call_params_dynamic
simple_mlp_with_call_params_dynamic_f64
simple_mlp_with_call_params
simple_mlp_with_call_params_f64
v0.1.0
MultiHeadAttention nnx.MultiHeadAttention exercised in several configurations, including custom attention_fn and symbolic batch variants. multihead_attention_nn_dynamic
multihead_attention_nn
multihead_attention_nnx_dynamic
multihead_attention_nnx
multihead_attention_2_nnx_dynamic
multihead_attention_2_nnx
v0.2.0
SequentialReLU Two stateless nnx.relu activations chained via nnx.Sequential. sequential_double_relu
sequential_double_relu_f64
v0.7.1
SequentialWithResidual nnx.Sequential nested within a residual block to regress earlier bugs. sequential_nested_with_residual v0.7.1
TransformerDecoderWithSequential Tiny nnx Transformer decoder using nnx.Sequential in the FFN block. tiny_decoder_with_sequential
tiny_decoder_with_sequential_and_full_dynamic_shapes_dynamic
v0.7.1
TransformerDecoderWithoutSequential Tiny nnx Transformer decoder with explicit FFN layers (no Sequential). tiny_decoder_without_sequential v0.7.1
onnx_functions_000 One function boundary on an outer NNX module (new-world). 000_one_function_on_outer_layer_dynamic
000_one_function_on_outer_layer
v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic
001_one_function_inner
v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic
002_two_nested_functions
v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic
003_two_simple_nested_functions
v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic
004_nested_function_plus_component
v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic
005_nested_function_plus_component
v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic
006_one_function_outer
v0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic
007_transformer_block
v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic
008_transformer_block
v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic
009_transformer_block
v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic
010_transformer_stack
v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic
012_vit_conv_embedding
v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic
013_vit_conv_embedding_with_call_params
013_vit_conv_embedding_with_internal_call_params_dynamic
013_vit_conv_embedding_with_internal_call_params
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value
014_one_function_without_input_param_with_default_value_dynamic
014_one_function_without_input_param_with_default_value
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic
015_one_function_with_input_param_without_default_value
v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic
016_internal_function_with_input_param_with_default_value
v0.4.0
onnx_functions_017 Demonstrates @onnx_function(unique=True) reuse across call sites. 017_unique_function_reuse v0.10.0
ClassificationHead Classification head for Vision Transformer vit_classification_head_dynamic
vit_classification_head
v0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer vit_classification_head_flat_dynamic
vit_classification_head_flat
v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding vit_concat_cls_token_dynamic
vit_concat_cls_token
v0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding vit_concat_cls_token_flat_dynamic
vit_concat_cls_token_flat
v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. vit_mnist_conv_embedding_dynamic
vit_mnist_conv_embedding
v0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. vit_mnist_conv_embedding_flat_dynamic
vit_mnist_conv_embedding_flat
v0.1.0
FeedForward MLP in Transformer vit_feed_forward_dynamic
vit_feed_forward
v0.1.0
FeedForwardFlatten MLP in Transformer vit_feed_forward_flat_dynamic
vit_feed_forward_flat
v0.1.0
GetToken Get the CLS token from the input embedding vit_get_token_dynamic
vit_get_token
v0.4.0
GetTokenFlatten Get the CLS token from the input embedding vit_get_token_flat_dynamic
vit_get_token_flat
v0.4.0
PatchEmbedding Cutting the image into patches and linearly embedding them. vit_patch_embedding_dynamic
vit_patch_embedding
v0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. vit_patch_embedding_flat_dynamic
vit_patch_embedding_flat
v0.1.0
PositionalEmbedding Add positional embedding to the input embedding vit_positional_embedding_dynamic
vit_positional_embedding
v0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding vit_positional_embedding_flat_dynamic
vit_positional_embedding_flat
v0.4.0
TransformerBlock Transformer from ‘Attention Is All You Need.’ vit_transformer_block_dynamic
vit_transformer_block
v0.1.0
TransformerBlockFlatten Transformer from ‘Attention Is All You Need.’ vit_transformer_block_flat_dynamic
vit_transformer_block_flat
v0.1.0
TransformerStack Stack of Transformer blocks vit_transformer_stack_dynamic
vit_transformer_stack
v0.1.0
TransformerStackFlatten Stack of Transformer blocks vit_transformer_stack_flat_dynamic
vit_transformer_stack_flat
v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_dynamic
vit_conv_embedding
vit_patch_embedding
v0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_flat_dynamic
vit_conv_embedding_flat
vit_patch_embedding_flat_dynamic
vit_patch_embedding_flat
v0.2.0