Skip to content

Supported JAX/ONNX Components

This is the public support matrix for JAX, JAX NumPy, Flax, Equinox, and related components. Testcase links open representative ONNX models in Netron, and the table is generated from plugin metadata plus validation tests so support status stays aligned with the converter.

JAX Component ONNX Components Testcases Since
core.custom_jvp_generic (Generic passthrough for custom JVP calls) custom_jvp_square
custom_jvp_square_f64
0.7.1
core.custom_vjp_generic (Generic passthrough for custom VJP calls) custom_vjp_square
custom_vjp_square_f64
0.7.1
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic
dim_as_value_dynamic_f64
dim_as_value
dim_as_value_f64
0.5.0
core.jit_inline jit_identity
jit_identity_f64
0.9.0
dm_pix.depth_to_space DepthToSpace depth_to_space_nhwc 0.12.0
dm_pix.space_to_depth SpaceToDepth space_to_depth_nhwc 0.12.1
eqx.adaptive_pool AveragePool
MaxPool
eqx_adaptive_avg_pool2d_divisible
eqx_adaptive_max_pool2d_divisible
eqx_adaptive_pool_generic_avg_vmap_dynamic
eqx_adaptive_pool_generic_avg_vmap
0.12.2
eqx.avg_pool AveragePool eqx_avg_pool2d_basic
eqx_avg_pool2d_batched_dynamic
eqx_avg_pool2d_batched
0.12.2
eqx.batch_norm BatchNormalization eqx_batch_norm_inference 0.12.2
eqx.conv Conv eqx_conv2d_nchw
eqx_conv2d_batched_nchw
0.10.0
eqx.dropout Dropout
Not
eqx_dropout_inference_mode
eqx_dropout_inference_mode_f64
eqx_dropout_training_mode
eqx_dropout_training_mode_f64
eqx_dropout_dynamic_inference
eqx_dropout_dynamic_inference_f64
eqx_dropout_batched_inference_dynamic
eqx_dropout_batched_inference_dynamic_f64
eqx_dropout_batched_inference
eqx_dropout_batched_inference_f64
0.8.0
eqx.embedding Gather eqx_embedding_scalar_index
eqx_embedding_scalar_index_f64
eqx_embedding_batched_vmap_dynamic
eqx_embedding_batched_vmap_dynamic_f64
eqx_embedding_batched_vmap
eqx_embedding_batched_vmap_f64
0.12.2
eqx.group_norm GroupNormalization eqx_group_norm_rank3
eqx_group_norm_no_affine
0.12.2
eqx.gru_cell Add
Gemm
Mul
Sigmoid
Tanh
eqx_gru_cell_basic 0.12.2
eqx.identity Identity eqx_identity_static
eqx_identity_static_f64
eqx_identity_symbolic_batch_dynamic
eqx_identity_symbolic_batch_dynamic_f64
eqx_identity_symbolic_batch
eqx_identity_symbolic_batch_f64
0.8.0
eqx.lambda Relu
Sigmoid
Tanh
eqx_lambda_relu
eqx_lambda_relu_f64
eqx_lambda_sigmoid
eqx_lambda_sigmoid_f64
0.12.2
eqx.layer_norm LayerNormalization layer_norm
layer_norm_f64
layer_norm_multiaxis
layer_norm_multiaxis_f64
batched_layer_norm_dynamic
batched_layer_norm_dynamic_f64
batched_layer_norm
batched_layer_norm_f64
layer_norm_no_bias_no_scale
layer_norm_no_bias_no_scale_f64
0.8.0
eqx.linear Gemm
Reshape
eqx_linear_symbolic_batch_dynamic
eqx_linear_symbolic_batch_dynamic_f64
eqx_linear_symbolic_batch
eqx_linear_symbolic_batch_f64
eqx_linear_no_bias_symbolic_batch_dynamic
eqx_linear_no_bias_symbolic_batch_dynamic_f64
eqx_linear_no_bias_symbolic_batch
eqx_linear_no_bias_symbolic_batch_f64
eqx_linear_no_bias_vector
eqx_linear_no_bias_vector_f64
eqx_linear_high_rank
eqx_linear_high_rank_f64
eqx_linear_vector
eqx_linear_vector_f64
0.8.0
eqx.lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
eqx_lstm_cell_basic 0.12.2
eqx.max_pool MaxPool eqx_max_pool2d_basic
eqx_max_pool2d_basic_f64
eqx_max_pool2d_batched_dynamic
eqx_max_pool2d_batched_dynamic_f64
eqx_max_pool2d_batched
eqx_max_pool2d_batched_f64
0.12.2
eqx.multihead_attention Attention
Gemm
MatMul
Softmax
eqx_multihead_attention
eqx_multihead_attention_opset23
eqx_multihead_attention_core_dynamic
eqx_multihead_attention_core_opset23_dynamic
eqx_multihead_attention_core_static
eqx_multihead_attention_core_static_opset23
0.10.0
eqx.pool MaxPool eqx_pool_max2d_basic
eqx_pool_max2d_basic_f64
eqx_pool_max2d_batched_dynamic
eqx_pool_max2d_batched_dynamic_f64
eqx_pool_max2d_batched
eqx_pool_max2d_batched_f64
0.12.2
eqx.prelu PRelu eqx_prelu_default_dynamic
eqx_prelu_default
eqx_prelu_channelwise
0.12.2
eqx.rms_norm RMSNormalization rms_norm_vector
rms_norm_vector_f64
rms_norm_no_affine
rms_norm_no_affine_f64
0.10.2
eqx.rotary_positional_embedding Add
Concat
Multiply
RotaryEmbedding
eqx_rotary_positional_embedding
eqx_rotary_positional_embedding_opset23
eqx_rotary_positional_embedding_heads
0.10.0
eqx.sequential Identity
Relu
Sigmoid
Tanh
eqx_sequential_relu_tanh_dynamic
eqx_sequential_relu_tanh_dynamic_f64
eqx_sequential_relu_tanh
eqx_sequential_relu_tanh_f64
eqx_sequential_identity_sigmoid
eqx_sequential_identity_sigmoid_f64
0.12.2
eqx.spectral_norm Div
MatMul
eqx_spectral_norm_linear_inference 0.12.2
eqx.weight_norm Div
Mul
eqx_weight_norm_linear_dynamic
eqx_weight_norm_linear
eqx_weight_norm_conv2d
0.12.2
jax_image.resize Resize
Upsample
resize_linear
resize_nearest
resize_nearest_antialias_ignored
resize_nearest_opset9_upsample
resize_linear_opset9_upsample
resize_nearest_rank3_opset9_upsample
resize_area_static_exact
resize_cubic_pytorch_static_exact
0.10.0
jnp.abs Abs jnp_abs_basic
jnp_abs_basic_f64
jnp_abs_int
jnp_abs_int_f64
abs_vmap_batching
abs_vmap_batching_f64
abs_grad_issue_batch_diff_rules
0.12.2
jnp.acos Acos jnp_acos_basic
jnp_acos_int_promote
acos_vmap_batching
acos_grad_issue_batch_diff_rules
0.12.2
jnp.acosh Acosh jnp_acosh_basic
jnp_acosh_int_promote
acosh_vmap_batching
acosh_grad_issue_batch_diff_rules
0.12.2
jnp.add Add add
add_f64
jnp_add_vector
jnp_add_vector_f64
jnp_add_broadcast
jnp_add_broadcast_f64
add_vmap_batching
add_vmap_batching_f64
add_grad_issue_batch_diff_rules
0.8.0
jnp.all ReduceMin jnp_all_basic
jnp_all_basic_f64
jnp_all_axis_keepdims
jnp_all_axis_keepdims_f64
all_vmap_batching
all_vmap_batching_f64
0.12.2
jnp.amax ReduceMax jnp_amax_basic
jnp_amax_basic_f64
jnp_amax_axis
jnp_amax_axis_f64
jnp_amax_keepdims
jnp_amax_keepdims_f64
amax_vmap_batching
amax_vmap_batching_f64
0.12.2
jnp.amin ReduceMin jnp_amin_basic
jnp_amin_basic_f64
jnp_amin_axis
jnp_amin_axis_f64
jnp_amin_keepdims
jnp_amin_keepdims_f64
amin_vmap_batching
amin_vmap_batching_f64
0.12.2
jnp.any ReduceMax jnp_any_basic
jnp_any_basic_f64
jnp_any_axis_keepdims
jnp_any_axis_keepdims_f64
any_vmap_batching
any_vmap_batching_f64
0.12.2
jnp.arange Range arange_data_dependent_indices
arange_stop_only_concrete_input_val
arange_stop_only_concrete_input_val_f64
arange_start_stop_concrete_input_val
arange_start_stop_concrete_input_val_f64
arange_start_stop_step_concrete_input_val
arange_start_stop_step_concrete_input_val_f64
arange_float_concrete_input_val
arange_float_concrete_input_val_f64
arange_static_stop_only_int
arange_static_stop_only_int_f64
arange_static_stop_only_float
arange_static_stop_only_float_f64
arange_static_start_stop_int
arange_static_start_stop_int_f64
arange_static_start_stop_step_int
arange_static_start_stop_step_int_f64
arange_static_empty_result_pos_step
arange_static_empty_result_pos_step_f64
arange_static_empty_result_neg_step
arange_static_empty_result_neg_step_f64
arange_static_negative_step
arange_static_negative_step_f64
arange_static_float_step_explicit_dtype
arange_static_float_step_explicit_dtype_f64
arange_static_float_step_inferred_dtype
arange_static_float_step_inferred_dtype_f64
arange_static_stop_zero
arange_static_stop_zero_f64
arange_static_start_equals_stop
arange_static_start_equals_stop_f64
arange_static_large_numbers_int
arange_static_large_numbers_int_f64
0.5.2
jnp.argmax ArgMax jnp_argmax_axis1
jnp_argmax_axis0_bool
0.12.2
jnp.argmin ArgMin jnp_argmin_axis1
jnp_argmin_axis0
0.12.2
jnp.asin Asin jnp_asin_basic
jnp_asin_int_promote
asin_vmap_batching
asin_grad_issue_batch_diff_rules
0.12.2
jnp.asinh Asinh jnp_asinh_basic
jnp_asinh_int_promote
asinh_vmap_batching
asinh_grad_issue_batch_diff_rules
0.12.2
jnp.atan Atan jnp_atan_basic
jnp_atan_int_promote
atan_vmap_batching
atan_grad_issue_batch_diff_rules
0.12.2
jnp.atan2 Atan
Div
Where
jnp_atan2_quadrants
jnp_atan2_broadcast
atan2_vmap_batching
atan2_grad_issue_batch_diff_rules
0.12.2
jnp.atanh Atanh jnp_atanh_basic
jnp_atanh_int_promote
atanh_vmap_batching
atanh_grad_issue_batch_diff_rules
0.12.2
jnp.bartlett Constant jnp_bartlett_basic
jnp_bartlett_basic_f64
jnp_bartlett_empty
jnp_bartlett_empty_f64
0.12.1
jnp.bitwise_and And
BitwiseAnd
jnp_bitwise_and_int
jnp_bitwise_and_int_f64
jnp_bitwise_and_bool
jnp_bitwise_and_bool_f64
jnp_bitwise_and_mixed_dtype
jnp_bitwise_and_mixed_dtype_f64
jnp_bitwise_and_broadcast
jnp_bitwise_and_broadcast_f64
bitwise_and_vmap_batching
bitwise_and_vmap_batching_f64
0.12.2
jnp.bitwise_left_shift BitShift jnp_bitwise_left_shift_basic
jnp_bitwise_left_shift_basic_f64
jnp_bitwise_left_shift_bool_bool
jnp_bitwise_left_shift_bool_bool_f64
bitwise_left_shift_vmap_batching
bitwise_left_shift_vmap_batching_f64
0.12.2
jnp.bitwise_not BitwiseNot
Not
jnp_bitwise_not_bool
jnp_bitwise_not_bool_f64
jnp_bitwise_not_int
jnp_bitwise_not_int_f64
bitwise_not_vmap_batching
bitwise_not_vmap_batching_f64
0.12.2
jnp.bitwise_or BitwiseOr
Or
jnp_bitwise_or_int
jnp_bitwise_or_int_f64
jnp_bitwise_or_bool
jnp_bitwise_or_bool_f64
jnp_bitwise_or_mixed_dtype
jnp_bitwise_or_mixed_dtype_f64
jnp_bitwise_or_broadcast
jnp_bitwise_or_broadcast_f64
bitwise_or_vmap_batching
bitwise_or_vmap_batching_f64
0.12.2
jnp.bitwise_right_shift BitShift jnp_bitwise_right_shift_signed
jnp_bitwise_right_shift_signed_f64
jnp_bitwise_right_shift_bool_bool
jnp_bitwise_right_shift_bool_bool_f64
bitwise_right_shift_vmap_batching
bitwise_right_shift_vmap_batching_f64
0.12.2
jnp.bitwise_xor BitwiseXor
Xor
jnp_bitwise_xor_int
jnp_bitwise_xor_int_f64
jnp_bitwise_xor_bool
jnp_bitwise_xor_bool_f64
jnp_bitwise_xor_mixed_dtype
jnp_bitwise_xor_mixed_dtype_f64
jnp_bitwise_xor_broadcast
jnp_bitwise_xor_broadcast_f64
bitwise_xor_vmap_batching
bitwise_xor_vmap_batching_f64
0.12.2
jnp.blackman BlackmanWindow jnp_blackman_basic 0.12.1
jnp.block Concat jnp_block_basic 0.12.1
jnp.ceil Ceil
Identity
jnp_ceil_basic
jnp_ceil_basic_f64
jnp_ceil_int_identity
jnp_ceil_int_identity_f64
ceil_vmap_batching
ceil_vmap_batching_f64
ceil_grad_issue_batch_diff_rules
0.12.2
jnp.choose Concat
Gather
jnp_choose_wrap 0.12.1
jnp.clip Clip
Max
Min
clip_i32_scalar_bounds
clip_i32_scalar_bounds_f64
clip_f32_scalar_bounds_no_upcast_f64_mode
clip_only_upper
clip_only_upper_f64
clip_only_lower
clip_only_lower_f64
clip_broadcast_bounds
clip_grad_issue_batch_diff_rules
clip_keyword_min_alias
clip_keyword_min_alias_f64
clip_keyword_max_alias
clip_keyword_max_alias_f64
clip_keyword_min_max_aliases
clip_keyword_min_max_aliases_f64
clip_keyword_legacy_a_min_a_max_aliases
clip_keyword_legacy_a_min_a_max_aliases_f64
0.8.0
jnp.compress Compress jnp_compress_axis0
jnp_compress_axis0_f64
jnp_compress_axis1
jnp_compress_axis1_f64
compress_jvp_issue_batch_diff_rules
0.12.1
jnp.concatenate Concat concatenate_basic
concatenate_basic_f64
concatenate_mixed_dtypes
concatenate_mixed_dtypes_f64
concatenate_with_explicit_dtype
concatenate_with_explicit_dtype_casts_inputs
concatenate_abstract_middle_dim_dynamic
concatenate_abstract_middle_dim_dynamic_f64
concatenate_abstract_middle_dim
concatenate_abstract_middle_dim_f64
concatenate_tile_and_symbolic_dynamic
concatenate_tile_and_symbolic_dynamic_f64
concatenate_tile_and_symbolic
concatenate_tile_and_symbolic_f64
concatenate_grad_issue191
0.8.0
jnp.conj Identity jnp_conj_real
jnp_conj_real_f64
jnp_conj_complex64
jnp_conj_complex64_f64
conj_vmap_batching
conj_vmap_batching_f64
conj_grad_issue_batch_diff_rules
0.10.1
jnp.copysign Abs
Less
Neg
Where
jnp_copysign_basic
jnp_copysign_basic_f64
jnp_copysign_int_promote
jnp_copysign_int_promote_f64
jnp_copysign_broadcast
jnp_copysign_broadcast_f64
copysign_vmap_batching
copysign_vmap_batching_f64
0.12.2
jnp.corrcoef Div
MatMul
ReduceMean
jnp_corrcoef_basic 0.12.1
jnp.cos Cos
Sin
jnp_cos_basic
jnp_cos_basic_f64
jnp_cos_int_promote
jnp_cos_int_promote_f64
jnp_cos_f64_via_sin
cos_vmap_batching
cos_vmap_batching_f64
cos_grad_issue_batch_diff_rules
0.12.2
jnp.cosh Cosh jnp_cosh_basic
jnp_cosh_basic_f64
jnp_cosh_int_promote
jnp_cosh_int_promote_f64
cosh_vmap_batching
cosh_vmap_batching_f64
cosh_grad_issue_batch_diff_rules
0.12.2
jnp.cov MatMul
ReduceMean
jnp_cov_basic
jnp_cov_basic_f64
0.12.1
jnp.cross Concat
Mul
Sub
jnp_cross_vectors
jnp_cross_vectors_f64
0.12.1
jnp.cumprod CumProd jnp_cumprod_axis1
jnp_cumprod_axis1_f64
jnp_cumprod_axis_none_flatten
jnp_cumprod_axis_none_flatten_f64
jnp_cumprod_dtype_cast
jnp_cumprod_dtype_cast_f64
0.12.1
jnp.cumsum CumSum jnp_cumsum_axis1
jnp_cumsum_axis1_f64
jnp_cumsum_reverse_dtype
cumsum_axis2_i32
cumsum_axis2_i32_f64
cumsum_axis2_reverse_i32
cumsum_axis2_reverse_i32_f64
cumsum_vmap_batching
cumsum_vmap_batching_f64
0.8.0
jnp.delete Concat
Slice
jnp_delete_axis0_index1
jnp_delete_axis0_index1_f64
0.12.1
jnp.diag Gather
Mul
Reshape
jnp_diag_from_vector
jnp_diag_from_vector_f64
jnp_diag_from_matrix
jnp_diag_from_matrix_f64
jnp_diag_from_matrix_k1
jnp_diag_from_matrix_k1_f64
0.12.2
jnp.diagonal Gather
Reshape
jnp_diagonal_basic
jnp_diagonal_basic_f64
jnp_diagonal_offset_neg1
jnp_diagonal_offset_neg1_f64
0.12.2
jnp.digitize Cast
GreaterOrEqual
Greater
LessOrEqual
Less
ReduceSum
Where
jnp_digitize_increasing_left
jnp_digitize_increasing_left_f64
jnp_digitize_increasing_right
jnp_digitize_increasing_right_f64
jnp_digitize_decreasing_left
jnp_digitize_decreasing_left_f64
jnp_digitize_scalar_value
jnp_digitize_scalar_value_f64
digitize_vmap_values
digitize_vmap_values_f64
0.13.0
jnp.divide Div jnp_divide_basic
jnp_divide_basic_f64
jnp_divide_int_promote
jnp_divide_int_promote_f64
jnp_divide_broadcast
jnp_divide_broadcast_f64
divide_vmap_batching
divide_vmap_batching_f64
divide_grad_issue_batch_diff_rules
0.12.2
jnp.dot MatMul jnp_dot_vector_vector
jnp_dot_vector_vector_f64
jnp_dot_matrix_vector
jnp_dot_matrix_vector_f64
jnp_dot_matrix_matrix
jnp_dot_matrix_matrix_f64
0.12.2
jnp.einsum Einsum einsum_vector_dot
einsum_vector_dot_f64
einsum_matrix_vector
einsum_matrix_vector_f64
einsum_matrix_matrix_dynamic
einsum_matrix_matrix_dynamic_f64
einsum_matrix_matrix
einsum_matrix_matrix_f64
einsum_transpose
einsum_transpose_f64
einsum_batch_transpose_dynamic
einsum_batch_transpose_dynamic_f64
einsum_batch_transpose
einsum_batch_transpose_f64
einsum_diag
einsum_diag_f64
einsum_sum_reduce
einsum_sum_reduce_f64
einsum_multi_operand
einsum_multi_operand_f64
einsum_attention_logits_orig_dynamic
einsum_attention_logits_orig_dynamic_f64
einsum_attention_logits_orig
einsum_attention_logits_orig_f64
einsum_attention_output_orig_dynamic
einsum_attention_output_orig_dynamic_f64
einsum_attention_output_orig
einsum_attention_output_orig_f64
einsum_attention_logits_batched_dynamic
einsum_attention_logits_batched_dynamic_f64
einsum_attention_logits_batched
einsum_attention_logits_batched_f64
einsum_attention_output_batched_dynamic
einsum_attention_output_batched_dynamic_f64
einsum_attention_output_batched
einsum_attention_output_batched_f64
einsum_ellipsis_rank_mismatch
einsum_ellipsis_rank_mismatch_f64
einsum_attention_logits_batched_rank_mismatch
einsum_attention_logits_batched_rank_mismatch_f64
einsum_grad_issue_batch_diff_rules
0.1.0
jnp.equal Equal jnp_equal_basic
jnp_equal_basic_f64
jnp_equal_broadcast
jnp_equal_broadcast_f64
jnp_equal_mixed_dtype
jnp_equal_mixed_dtype_f64
0.12.2
jnp.exp Exp jnp_exp_basic
jnp_exp_basic_f64
jnp_exp_int_promote
jnp_exp_int_promote_f64
exp_vmap_batching
exp_vmap_batching_f64
exp_grad_issue_batch_diff_rules
0.12.2
jnp.exp2 Pow jnp_exp2_basic
jnp_exp2_basic_f64
jnp_exp2_int_promote
jnp_exp2_int_promote_f64
exp2_vmap_batching
exp2_vmap_batching_f64
exp2_grad_issue_batch_diff_rules
0.12.2
jnp.expm1 Exp
Sub
jnp_expm1_basic
jnp_expm1_basic_f64
jnp_expm1_int_promote
jnp_expm1_int_promote_f64
expm1_vmap_batching
expm1_vmap_batching_f64
expm1_grad_issue_batch_diff_rules
0.12.2
jnp.eye EyeLike jnp_eye_square
jnp_eye_square_f64
jnp_eye_rect_k1
jnp_eye_rect_k1_f64
0.12.1
jnp.fabs Abs jnp_fabs_basic
jnp_fabs_basic_f64
jnp_fabs_int_promote
jnp_fabs_int_promote_f64
fabs_vmap_batching
fabs_vmap_batching_f64
fabs_grad_issue_batch_diff_rules
0.12.2
jnp.fft DFT jnp_fft_complex64
jnp_fft_complex128
0.10.1
jnp.fft_fftfreq Div
Range
jnp_fft_fftfreq_n8
jnp_fft_fftfreq_n8_f64
0.12.1
jnp.fft_fftshift Concat
Slice
jnp_fft_fftshift_1d
jnp_fft_fftshift_1d_f64
0.12.1
jnp.fft_ifftshift Concat
Slice
jnp_fft_ifftshift_1d
jnp_fft_ifftshift_1d_f64
0.12.1
jnp.fft_irfft DFT jnp_irfft_complex64 0.12.1
jnp.fft_rfftfreq Div
Range
jnp_fft_rfftfreq_n8
jnp_fft_rfftfreq_n8_f64
0.12.1
jnp.fill_diagonal ScatterND jnp_fill_diagonal_basic
jnp_fill_diagonal_basic_f64
0.12.1
jnp.floor Floor
Identity
jnp_floor_basic
jnp_floor_basic_f64
jnp_floor_int_identity
jnp_floor_int_identity_f64
floor_vmap_batching
floor_vmap_batching_f64
floor_grad_issue_batch_diff_rules
0.12.2
jnp.floor_divide Div
Floor
Where
jnp_floor_divide_float
jnp_floor_divide_float_f64
jnp_floor_divide_int_neg
jnp_floor_divide_int_neg_f64
jnp_floor_divide_broadcast
jnp_floor_divide_broadcast_f64
floor_divide_vmap_batching
floor_divide_vmap_batching_f64
0.12.2
jnp.fmod Div
Mod
Sub
jnp_fmod_float
jnp_fmod_float_f64
jnp_fmod_int
jnp_fmod_int_f64
jnp_fmod_broadcast
jnp_fmod_broadcast_f64
fmod_vmap_batching
fmod_vmap_batching_f64
0.12.2
jnp.frexp Cast
IsInf
IsNaN
Log
Pow
Where
jnp_frexp_basic
jnp_frexp_special
frexp_vmap_batching
0.13.0
jnp.full ConstantOfShape jnp_full_const_scalar
jnp_full_const_scalar_f64
jnp_full_expand_scalar_input
jnp_full_expand_scalar_input_f64
jnp_full_expand_row_broadcast
jnp_full_expand_row_broadcast_f64
0.12.1
jnp.gcd Equal
Loop
jnp_gcd_basic
jnp_gcd_basic_f64
0.12.1
jnp.geomspace Exp
Log
Range
jnp_geomspace_scalar
jnp_geomspace_scalar_f64
0.12.1
jnp.gradient Div
Slice
Sub
jnp_gradient_1d
jnp_gradient_1d_f64
0.12.1
jnp.greater Greater jnp_greater_basic
jnp_greater_basic_f64
jnp_greater_broadcast
jnp_greater_broadcast_f64
jnp_greater_mixed_dtype
jnp_greater_mixed_dtype_f64
greater_vmap_batching
greater_vmap_batching_f64
0.12.2
jnp.greater_equal GreaterOrEqual jnp_greater_equal_basic
jnp_greater_equal_basic_f64
jnp_greater_equal_broadcast
jnp_greater_equal_broadcast_f64
jnp_greater_equal_mixed_dtype
jnp_greater_equal_mixed_dtype_f64
greater_equal_vmap_batching
greater_equal_vmap_batching_f64
0.12.2
jnp.hamming HammingWindow jnp_hamming_basic 0.12.1
jnp.hanning HannWindow jnp_hanning_basic 0.12.1
jnp.histogram And
Cast
Gather
GreaterOrEqual
Identity
LessOrEqual
Less
Or
ReduceSum
Reshape
jnp_histogram_explicit_bins
jnp_histogram_explicit_bins_f64
jnp_histogram_out_of_range
jnp_histogram_out_of_range_f64
jnp_histogram_integer_inputs
jnp_histogram_integer_inputs_f64
0.13.0
jnp.histogram2d And
Cast
Gather
GreaterOrEqual
Identity
LessOrEqual
Less
Or
ReduceSum
Reshape
jnp_histogram2d_explicit_bins
jnp_histogram2d_explicit_bins_f64
jnp_histogram2d_out_of_range
jnp_histogram2d_out_of_range_f64
jnp_histogram2d_integer_inputs
jnp_histogram2d_integer_inputs_f64
0.13.0
jnp.histogram_bin_edges Range
ReduceMax
ReduceMin
jnp_histogram_bin_edges_basic
jnp_histogram_bin_edges_basic_f64
0.12.1
jnp.histogramdd And
Cast
Gather
GreaterOrEqual
Identity
LessOrEqual
Less
Or
ReduceSum
jnp_histogramdd_2d_explicit_bins
jnp_histogramdd_2d_explicit_bins_f64
jnp_histogramdd_2d_out_of_range
jnp_histogramdd_2d_out_of_range_f64
jnp_histogramdd_2d_integer_inputs
jnp_histogramdd_2d_integer_inputs_f64
0.13.0
jnp.i0 Add
Exp
Mul
jnp_i0_basic
jnp_i0_basic_f64
0.12.1
jnp.ifft DFT jnp_ifft_complex64 0.10.1
jnp.interp Add
Cast
Div
Gather
Greater
LessOrEqual
Less
Max
Min
Mul
ReduceSum
Sub
Where
jnp_interp_vector_default_bounds
jnp_interp_vector_default_bounds_f64
jnp_interp_matrix_values
jnp_interp_matrix_values_f64
jnp_interp_integer_inputs
0.13.0
jnp.invert BitwiseNot
Not
jnp_invert_bool
jnp_invert_bool_f64
jnp_invert_int
jnp_invert_int_f64
invert_vmap_batching
invert_vmap_batching_f64
0.12.2
jnp.isfinite IsInf
IsNaN
Not
Or
jnp_isfinite_basic
jnp_isfinite_basic_f64
0.12.2
jnp.lcm Div
Loop
Mul
jnp_lcm_basic
jnp_lcm_basic_f64
0.12.1
jnp.ldexp Cast
Mul
Pow
jnp_ldexp_vector
jnp_ldexp_vector_f64
jnp_ldexp_broadcast
jnp_ldexp_broadcast_f64
ldexp_vmap_batching
ldexp_vmap_batching_f64
0.13.0
jnp.left_shift BitShift jnp_left_shift_int
jnp_left_shift_int_f64
jnp_left_shift_unsigned
jnp_left_shift_unsigned_f64
jnp_left_shift_mixed_bool_rhs
jnp_left_shift_mixed_bool_rhs_f64
jnp_left_shift_broadcast
jnp_left_shift_broadcast_f64
left_shift_vmap_batching
left_shift_vmap_batching_f64
0.12.2
jnp.less Less jnp_less_basic
jnp_less_basic_f64
jnp_less_broadcast
jnp_less_broadcast_f64
jnp_less_mixed_dtype
jnp_less_mixed_dtype_f64
less_vmap_batching
less_vmap_batching_f64
0.12.2
jnp.less_equal LessOrEqual jnp_less_equal_basic
jnp_less_equal_basic_f64
jnp_less_equal_broadcast
jnp_less_equal_broadcast_f64
jnp_less_equal_mixed_dtype
jnp_less_equal_mixed_dtype_f64
less_equal_vmap_batching
less_equal_vmap_batching_f64
0.12.2
jnp.linalg_cholesky Cholesky jnp_linalg_cholesky_basic 0.12.1
jnp.linalg_cond Div
SVD
jnp_linalg_cond_basic 0.12.1
jnp.linalg_det Det linalg_det_2x2
linalg_det_batched
0.12.1
jnp.linalg_eig Eig jnp_linalg_eig_basic 0.12.1
jnp.linalg_eigh Eigh jnp_linalg_eigh_basic 0.12.1
jnp.linalg_eigvals Eig jnp_linalg_eigvals_basic 0.12.1
jnp.linalg_eigvalsh Eigh jnp_linalg_eigvalsh_basic 0.12.1
jnp.linalg_inv Concat
Div
Gather
Mul
Neg
Sub
Unsqueeze
linalg_inv_1x1
linalg_inv_2x2
linalg_inv_batched_2x2
0.13.0
jnp.linalg_lstsq MatMul
SVD
jnp_linalg_lstsq_basic 0.12.1
jnp.linalg_matrix_norm Mul
ReduceSum
Sqrt
jnp_linalg_matrix_norm_basic 0.12.1
jnp.linalg_matrix_power MatMul jnp_linalg_matrix_power_basic 0.12.1
jnp.linalg_matrix_rank ReduceSum
SVD
jnp_linalg_matrix_rank_basic 0.12.1
jnp.linalg_multi_dot MatMul jnp_linalg_multi_dot_basic 0.12.1
jnp.linalg_norm GlobalLpPool
Transpose
linalg_norm_global_fro
linalg_norm_global_default
linalg_norm_global_ord1
linalg_norm_global_ord1_nokeepdims
0.12.1
jnp.linalg_pinv MatMul
SVD
jnp_linalg_pinv_basic 0.12.1
jnp.linalg_qr QR jnp_linalg_qr_basic 0.12.1
jnp.linalg_solve Cast
Concat
Div
Gather
Mul
Sub
Unsqueeze
linalg_solve_1x1_vector
linalg_solve_2x2_vector
linalg_solve_batched_2x2_matrix_rhs
0.13.0
jnp.linalg_svd SVD jnp_linalg_svd_basic 0.12.1
jnp.linalg_svdvals SVD jnp_linalg_svdvals_basic 0.12.1
jnp.linalg_tensordot MatMul
Transpose
jnp_linalg_tensordot_basic 0.12.1
jnp.linalg_tensorinv Concat
Div
Gather
Mul
Neg
Reshape
Sub
Unsqueeze
linalg_tensorinv_2x2_tensor
linalg_tensorinv_1x1_tensor
0.13.0
jnp.linalg_tensorsolve Concat
Div
Gather
Mul
Reshape
Sub
Unsqueeze
linalg_tensorsolve_2x2_vector
linalg_tensorsolve_1x1_vector
0.13.0
jnp.linalg_trace Gather
ReduceSum
jnp_linalg_trace_basic 0.12.1
jnp.linalg_vecdot Mul
ReduceSum
jnp_linalg_vecdot_basic 0.12.1
jnp.linalg_vector_norm Mul
ReduceSum
Sqrt
jnp_linalg_vector_norm_basic 0.12.1
jnp.linspace Add
Range
linspace_static_basic
linspace_static_basic_f64
linspace_static_endpoint_false
linspace_static_endpoint_false_f64
linspace_static_int_inputs_default_dtype
linspace_static_int_inputs_default_dtype_f64
linspace_basic_f32
linspace_basic_f32_f64
linspace_endpoint_false_i32
linspace_endpoint_false_i32_f64
linspace_num_zero
linspace_num_zero_f64
linspace_num_one
linspace_num_one_f64
linspace_static_num_0
linspace_static_num_0_f64
linspace_static_num_1
linspace_static_num_1_f64
linspace_vmap_batching
linspace_vmap_batching_f64
0.5.2
jnp.log Log
ReduceLogSumExp
ReduceLogSum
jnp_log_basic
jnp_log_basic_f64
jnp_log_int_promote
jnp_log_int_promote_f64
log_vmap_batching
log_vmap_batching_f64
log_grad_issue_batch_diff_rules
0.12.2
jnp.logspace Pow
Range
jnp_logspace_basic 0.12.1
jnp.matmul MatMul matmul_1d
matmul_1d_f64
matmul_1d_2d
matmul_1d_2d_f64
matmul_2d
matmul_2d_f64
matmul_2d_1d
matmul_2d_1d_f64
matmul_3d
matmul_3d_f64
matmul_dynamic_dynamic
matmul_dynamic_dynamic_f64
matmul_dynamic
matmul_dynamic_f64
matmul_dynamic_a_dynamic
matmul_dynamic_a_dynamic_f64
matmul_dynamic_a
matmul_dynamic_a_f64
matmul_complex64
matmul_complex64_f64
matmul_vmap_batching
matmul_vmap_batching_f64
matmul_grad_issue_batch_diff_rules
0.1.0
jnp.matvec MatMul jnp_matvec_basic 0.12.1
jnp.max ReduceMax jnp_max_basic
jnp_max_basic_f64
jnp_max_axis
jnp_max_axis_f64
jnp_max_keepdims
jnp_max_keepdims_f64
max_vmap_batching
max_vmap_batching_f64
0.12.2
jnp.maximum Max jnp_maximum_basic
jnp_maximum_basic_f64
jnp_maximum_broadcast_scalar
jnp_maximum_broadcast_scalar_f64
0.12.2
jnp.mean ReduceMean basic_mean
basic_mean_f64
mean_with_axis
mean_with_axis_f64
mean_with_keepdims
mean_with_keepdims_f64
jnp_mean_basic
jnp_mean_basic_f64
jnp_mean_axis
jnp_mean_axis_f64
mean_grad_issue_batch_diff_rules
0.12.0
jnp.median ReduceMean
TopK
jnp_median_basic 0.12.1
jnp.min ReduceMin jnp_min_basic
jnp_min_basic_f64
jnp_min_axis
jnp_min_axis_f64
jnp_min_keepdims
jnp_min_keepdims_f64
min_vmap_batching
min_vmap_batching_f64
0.12.2
jnp.minimum Min jnp_minimum_basic
jnp_minimum_basic_f64
jnp_minimum_broadcast_scalar
jnp_minimum_broadcast_scalar_f64
0.12.2
jnp.moveaxis Transpose jnp_moveaxis_2d
jnp_moveaxis_2d_f64
jnp_moveaxis_3d_tuple
jnp_moveaxis_3d_tuple_f64
jnp_moveaxis_vmap_batching
jnp_moveaxis_vmap_batching_f64
0.12.2
jnp.nanargmax ArgMax
IsNaN
Where
jnp_nanargmax_basic 0.12.1
jnp.nanargmin ArgMin
IsNaN
Where
jnp_nanargmin_basic 0.12.1
jnp.nancumprod Cast
Expand
IsNaN
ReduceProd
Reshape
Where
jnp_nancumprod_axis1
jnp_nancumprod_axis_none_flatten
jnp_nancumprod_dtype_cast_i32
0.13.0
jnp.nancumsum CumSum
IsNaN
Where
jnp_nancumsum_basic 0.12.1
jnp.nanmax IsNaN
ReduceMax
Where
jnp_nanmax_basic 0.12.1
jnp.nanmean IsNaN
ReduceSum
Where
jnp_nanmean_basic 0.12.1
jnp.nanmedian TopK
Where
jnp_nanmedian_basic 0.12.1
jnp.nanmin IsNaN
ReduceMin
Where
jnp_nanmin_basic 0.12.1
jnp.nanpercentile TopK
Where
jnp_nanpercentile_basic 0.12.1
jnp.nanprod IsNaN
ReduceProd
Where
jnp_nanprod_basic 0.12.1
jnp.nanquantile TopK
Where
jnp_nanquantile_basic 0.12.1
jnp.nanstd ReduceSum
Sqrt
jnp_nanstd_basic 0.12.1
jnp.nansum IsNaN
ReduceSum
Where
jnp_nansum_basic 0.12.1
jnp.nanvar Mul
ReduceSum
jnp_nanvar_basic 0.12.1
jnp.ndarray_at ScatterND jnp_ndarray_at_set_basic 0.12.1
jnp.ones ConstantOfShape jnp_ones_2x3
jnp_ones_2x3_f64
jnp_ones_bool
jnp_ones_bool_f64
0.12.2
jnp.outer Mul outer_vector
outer_vector_f64
outer
outer_f64
outer_vmap_batching
outer_vmap_batching_f64
outer_grad_issue_batch_diff_rules
0.10.0
jnp.pad Pad jnp_pad_constant_1d
jnp_pad_constant_1d_f64
jnp_pad_constant_2d_tuple
jnp_pad_constant_2d_tuple_f64
jnp_pad_constant_vmap_batching
jnp_pad_constant_vmap_batching_f64
0.12.2
jnp.partition Concat
TopK
jnp_partition_basic 0.12.1
jnp.percentile ReduceMean
TopK
jnp_percentile_basic 0.12.1
jnp.polyadd Add
Pad
jnp_polyadd_basic 0.12.1
jnp.polyder Mul
Slice
jnp_polyder_basic 0.12.1
jnp.polydiv Div
Sub
jnp_polydiv_basic 0.12.1
jnp.polyfit Cast
Concat
Div
Mul
ReduceSum
Sub
Unsqueeze
jnp_polyfit_linear_two_points
jnp_polyfit_linear_two_points_f64
jnp_polyfit_linear_three_points
jnp_polyfit_linear_three_points_f64
0.13.0
jnp.polyint Concat
Div
jnp_polyint_basic 0.12.1
jnp.polysub Pad
Sub
jnp_polysub_basic 0.12.1
jnp.pow Pow jnp_pow_vector
jnp_pow_vector_f64
pow_jnp_pow
pow_jnp_pow_f64
pow_vmap_batching
pow_vmap_batching_f64
pow_grad_issue_batch_diff_rules
0.8.0
jnp.power Pow jnp_power_vector
jnp_power_vector_f64
pow_jnp_power
pow_jnp_power_f64
power_vmap_batching
power_vmap_batching_f64
power_grad_issue_batch_diff_rules
0.8.0
jnp.prod ReduceProd basic_prod
basic_prod_f64
prod_with_axis
prod_with_axis_f64
prod_with_keepdims
prod_with_keepdims_f64
jnp_prod_basic
jnp_prod_basic_f64
jnp_prod_axis
jnp_prod_axis_f64
jnp_prod_keepdims
jnp_prod_keepdims_f64
prod_grad_issue_batch_diff_rules
prod_vmap_batching
prod_vmap_batching_f64
0.8.0
jnp.put ScatterND jnp_put_basic 0.12.1
jnp.put_along_axis ScatterND jnp_put_along_axis_basic 0.12.1
jnp.quantile ReduceMean
TopK
jnp_quantile_basic 0.12.1
jnp.ravel_multi_index Add
Mod
Mul
jnp_ravel_multi_index_wrap 0.12.1
jnp.reshape Flatten
Reshape
reshape_1
reshape_1_f64
reshape_2
reshape_2_f64
reshape_flatten_static_uses_flatten
reshape_flatten_static_uses_flatten_f64
reshape_3
reshape_3_f64
reshape_4_dynamic
reshape_4_dynamic_f64
reshape_4
reshape_4_f64
reshape_to_scalar
reshape_to_scalar_f64
reshape_from_scalar
reshape_from_scalar_f64
reshape_cnn_dynamic
reshape_cnn_dynamic_f64
reshape_cnn
reshape_cnn_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_basic
reshape_basic_f64
reshape_infer
reshape_infer_f64
reshape_symbolic_flatten_dynamic
reshape_symbolic_flatten_dynamic_f64
reshape_symbolic_flatten
reshape_symbolic_flatten_f64
reshape_vmap_batching_issue_144
reshape_vmap_batching_issue_144_f64
reshape_grad_issue_batch_diff_rules
0.1.0
jnp.rfft DFT jnp_rfft_float32 0.10.1
jnp.right_shift BitShift jnp_right_shift_signed_arithmetic
jnp_right_shift_signed_arithmetic_f64
jnp_right_shift_unsigned_logical
jnp_right_shift_unsigned_logical_f64
jnp_right_shift_mixed_bool_rhs
jnp_right_shift_mixed_bool_rhs_f64
jnp_right_shift_broadcast
jnp_right_shift_broadcast_f64
right_shift_vmap_batching
right_shift_vmap_batching_f64
0.12.2
jnp.roll Concat
Slice
jnp_roll_basic 0.12.1
jnp.roots Add
And
Concat
Div
Equal
Gather
Less
Mul
Sign
Sqrt
Sub
Where
jnp_roots_linear
jnp_roots_linear_f64
jnp_roots_quadratic_real
jnp_roots_quadratic_real_f64
jnp_roots_quadratic_complex_pair
jnp_roots_quadratic_complex_pair_f64
jnp_roots_quadratic_leading_zero
jnp_roots_quadratic_leading_zero_f64
0.13.0
jnp.rot90 Gather
Range
Transpose
jnp_rot90_basic 0.12.1
jnp.searchsorted Cast
LessOrEqual
Less
ReduceSum
jnp_searchsorted_left_vector
jnp_searchsorted_left_vector_f64
jnp_searchsorted_right_vector
jnp_searchsorted_right_vector_f64
jnp_searchsorted_scalar_value
jnp_searchsorted_scalar_value_f64
searchsorted_vmap_values
searchsorted_vmap_values_f64
0.13.0
jnp.select Where select_simple
select_simple_f64
select_broadcast
select_broadcast_f64
select_gpt2_attention_mask_dynamic
select_gpt2_attention_mask_dynamic_f64
select_gpt2_attention_mask
select_gpt2_attention_mask_f64
select_basic
select_basic_f64
select_vmap_batching
select_vmap_batching_f64
select_grad_issue_batch_diff_rules
0.7.1
jnp.shape Shape shape_basic
shape_basic_f64
shape_dynamic_dynamic
shape_dynamic_dynamic_f64
shape_dynamic
shape_dynamic_f64
shape_vmap_batching
shape_vmap_batching_f64
0.4.0
jnp.sign Sign jnp_sign_basic
jnp_sign_basic_f64
jnp_sign_int
jnp_sign_int_f64
sign_vmap_batching
sign_vmap_batching_f64
sign_grad_issue_batch_diff_rules
0.12.2
jnp.sin Sin jnp_sin_basic
jnp_sin_basic_f64
jnp_sin_int_promote
jnp_sin_int_promote_f64
sin_vmap_batching
sin_vmap_batching_f64
sin_grad_issue_batch_diff_rules
0.12.2
jnp.sinc Div
Sin
jnp_sinc_basic 0.12.1
jnp.sinh Sinh jnp_sinh_basic
jnp_sinh_basic_f64
jnp_sinh_int_promote
jnp_sinh_int_promote_f64
sinh_vmap_batching
sinh_vmap_batching_f64
sinh_grad_issue_batch_diff_rules
0.12.2
jnp.size Size jnp_size_all
jnp_size_all_f64
jnp_size_axis
jnp_size_axis_f64
jnp_size_axis_tuple
jnp_size_axis_tuple_f64
jnp_size_dynamic_dynamic
jnp_size_dynamic_dynamic_f64
jnp_size_dynamic
jnp_size_dynamic_f64
0.12.1
jnp.sort TopK sort_1d
sort_1d_f64
sort_2d_axis0
sort_2d_axis0_f64
sort_basic
sort_basic_f64
sort_vmap_batching
sort_vmap_batching_f64
0.5.2
jnp.spacing IsInf
IsNaN
Log
Pow
Where
jnp_spacing_basic
jnp_spacing_subnormal_and_special
spacing_vmap_batching
0.13.0
jnp.split Split split_by_sections
split_by_sections_f64
split_by_indices
split_by_indices_f64
split_by_indices_symbolic_dynamic
split_by_indices_symbolic_dynamic_f64
split_by_indices_symbolic
split_by_indices_symbolic_f64
split_sections
split_sections_f64
split_indices_numpy
split_indices_numpy_f64
split_single_section
split_single_section_f64
split_grad_issue_batch_diff_rules
0.7.2
jnp.sqrt ReduceL2
Sqrt
jnp_sqrt_basic
jnp_sqrt_basic_f64
jnp_sqrt_int_promote
jnp_sqrt_int_promote_f64
jnp_sqrt_reduce_sum_square_axis1
jnp_sqrt_reduce_sum_square_axis1_f64
sqrt_vmap_batching
sqrt_vmap_batching_f64
sqrt_grad_issue_batch_diff_rules
0.12.2
jnp.squeeze Squeeze squeeze_single_dim
squeeze_single_dim_f64
squeeze_multiple_dims
squeeze_multiple_dims_f64
squeeze_vit_output
squeeze_vit_output_f64
squeeze_dynamic_batch_dynamic
squeeze_dynamic_batch_dynamic_f64
squeeze_dynamic_batch
squeeze_dynamic_batch_f64
squeeze_all_dims
squeeze_all_dims_f64
squeeze_negative_axis
squeeze_negative_axis_f64
squeeze_negative_axis_tuple
squeeze_negative_axis_tuple_f64
squeeze_dynamic_and_negative_axis_dynamic
squeeze_dynamic_and_negative_axis_dynamic_f64
squeeze_dynamic_and_negative_axis
squeeze_dynamic_and_negative_axis_f64
squeeze_vmap_batching
squeeze_vmap_batching_f64
squeeze_grad_issue_batch_diff_rules
0.1.0
jnp.stack Concat
Unsqueeze
stack_axis_0
stack_axis_0_f64
stack_axis_1
stack_axis_1_f64
stack_negative_axis
stack_negative_axis_f64
stack_scalars
stack_scalars_f64
jnp_stack_axis0
jnp_stack_axis0_f64
jnp_stack_axis1
jnp_stack_axis1_f64
jnp_stack_negative_axis
jnp_stack_negative_axis_f64
jnp_stack_scalars
jnp_stack_scalars_f64
stack_grad_issue_batch_diff_rules
0.8.0
jnp.std ReduceMean
Sqrt
jnp_std_basic 0.12.1
jnp.sum ReduceSum jnp_sum_basic
jnp_sum_basic_f64
jnp_sum_axis
jnp_sum_axis_f64
jnp_sum_keepdims
jnp_sum_keepdims_f64
jnp_sum_int8_promote
sum_vmap_batching
sum_vmap_batching_f64
sum_grad_issue_batch_diff_rules
0.12.2
jnp.take Gather take_data_dependent_indices
take_basic_axis1
take_basic_axis1_f64
take_vmap_batching
take_vmap_batching_f64
take_grad_issue_batch_diff_rules
0.7.0
jnp.tan Tan jnp_tan_basic
jnp_tan_basic_f64
jnp_tan_int_promote
jnp_tan_int_promote_f64
tan_vmap_batching
tan_vmap_batching_f64
tan_grad_issue_batch_diff_rules
0.12.2
jnp.tanh Tanh jnp_tanh_basic
jnp_tanh_basic_f64
jnp_tanh_int_promote
jnp_tanh_int_promote_f64
tanh_vmap_batching
tanh_vmap_batching_f64
tanh_grad_issue_batch_diff_rules
0.12.2
jnp.tensordot MatMul
Transpose
jnp_tensordot_basic 0.12.1
jnp.tile Tile tile_repeats
tile_repeats_f64
tile_a
tile_a_f64
tile_b
tile_b_f64
tile_c
tile_c_f64
tile_d
tile_d_f64
tile_dynamic_input_static
tile_dynamic_input_static_f64
tile_dynamic_input_dynamic
tile_dynamic_input_dynamic_f64
tile_dynamic_input
tile_dynamic_input_f64
tile_pad
tile_pad_f64
tile_param_symbolic_dynamic
tile_param_symbolic_dynamic_f64
tile_param_symbolic
tile_param_symbolic_f64
tile_with_symbolic_repeats_static
tile_with_symbolic_repeats_static_f64
tile_with_symbolic_repeats_dynamic
tile_with_symbolic_repeats_dynamic_f64
tile_with_symbolic_repeats
tile_with_symbolic_repeats_f64
jnp_tile_basic
jnp_tile_basic_f64
jnp_tile_scalar_repeats
jnp_tile_scalar_repeats_f64
jnp_tile_pad_rank
jnp_tile_pad_rank_f64
jnp_tile_symbolic_dynamic
jnp_tile_symbolic_dynamic_f64
jnp_tile_symbolic
jnp_tile_symbolic_f64
tile_vmap_batching
tile_vmap_batching_f64
tile_grad_issue_batch_diff_rules
0.8.0
jnp.trace Gather
ReduceSum
jnp_trace_basic 0.12.1
jnp.transpose Transpose transpose_basic
transpose_basic_f64
transpose_reverse_default
transpose_reverse_default_f64
transpose_high_dim
transpose_high_dim_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_no_axes
transpose_no_axes_f64
transpose_reverse
transpose_reverse_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_vmap_batching
transpose_vmap_batching_f64
transpose_grad_issue_batch_diff_rules
0.1.0
jnp.trapezoid Add
ReduceSum
Slice
jnp_trapezoid_basic 0.12.1
jnp.tri LessOrEqual
Range
jnp_tri_basic 0.12.1
jnp.trilu Trilu jnp_triu_basic
jnp_triu_basic_f64
jnp_tril_negative_k
jnp_tril_negative_k_f64
jnp_triu_symbolic_batch_dynamic
jnp_triu_symbolic_batch_dynamic_f64
jnp_triu_symbolic_batch
jnp_triu_symbolic_batch_f64
jnp_trilu_vmap_batching
jnp_trilu_vmap_batching_f64
0.12.1
jnp.unique Unique jnp_unique_f32_size_fill
jnp_unique_i32_size_fill
jnp_unique_symbolic_size_fill_dynamic
jnp_unique_symbolic_size_fill
0.12.1
jnp.unravel_index Concat
Div
Mod
jnp_unravel_index_basic
jnp_unravel_index_basic_f64
0.12.1
jnp.unstack Split
Squeeze
unstack_axis_0
unstack_axis_0_f64
unstack_axis_1
unstack_axis_1_f64
unstack_negative_axis
unstack_negative_axis_f64
unstack_axis_1_single
unstack_axis_1_single_f64
unstack_vmap_batching
unstack_vmap_batching_f64
0.7.1
jnp.var Mul
ReduceMean
jnp_var_basic 0.12.1
jnp.vdot Mul
ReduceSum
jnp_vdot_basic 0.12.1
jnp.vecdot Mul
ReduceSum
jnp_vecdot_basic 0.12.1
jnp.vecmat MatMul jnp_vecmat_basic 0.12.1
jnp.vectorize Add jnp_vectorize_basic 0.12.1
jnp.where Where where_simple
where_simple_f64
where_broadcast
where_broadcast_f64
where_gpt_mask_scores_literal_else_dynamic
where_gpt_mask_scores_literal_else_dynamic_f64
where_gpt_mask_scores_literal_else
where_gpt_mask_scores_literal_else_f64
where_multidim_condition_scalar_branches_broadcast
where_multidim_condition_scalar_branches_broadcast_f64
where_A
where_A_f64
where_B
where_B_f64
where_gpt_mask_scores_scalar_else_dynamic
where_gpt_mask_scores_scalar_else_dynamic_f64
where_gpt_mask_scores_scalar_else
where_gpt_mask_scores_scalar_else_f64
where_int_condition_cast
where_int_condition_cast_f64
where_literal_else_pyfloat
where_literal_else_pyfloat_f64
where_jax_int_literals_broadcast_f64_mode
where_dtype_mismatch_f64_vs_i32_promote
jnp_where_basic
jnp_where_basic_f64
jnp_where_broadcast
jnp_where_broadcast_f64
jnp_where_scalar_else
jnp_where_scalar_else_f64
where_vmap_batching
where_vmap_batching_f64
where_grad_issue_batch_diff_rules
0.8.0
jnp.zeros ConstantOfShape jnp_zeros_2x3
jnp_zeros_2x3_f64
jnp_zeros_int32
jnp_zeros_int32_f64
0.12.2
lax.abs Abs abs
abs_f64
0.5.0
lax.acos Acos acos 0.12.1
lax.acosh Acosh acosh 0.12.1
lax.add Add add
add_f64
add_const
add_const_f64
add_complex64
add_complex64_f64
0.2.0
lax.add_any Add
Sum
add_any_via_jvp_on_mul
add_any_via_jvp_on_mul_f64
0.8.0
lax.and And
BitwiseAnd
and_bool
and_bool_f64
and_int
and_int_f64
0.6.5
lax.approx_top_k TopK approx_max_k_matrix
approx_max_k_matrix_f64
approx_min_k_axis0
approx_min_k_axis0_f64
0.12.1
lax.argmax ArgMax argmax_float_axis0
argmax_float_axis0_f64
argmax_float_axis1
argmax_float_axis1_f64
argmax_boolean_input_axis0_specific_values
argmax_boolean_input_axis0_specific_values_f64
argmax_boolean_input_axis1_specific_values
argmax_boolean_input_axis1_specific_values_f64
argmax_boolean_random_input_axis0
argmax_boolean_random_input_axis0_f64
0.2.0
lax.argmin ArgMin argmin_test1
argmin_test1_f64
argmin_test2
argmin_test2_f64
0.2.0
lax.asin Asin asin 0.12.1
lax.asinh Asinh asinh 0.12.1
lax.atan Atan atan 0.12.1
lax.atan2 Add
Atan
Div
Equal
Greater
Less
Or
Sub
Where
atan2_quadrants_and_zero
atan2_quadrants_and_zero_f64
atan2_broadcast
0.12.1
lax.atanh Atanh atanh 0.12.1
lax.bessel_i0e Abs
Exp
Where
bessel_i0e_basic
bessel_i0e_basic_f64
0.12.1
lax.bessel_i1e Abs
Sign
Where
bessel_i1e_basic
bessel_i1e_basic_f64
0.12.1
lax.betainc Exp
Pow
Where
betainc_basic
betainc_edge_x
0.12.1
lax.bitcast_convert_type BitCast bitcast_scalar_f32_to_i32
bitcast_tensor_i32_to_f32
0.7.2
lax.bitwise_not BitwiseNot
Not
bitwise_not_bool
bitwise_not_bool_f64
bitwise_not_i32
bitwise_not_i32_f64
0.7.5
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim
broadcast_in_dim_f64
broadcast_in_dim_2d_to_3d
broadcast_in_dim_2d_to_3d_f64
broadcast_in_dim_scalar
broadcast_in_dim_scalar_f64
broadcast_in_dim_batch_dynamic
broadcast_in_dim_batch_dynamic_f64
broadcast_in_dim_batch
broadcast_in_dim_batch_f64
broadcast_in_dim_dynamic_B_dynamic
broadcast_in_dim_dynamic_B_dynamic_f64
broadcast_in_dim_dynamic_B
broadcast_in_dim_dynamic_B_f64
0.2.0
lax.cbrt Abs
Mul
Pow
Sign
cbrt
cbrt_f64
0.12.1
lax.ceil Ceil ceil
ceil_f64
0.12.1
lax.cholesky Concat
Div
Gather
Mul
ScatterND
Sqrt
Squeeze
Sub
Unsqueeze
cholesky_spd_3x3
cholesky_spd_3x3_f64
cholesky_diagonal
cholesky_diagonal_f64
cholesky_batched_2x3x3
cholesky_batched_2x3x3_f64
0.12.1
lax.cholesky_update Add
Div
Mul
ScatterND
Sqrt
Sub
cholesky_update_upper_3x3
cholesky_update_upper_3x3_f64
cholesky_update_identity
cholesky_update_identity_f64
0.12.1
lax.clamp Max
Min
clamp_i32_scalar_bounds
clamp_i32_scalar_bounds_f64
clamp_scalar_float_bounds_match_x
clamp_scalar_float_bounds_match_x_f64
clamp_vector_bounds_match
clamp_pyint_bounds_promote_to_x_dtype
clamp_pyint_bounds_promote_to_x_dtype_f64
0.7.5
lax.clz BitShift
BitwiseAnd
clz_i32
clz_i32_f64
clz_u8
clz_u8_f64
0.12.1
lax.complex Concat
Unsqueeze
complex_scalar
complex_scalar_f64
complex_array
complex_array_f64
complex_double_precision
0.12.2
lax.concatenate Cast
Concat
concatenate
concatenate_f64
concatenate_axis1_dynamic
concatenate_axis1_dynamic_f64
concatenate_axis1
concatenate_axis1_f64
concatenate_axis0
concatenate_axis0_f64
concatenate_3d
concatenate_3d_f64
concatenate_internal_int32_then_cast_to_f32_zeroarg
0.2.0
lax.cond If cond_scalar
cond_scalar_f64
cond_multiple_operands_in_tuple
cond_multiple_operands_in_tuple_f64
cond_my_new_complex_scenario
cond_my_new_complex_scenario_f64
cond_nested_conditional
cond_nested_conditional_f64
cond_variables
cond_variables_f64
cond_internal_constant_f64
cond_passthrough_identity
cond_passthrough_identity_f64
cond_with_scatter
cond_with_scatter_f64
0.5.1
lax.conj Identity conj_real
conj_real_f64
conj_complex64
conj_complex64_f64
0.10.1
lax.conv Conv conv
conv2
conv_nchw
conv_general_dilated_1d
conv_nhwc
conv_general_dilated_nhwc_output
conv_complex64
conv_complex64_nhwc
conv_complex128_grouped
0.2.0
lax.convert_element_type Cast convert_element_type
convert_element_type_f64
0.2.0
lax.copy (Handles the JAX primitive lax.copy_p. The public API jax.lax.copy is deprecated, but the primitive still appears in transformed jaxprs.) Identity copy_float32_array
copy_int64_scalar
0.13.0
lax.cos Cos cos
cos_f64
0.4.4
lax.cosh Cosh cosh
cosh_f64
0.4.4
lax.cumlogsumexp CumSum
Exp
Log
cumlogsumexp_axis1
cumlogsumexp_axis1_f64
cumlogsumexp_reverse_last_axis
cumlogsumexp_reverse_last_axis_f64
0.12.1
lax.cummax MaxPool
Reshape
Transpose
cummax_axis1
cummax_axis1_f64
cummax_reverse_last_axis
cummax_reverse_last_axis_f64
0.12.1
lax.cummin MaxPool
Neg
Reshape
Transpose
cummin_axis1
cummin_axis1_f64
cummin_reverse_last_axis
cummin_reverse_last_axis_f64
0.12.1
lax.cumprod CumProd cumprod_i32_axis2
cumprod_i32_axis2_f64
cumprod_f32_axism1_reverse
cumprod_f32_axism1_reverse_f64
0.12.1
lax.cumsum CumSum cumsum_i32_axis2
cumsum_i32_axis2_f64
cumsum_f32_axism1_reverse
cumsum_f32_axism1_reverse_f64
0.7.4
lax.custom_linear_solve custom_linear_solve_via_matvec
custom_linear_solve_via_matvec_f64
0.12.1
lax.device_put Identity device_put_array
device_put_array_f64
device_put_scalar
device_put_scalar_f64
0.4.0
lax.digamma Cos
Log
Sin
Where
digamma_positive
digamma_mixed
0.12.1
lax.div Div
LpNormalization
Mean
div
div_f64
div_const
div_const_f64
div_add_half_fuses_to_mean
div_add_half_fuses_to_mean_symbolic_dynamic
div_add_half_fuses_to_mean_symbolic
div_add_third_no_mean
div_add_half_f64_no_mean
div_lpnorm_l2_axis1
div_lpnorm_l1_axis1
div_lpnorm_l2_axis2
div_sqrt_of_norm_no_lpnorm_fusion
div_complex64
div_complex64_f64
0.2.0
lax.dot_general Einsum
MatMul/Gemm
dot_contract_nm
dot_contract_nm_f64
dot_contract_min
dot_contract_min_f64
dot_general
dot_general_f64
dot_general_lhs1_rhs1
dot_general_lhs1_rhs1_f64
dot_double_contract
dot_double_contract_f64
dot_batched_double_contract
dot_batched_double_contract_f64
dot_highrank_batch
dot_highrank_batch_f64
dot_contract_inner_lhs_with_middle_rhs
dot_contract_inner_lhs_with_middle_rhs_f64
dot_outer_product
dot_outer_product_f64
dot_full_contract_scalar
dot_full_contract_scalar_f64
dot_general_complex_matmul
dot_general_complex_matmul_f64
0.2.0
lax.dynamic_slice Slice dynamic_slice_test1
dynamic_slice_test1_f64
dynamic_slice_2d
dynamic_slice_2d_f64
dynamic_slice_3d
dynamic_slice_3d_f64
dynamic_slice_vit_like_dynamic
dynamic_slice_vit_like_dynamic_f64
dynamic_slice_vit_like
dynamic_slice_vit_like_f64
0.1.0
lax.dynamic_update_slice ScatterND
TensorScatter
dus_1d_scalar_update
dus_1d_scalar_update_f64
dus_1d_block_update
dus_1d_block_update_f64
dus_2d_block_update
dus_2d_block_update_f64
dus_3d_block_update
dus_3d_block_update_f64
dus_4d_block_update
dus_4d_block_update_f64
dus_tensorscatter_axis1_opset24
0.8.1
lax.eig Concat
Gather
Identity
Reshape
Unsqueeze
eig_1x1_values_only
eig_1x1_values_only_f64
eig_1x1_left_only
eig_1x1_left_only_f64
eig_1x1_right_only
eig_1x1_right_only_f64
eig_1x1_full
eig_1x1_full_f64
eig_2x2_values_only_real
eig_2x2_values_only_real_f64
eig_2x2_values_only_complex128
0.12.1
lax.eigh Add
Cast
Concat
Div
Gather
Identity
LessOrEqual
Mul
Reshape
Slice
Sqrt
Sub
Where
eigh_1x1
eigh_1x1_f64
eigh_2x2_lower_true
eigh_2x2_lower_true_f64
eigh_2x2_lower_false
eigh_2x2_lower_false_f64
eigh_2x2_subset_top1
eigh_2x2_subset_top1_f64
0.12.1
lax.eq Equal eq
eq_f64
0.2.0
lax.erf Erf erf 0.4.4
lax.erf_inv Erf
Exp
Log
Sqrt
erf_inv_midrange
erf_inv_matrix
0.12.1
lax.erfc Erf
Sub
erfc 0.12.1
lax.exp Exp exp
exp_f64
0.2.0
lax.exp2 Pow exp2
exp2_f64
0.12.1
lax.expm1 Exp
Sub
expm1
expm1_f64
0.12.1
lax.fft DFT fft_complex64_1d
fft_complex64_len8
fft_complex64_batch
fft_complex128_1d
fft_complex128_len8
ifft_complex64_1d
ifft_complex64_len8
ifft_complex128_batch
rfft_real32_1d
rfft_real64_len8
irfft_complex64_1d
irfft_complex128_len8
0.10.1
lax.floor Floor floor
floor_f64
0.12.1
lax.fori_loop Loop fori_loop_counter
fori_loop_counter_f64
fori_loop_zero
fori_loop_zero_f64
fori_loop_vector
fori_loop_vector_f64
fori_loop_example
fori_loop_example_f64
fori_loop_test
fori_loop_test_f64
0.5.1
lax.gather GatherND gather_trig_where_pipeline_f64_indices_i64
gather_trig_where_pipeline_f64_indices_i32
gather_f64_data_i64_indices_output_is_f64
gather_f64_data_i32_indices_cast_and_output_is_f64
gather_static
gather_static_f64
gather_fill_or_drop_oob_i32
gather_fill_or_drop_oob_i32_f64
gather_dynamic_batch_simple_index_dynamic
gather_dynamic_batch_simple_index_dynamic_f64
gather_dynamic_batch_simple_index
gather_dynamic_batch_simple_index_f64
0.2.0
lax.greater_equal GreaterOrEqual greater_equal
greater_equal_f64
0.7.5
lax.gt Greater gt
gt_f64
0.2.0
lax.hessenberg Div
MatMul
ScatterND
Sqrt
Where
hessenberg_square_4x4
hessenberg_square_4x4_f64
hessenberg_diagonal
hessenberg_diagonal_f64
0.12.1
lax.householder_product MatMul
Mul
ScatterND
Sub
Transpose
householder_product_basic
householder_product_basic_f64
householder_product_k0
householder_product_k0_f64
0.12.1
lax.igamma Exp
Pow
Where
igamma_basic
igamma_zero_x
0.12.1
lax.igamma_grad_a Add
Div
Sub
Where
igamma_grad_a_basic 0.12.1
lax.igammac Sub
Where
igammac_basic
igammac_zero_x
0.12.1
lax.imag Mul imag_complex64_input
imag_complex64_input_f64
0.10.2
lax.integer_pow Pow
Reciprocal
integer_pow
integer_pow_f64
integer_pow_reciprocal
integer_pow_reciprocal_f64
0.2.0
lax.iota Range iota_int32
iota_int32_f64
iota_float32
iota_float32_f64
iota_bfloat16_opset23_cast_fallback
iota_bfloat16_opset27_native_range
iota_float16_opset27_native_range
iota_uint8
iota_uint8_f64
broadcasted_iota
broadcasted_iota_f64
0.5.0
lax.is_finite IsInf
IsNaN
Not
Or
is_finite_vec
is_finite_vec_f64
0.12.1
lax.less_equal LessOrEqual less_equal
less_equal_f64
0.7.5
lax.lgamma Log
Sin
Where
lgamma_positive
lgamma_positive_f64
lgamma_negative_noninteger
lgamma_negative_noninteger_f64
0.12.1
lax.log Log
ReduceLogSumExp
ReduceLogSum
log
log_f64
log_of_reduce_sum_axis1
log_of_reduce_sum_axis1_f64
log_of_reduce_sum_exp_axis1
log_of_reduce_sum_exp_axis1_f64
0.2.0
lax.log1p Add
Log
log1p
log1p_f64
0.11.0
lax.logistic Sigmoid lax_logistic_basic
lax_logistic_basic_f64
0.7.2
lax.lt Less lt
lt_f64
0.2.0
lax.lu Abs
ArgMax
Div
Mul
ScatterElements
ScatterND
Sub
lu_square_3x3
lu_square_3x3_f64
lu_rectangular_4x3
lu_rectangular_4x3_f64
0.12.1
lax.lu_pivots_to_permutation Concat
Gather
ScatterElements
Squeeze
Unsqueeze
lu_pivots_to_permutation_basic
lu_pivots_to_permutation_basic_f64
lu_pivots_to_permutation_identity
lu_pivots_to_permutation_identity_f64
lu_pivots_to_permutation_batched
lu_pivots_to_permutation_batched_f64
0.12.1
lax.max Max max
max_f64
0.2.0
lax.min Min min_test1
min_test1_f64
0.1.0
lax.mul Mul mul_test1
mul_test1_f64
mul_test2
mul_test2_f64
mul_pyfloat_promotes_to_array_dtype_f64
mul_scalar_broadcast_promote_to_f64
mul_complex128
mul_complex64
0.1.0
lax.ne Equal
Not
ne
ne_f64
0.2.0
lax.neg Neg neg
neg_f64
0.2.0
lax.nextafter IsNaN
Log
Pow
Sign
Where
nextafter_vector
nextafter_special_values
0.12.1
lax.optimization_barrier Identity optimization_barrier_single
optimization_barrier_single_f64
optimization_barrier_tuple
optimization_barrier_tuple_f64
0.12.1
lax.or BitwiseOr
Or
or_bool_vec
or_bool_vec_f64
or_int_vec
or_int_vec_f64
0.7.2
lax.ormqr MatMul
Mul
ScatterND
Sub
Transpose
ormqr_left
ormqr_left_f64
ormqr_left_transpose
ormqr_left_transpose_f64
ormqr_right
ormqr_right_f64
0.13.0
lax.pad Pad pad_const_1d
pad_const_1d_f64
pad_const_2d
pad_const_2d_f64
pad_const_2d_cval
pad_const_2d_cval_f64
pad_inside_scan_smoke_f64
pad_inside_nested_scan_smoke_f64
0.8.0
lax.pjit pjit_inline_mul
pjit_inline_mul_f64
pjit_inline_tuple
pjit_inline_tuple_f64
0.1.0
lax.polygamma Exp
Floor
Pow
Round
Where
polygamma_orders
polygamma_zero_order
0.12.1
lax.population_count BitShift
BitwiseAnd
population_count_i32
population_count_i32_f64
population_count_u8
population_count_u8_f64
0.12.1
lax.pow Pow pow_basic
pow_basic_f64
pow_lax
pow_lax_f64
0.8.2
lax.qr MatMul
Mul
ScatterND
Sqrt
Sub
qr_reduced_tall
qr_reduced_tall_f64
qr_reduced_wide
qr_reduced_wide_f64
qr_full_tall
qr_full_tall_f64
qr_full_wide
qr_full_wide_f64
0.12.1
lax.real Identity real_complex64_input
real_complex64_input_f64
0.10.2
lax.reduce ReduceMax
ReduceMin
reduce_max_lambda
reduce_max_lambda_f64
reduce_min_lambda
reduce_min_lambda_f64
0.12.1
lax.reduce_and ReduceMin reduce_and_all_true
reduce_and_all_true_f64
reduce_and_one_false
reduce_and_one_false_f64
reduce_and_keepdims
reduce_and_keepdims_f64
0.6.1
lax.reduce_max ReduceMax reduce_max
reduce_max_f64
reduce_max_allaxes
reduce_max_allaxes_f64
reduce_max_axes_input
reduce_max_axes_input_f64
reduce_max_keepdims
reduce_max_keepdims_f64
0.2.0
lax.reduce_min ReduceMin reduce_min
reduce_min_f64
reduce_min_allaxes
reduce_min_allaxes_f64
reduce_min_keepdims
reduce_min_keepdims_f64
0.2.0
lax.reduce_or ReduceMax reduce_or_all_false
reduce_or_all_false_f64
reduce_or_one_true
reduce_or_one_true_f64
reduce_or_keepdims
reduce_or_keepdims_f64
reduce_or_no_axes
reduce_or_no_axes_f64
0.6.1
lax.reduce_precision Floor
Log
Pow
Round
Where
reduce_precision_mantissa10
reduce_precision_underflow_overflow
0.12.1
lax.reduce_prod ReduceProd reduce_prod
reduce_prod_f64
reduce_prod_allaxes
reduce_prod_allaxes_f64
reduce_prod_dtype
reduce_prod_dtype_f64
reduce_prod_keepdims
reduce_prod_keepdims_f64
0.6.1
lax.reduce_sum ReduceL1
ReduceSumSquare
ReduceSum
reduce_sum
reduce_sum_f64
reduce_sum_allaxes
reduce_sum_allaxes_f64
reduce_sum_dtype
reduce_sum_dtype_f64
reduce_sum_uint32_axis
reduce_sum_keepdims
reduce_sum_keepdims_f64
reduce_sum_no_axes
reduce_sum_no_axes_f64
reduce_sum_of_abs_axis1
reduce_sum_of_abs_axis1_f64
reduce_sum_of_square_axis1
reduce_sum_of_square_axis1_f64
0.2.0
lax.reduce_window Conv
MaxPool
Neg
reduce_window_add_lambda
reduce_window_max_lambda
reduce_window_min_lambda_stride
0.12.1
lax.reduce_window_max MaxPool reduce_window_max_primitive 0.12.2
lax.reduce_window_sum Conv
LpPool
reduce_window_sum_valid
reduce_window_sum_same_padding
reduce_window_sum_stride_dilate
reduce_window_sum_int32
reduce_window_sum_base_dilation
reduce_window_sum_abs_lppool_opset23
reduce_window_sum_abs_lppool_dilated_opset23
0.10.1
lax.reduce_xor Mod
ReduceSum
reduce_xor_all_false
reduce_xor_all_false_f64
reduce_xor_one_true
reduce_xor_one_true_f64
reduce_xor_two_true
reduce_xor_two_true_f64
reduce_xor_keepdims
reduce_xor_keepdims_f64
0.6.1
lax.rem Div
Mod
rem_int
rem_int_f64
rem_float
rem_float_f64
rem_int_neg
rem_int_neg_f64
rem_float_neg
rem_float_neg_f64
0.6.5
lax.remat2 remat2_scalar_sin_chain
remat2_scalar_sin_chain_f64
remat2_tuple_passthrough
remat2_tuple_passthrough_f64
0.6.5
lax.reshape Reshape reshape_after_transpose_folds_const_shape
reshape_after_transpose_folds_const_shape_f64
reshape_flatten_trailing_folds_const_shape
reshape_flatten_trailing_folds_const_shape_f64
reshape
reshape_f64
reshape_valid_squeeze_middle_dim_from_problematic_source
reshape_valid_squeeze_middle_dim_from_problematic_source_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_with_inferred_dimension_from_input_dynamic_dynamic
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64
reshape_with_inferred_dimension_from_input_dynamic
reshape_with_inferred_dimension_from_input_dynamic_f64
reshape_with_inferred_dimension_from_input
reshape_with_inferred_dimension_from_input_f64
reshape_merge_symbolic_with_static_and_check_name_dynamic
reshape_merge_symbolic_with_static_and_check_name
0.2.0
lax.rev Gather
Range
rev_vector
rev_vector_f64
rev_matrix_axes01
rev_matrix_axes01_f64
0.7.5
lax.rng_bit_generator Cast
Identity
RandomUniform
rng_bit_generator_u32 0.12.1
lax.rng_uniform Add
Mul
RandomUniform
Sub
rng_uniform_f32 0.12.1
lax.round Round round
round_f64
0.12.1
lax.rsqrt Div
Sqrt
rsqrt
rsqrt_f64
0.10.2
lax.scan Loop scan_identity_slice_helper
scan_identity_slice_helper_f64
scan_cumsum
scan_cumsum_f64
scan_carry_only
scan_carry_only_f64
scan_multiple_sequences
scan_multiple_sequences_f64
scan_multiple_carry
scan_multiple_carry_f64
scan_matrix_carry_multidim_xs
scan_matrix_carry_multidim_xs_f64
scan_unroll_reuses_loop
scan_no_xs
scan_no_xs_f64
scan_fn
scan_fn_f64
scan_jit_no_xs
scan_jit_no_xs_f64
scan_captured_scalar
scan_captured_scalar_f64
scan_rank0_sequence_vectorized
scan_rank0_sequence_vectorized_f64
scan_two_diff_lengths
scan_two_diff_lengths_f64
scan_two_diff_lengths_broadcast
scan_two_diff_lengths_broadcast_f64
scan_two_diff_lengths_with_broadcast
scan_nested_len_mismatch
scan_nested_len_mismatch_f64
scan_captured_scalar_with_xs
scan_captured_vector_with_xs_f64
0.5.1
lax.scatter NonZero
ScatterElements
ScatterND
Scatter
scatter_set_axis0
scatter_set_axis0_f64
scatter_set_middle
scatter_set_middle_f64
scatter_set_single
scatter_set_single_f64
scatter_set_vector
scatter_set_vector_f64
scatter_elements_set_vector_promise_in_bounds
scatter_elements_set_vector_promise_in_bounds_f64
scatter_correct_axis_determination
scatter_correct_axis_determination_f64
scatter_updates_slice_needed_axis0
scatter_updates_slice_needed_axis0_f64
scatter_from_user_warning_shapes_valid_jax
scatter_from_user_warning_shapes_valid_jax_f64
scatter_user_error_scenario_precise
scatter_user_error_scenario_precise_f64
scatter_window_update_f64
scatter_window_update_depth3_shapes_ok
scatter_static_slice_set_f64
scatter_depth2_fp64_type_mismatch
scatter_clip_2d_window_at_edge
scatter_simple_2d_window_out_of_bounds
scatter_depth2_mixed_dtypes_fp_mismatch_f64
scatter_depth2_mixed_dtypes_fp_mismatch
0.4.4
lax.scatter_add ScatterND(reduction='add') scatter_add_vector
scatter_add_vector_f64
scatter_add_scalar
scatter_add_scalar_f64
scatter_add_simple_1d
scatter_add_simple_1d_f64
scatter_add_fill_or_drop_oob_1d
scatter_add_fill_or_drop_oob_1d_f64
scatter_add_batch_updates_1d_operand
scatter_add_batch_updates_1d_operand_f64
scatter_add_window_2d_operand_1d_indices
scatter_add_window_2d_operand_1d_indices_f64
scatter_add_mismatched_window_dims_from_user_report
scatter_add_mismatched_window_dims_from_user_report2
scatter_add_mismatched_window_dims_from_user_report3
scatter_add_fluids_pattern_updates_5_4_1_1
scatter_add_in_cond_float64
scatter_add_fp64_dtype_mismatch
scatter_add_depth2_depth2_helper_regression
scatter_depth2_fp64_type_mismatch
0.5.3
lax.scatter_max ScatterND(reduction='max') scatter_max_simple_1d
scatter_max_simple_1d_f64
scatter_max_batch_updates_1d_operand
scatter_max_batch_updates_1d_operand_f64
scatter_max_window_2d_operand_1d_indices
scatter_max_window_2d_operand_1d_indices_f64
scatter_max_fp64_dtype_path_check
scatter_max_depth2_helper_regression_fp64
0.7.5
lax.scatter_min ScatterND(reduction='min') scatter_min_simple_1d
scatter_min_simple_1d_f64
scatter_min_batch_updates_1d_operand
scatter_min_batch_updates_1d_operand_f64
scatter_min_window_2d_operand_1d_indices
scatter_min_window_2d_operand_1d_indices_f64
scatter_min_fp64_dtype_path_check
scatter_min_depth2_helper_regression_fp64
0.7.5
lax.scatter_mul ScatterND(reduction='mul') scatter_mul_simple_1d
scatter_mul_simple_1d_f64
scatter_mul_batch_updates_1d_operand
scatter_mul_batch_updates_1d_operand_f64
scatter_mul_window_2d_operand_1d_indices
scatter_mul_window_2d_operand_1d_indices_f64
scatter_mul_mismatched_window_dims_from_user_report
scatter_mul_mismatched_window_dims_from_user_report2
scatter_mul_mismatched_window_dims_from_user_report3
scatter_mul_fluids_pattern_updates_5_4_1_1
scatter_mul_in_cond_float64
0.6.4
lax.scatter_sub Neg
ScatterND(reduction='add')
scatter_sub_simple_1d
scatter_sub_simple_1d_f64
scatter_sub_window_2d_operand_1d_indices
scatter_sub_window_2d_operand_1d_indices_f64
0.12.1
lax.schur Identity schur_1x1_default
schur_1x1_default_f64
schur_1x1_no_vectors
schur_1x1_no_vectors_f64
0.12.1
lax.select Where select_simple
select_simple_f64
select_basic
select_basic_f64
select_mask_scores_tensor_else_dynamic
select_mask_scores_tensor_else_dynamic_f64
select_mask_scores_tensor_else
select_mask_scores_tensor_else_f64
0.7.1
lax.select_n Equal
Where
select_n_bool_predicate_two_cases_float
select_n_bool_predicate_two_cases_float_f64
select_n_bool_predicate_two_cases_int
select_n_bool_predicate_two_cases_int_f64
select_n_bool_predicate_scalar_broadcast
select_n_bool_predicate_scalar_broadcast_f64
select_n_int_indices_three_cases
select_n_int_indices_three_cases_f64
select_n_int_indices_four_cases
select_n_int_indices_four_cases_f64
0.2.0
lax.shard_map shard_map_inline_add
shard_map_inline_add_f64
0.10.2
lax.shift_left BitShift shift_left_vec
shift_left_vec_f64
shift_left_scalar
shift_left_scalar_f64
0.12.1
lax.shift_right_arithmetic BitShift shift_right_arithmetic_vec
shift_right_arithmetic_vec_f64
shift_right_arithmetic_scalar
shift_right_arithmetic_scalar_f64
0.12.1
lax.shift_right_logical BitShift shift_right_logical_vec
shift_right_logical_vec_f64
shift_right_logical_scalar
shift_right_logical_scalar_f64
0.7.2
lax.sign Sign sign
sign_f64
0.5.0
lax.sin Sin sin
sin_f64
0.4.4
lax.sinh Sinh sinh
sinh_f64
0.4.4
lax.slice Slice slice_test1
slice_test1_f64
slice_3d_none_strides
slice_3d_none_strides_f64
slice_scan_axis_drop
slice_scan_axis_drop_f64
0.1.0
lax.sort GatherElements
TopK
sort_1d
sort_1d_f64
sort_2d
sort_2d_f64
sort_two_keys
sort_two_keys_f64
0.2.0
lax.split Split lax_split_equal_parts
lax_split_equal_parts_f64
lax_split_unequal_parts
lax_split_unequal_parts_f64
lax_split_single_output
lax_split_single_output_f64
0.7.2
lax.sqrt ReduceL2
Sqrt
sqrt
sqrt_f64
sqrt_reduce_sum_square_axis1
sqrt_reduce_sum_square_axis1_f64
0.2.0
lax.square Mul square
square_f64
0.2.0
lax.squeeze Squeeze squeeze_single_axis
squeeze_single_axis_f64
squeeze_all_unit_dims_default
squeeze_all_unit_dims_default_f64
lax_squeeze_specific_axis_0
lax_squeeze_specific_axis_0_f64
lax_squeeze_multiple_axes
lax_squeeze_multiple_axes_f64
lax_squeeze_no_op_empty_dims
lax_squeeze_no_op_empty_dims_f64
lax_squeeze_problem_case_input_squeeze_only_axis_0
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64
lax_squeeze_problem_case_input_squeeze_axes_0_2
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64
0.2.0
lax.stop_gradient Identity stop_gradient
stop_gradient_f64
stop_gradient_basic
stop_gradient_basic_f64
0.2.0
lax.sub Sub sub_test1
sub_test1_f64
sub_test2
sub_test2_f64
sub_const
sub_const_f64
sub_complex64
sub_complex64_f64
0.1.0
lax.svd Abs
Add
Div
Gather
Identity
Max
Mul
Reshape
Sign
Sub
Where
svd_1x1_default
svd_1x1_default_f64
svd_1x1_values_only
svd_1x1_values_only_f64
svd_2x2_values_only
svd_2x2_values_only_f64
svd_3x2_values_only
svd_3x2_values_only_f64
svd_2x4_values_only
svd_2x4_values_only_f64
0.12.1
lax.symmetric_product Add
MatMul
Mul
Transpose
symmetric_product_default
symmetric_product_default_f64
symmetric_product_alpha_beta
symmetric_product_alpha_beta_f64
0.12.1
lax.tan Tan tan 0.12.1
lax.tanh Tanh tanh
tanh_f64
0.2.0
lax.top_k TopK top_k_last_axis
top_k_last_axis_f64
top_k_matrix
top_k_matrix_f64
0.10.2
lax.transpose Transpose transpose_basic
transpose_basic_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_reverse
transpose_reverse_f64
transpose_no_axes
transpose_no_axes_f64
transpose_nhwc_to_nchw
transpose_nhwc_to_nchw_f64
0.2.0
lax.triangular_solve Div
Gather
triangular_solve_left_basic
triangular_solve_left_basic_f64
triangular_solve_batched_unit_diag
0.12.1
lax.tridiagonal Gather
Identity
ScatterND
tridiagonal_2x2_lower_true
tridiagonal_2x2_lower_true_f64
tridiagonal_2x2_lower_false
tridiagonal_2x2_lower_false_f64
0.12.1
lax.tridiagonal_solve Concat
Div
Gather
Mul
Sub
tridiagonal_solve_single_rhs
tridiagonal_solve_single_rhs_f64
tridiagonal_solve_multi_rhs
tridiagonal_solve_multi_rhs_f64
0.12.1
lax.while_loop Loop while_scalar_counter
while_scalar_counter_f64
while_tuple_state
while_tuple_state_f64
while_loop_counter
while_loop_counter_f64
while_loop_vector
while_loop_vector_f64
while_loop_f64
while_loop_multi_state_f32
while_loop_multi_state_f64
while_loop_with_closure
while_loop_with_closure_f64
while_loop_basic
while_loop_two_state
while_loop_captured_tracer
while_loop_with_scalar_state
while_loop_renamed_passthrough
while_loop_closure_topo
while_loop_mixed_rank
while_loop_tracer_passthrough
while_loop_no_loop_output_reused_as_input
while_loop_4d_and_scalar_state
while_loop_4d_and_scalar_state_f64
while_loop_cnn_scalar_state_bug
while_loop_cnn_scalar_state_bug_f64
while_loop_nnx_repro
while_loop_nnx_repro_f64
0.5.1
lax.xor BitwiseXor
Xor
xor_bool_vec
xor_bool_vec_f64
xor_int_vec
xor_int_vec_f64
0.12.1
lax.zeta Add
Div
Pow
Where
zeta_positive
zeta_broadcast
0.12.1
linen.activation activation_glu_basic
activation_glu_basic_f64
activation_hard_sigmoid_basic
activation_hard_silu_basic
activation_hard_silu_basic_f64
activation_hard_swish_basic
activation_hard_tanh_basic
activation_hard_tanh_basic_f64
activation_log_sigmoid_basic
activation_log_sigmoid_basic_f64
activation_log_softmax_basic
activation_log_softmax_basic_f64
activation_relu6_basic
activation_relu6_basic_f64
activation_silu_basic
activation_silu_basic_f64
activation_swish_basic
activation_swish_basic_f64
activation_tanh_basic
activation_tanh_basic_f64
activation_normalize_basic
activation_normalize_basic_f64
activation_one_hot_basic
0.11.0
linen.avg_pool AveragePool
GlobalAveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
avg_pool_global_window_dynamic
avg_pool_global_window
0.11.0
linen.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale
batch_norm_bias_scale_dynamic
batch_norm_bias_scale
batch_norm_3d_dynamic
batch_norm_3d
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale
0.11.0
linen.bidirectional Concat
Loop
bidirectional_basic_dynamic
bidirectional_basic_dynamic_f64
bidirectional_basic
bidirectional_basic_f64
0.11.0
linen.conv CastLike
Conv
Reshape
Transpose
conv_basic_dynamic
conv_basic
conv_no_bias
conv_stride
0.11.0
linen.conv_local Conv
Gemm
Reshape
Transpose
conv_local_valid
conv_local_same
0.11.0
linen.conv_lstm_cell Add
Conv
Gemm
Mul
Sigmoid
Tanh
conv_lstm_cell_basic_dynamic
conv_lstm_cell_basic
0.11.0
linen.conv_transpose ConvTranspose conv_transpose_basic_dynamic
conv_transpose_basic
conv_transpose_valid_stride
0.11.0
linen.dense Gemm dense_basic_dynamic
dense_basic_dynamic_f64
dense_basic
dense_basic_f64
dense_high_rank_dynamic_dynamic
dense_high_rank_dynamic_dynamic_f64
dense_high_rank_static
dense_high_rank_static_f64
dense_high_rank_no_bias
dense_high_rank_no_bias_f64
dense_no_bias_dynamic
dense_no_bias_dynamic_f64
dense_no_bias
dense_no_bias_f64
0.11.0
linen.dense_general CastLike
Concat
Gemm
Reshape
Shape
Slice
dense_general_basic_dynamic
dense_general_basic_dynamic_f64
dense_general_basic
dense_general_basic_f64
dense_general_multi_out
dense_general_multi_out_f64
dense_general_contract_last_two
dense_general_contract_last_two_f64
dense_general_dynamic_batch_dynamic
dense_general_dynamic_batch_dynamic_f64
dense_general_no_bias
dense_general_no_bias_f64
0.11.0
linen.dot_product_attention MatMul
Mul
Softmax
Transpose
Where
dot_product_attention_basic
dot_product_attention_basic_f64
0.11.0
linen.dot_product_attention_weights Add
Div
MatMul
Mul
ReduceSum
Softmax
Where
dot_product_attention_weights_basic
dot_product_attention_weights_basic_f64
0.11.0
linen.dropout Dropout dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
0.11.0
linen.einsum Add
Einsum
Reshape
einsum_with_bias
einsum_with_bias_f64
einsum_no_bias
einsum_no_bias_f64
0.11.0
linen.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
0.11.0
linen.group_norm GroupNormalization group_norm_rank4
group_norm_rank2_dynamic
group_norm_rank2
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale
0.11.0
linen.gru_cell Add
Gemm
Mul
Sigmoid
Tanh
gru_cell_basic_dynamic
gru_cell_basic_dynamic_f64
gru_cell_basic
gru_cell_basic_f64
0.11.0
linen.instance_norm GroupNormalization
InstanceNormalization
instance_norm_rank4_dynamic
instance_norm_rank4
instance_norm_rank2_dynamic
instance_norm_rank2
0.11.0
linen.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
layer_norm_default_epsilon_dynamic
layer_norm_default_epsilon
0.11.0
linen.lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
lstm_cell_basic_dynamic
lstm_cell_basic_dynamic_f64
lstm_cell_basic
lstm_cell_basic_f64
0.11.0
linen.make_attention_mask Cast
Mul
make_attention_mask_basic
make_attention_mask_basic_f64
0.11.0
linen.make_causal_mask Cast
Less
make_causal_mask_basic
make_causal_mask_basic_f64
0.11.0
linen.max_pool GlobalMaxPool
MaxPool
max_pool
max_pool_same_padding
max_pool_basic
max_pool_same_dynamic
max_pool_same
max_pool_global_window_dynamic
max_pool_global_window
0.11.0
linen.mgu_cell Add
Gemm
Mul
Sigmoid
Tanh
mgu_cell_basic_dynamic
mgu_cell_basic_dynamic_f64
mgu_cell_basic
mgu_cell_basic_f64
0.11.0
linen.min_pool MaxPool
Neg
min_pool
min_pool_same_padding
min_pool_basic
min_pool_same_dynamic
min_pool_same
0.11.0
linen.multi_head_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
multi_head_attention_basic_dynamic
multi_head_attention_basic_dynamic_f64
multi_head_attention_basic
multi_head_attention_basic_f64
multi_head_attention_no_bias_dynamic
multi_head_attention_no_bias_dynamic_f64
multi_head_attention_no_bias
multi_head_attention_no_bias_f64
0.11.0
linen.multi_head_dot_product_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
multi_head_dot_product_attention_basic_dynamic
multi_head_dot_product_attention_basic_dynamic_f64
multi_head_dot_product_attention_basic
multi_head_dot_product_attention_basic_f64
multi_head_dot_product_attention_no_bias_dynamic
multi_head_dot_product_attention_no_bias_dynamic_f64
multi_head_dot_product_attention_no_bias
multi_head_dot_product_attention_no_bias_f64
0.11.0
linen.optimized_lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
optimized_lstm_cell_basic_dynamic
optimized_lstm_cell_basic_dynamic_f64
optimized_lstm_cell_basic
optimized_lstm_cell_basic_f64
0.11.0
linen.pool AveragePool
MaxPool
Mul
Neg
Transpose
pool_max_basic_dynamic
pool_max_basic
pool_min_basic_dynamic
pool_min_basic
pool_sum_basic_dynamic
pool_sum_basic
0.11.0
linen.prelu PRelu linen_prelu_default_dynamic
linen_prelu_default
linen_prelu_custom_slope
0.12.1
linen.rms_norm RMSNormalization rms_norm_basic
rms_norm_use_scale_false
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic
0.11.0
linen.rnn Loop rnn_basic_dynamic
rnn_basic_dynamic_f64
rnn_basic
rnn_basic_f64
0.11.0
linen.self_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
self_attention_basic_dynamic
self_attention_basic_dynamic_f64
self_attention_basic
self_attention_basic_f64
self_attention_no_bias_dynamic
self_attention_no_bias_dynamic_f64
self_attention_no_bias
self_attention_no_bias_f64
0.11.0
linen.simple_cell Add
Gemm
Mul
Sigmoid
Tanh
simple_cell_basic_dynamic
simple_cell_basic_dynamic_f64
simple_cell_basic
simple_cell_basic_f64
0.11.0
linen.spectral_norm Div
MatMul
spectral_norm_dense_dynamic
spectral_norm_dense
0.11.0
linen.weight_norm Div
Mul
weight_norm_dense_dynamic
weight_norm_dense
0.11.0
nn.celu Celu jaxnn_celu
jaxnn_celu_1
jaxnn_celu_alpha_default
jaxnn_celu_alpha_custom_dynamic
jaxnn_celu_alpha_custom
celu_grad_issue_batch_diff_rules
0.7.1
nn.dot_product_attention Add
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_positional_bias_mask
dpa_positional_bias_mask_f64
dpa_tnh_unbatched
dpa_tnh_unbatched_f64
dpa_diff_heads_embed
dpa_diff_heads_embed_f64
dpa_batch4_seq16
dpa_batch4_seq16_f64
dpa_float64
dpa_heads1_embed4
dpa_heads1_embed4_f64
dpa_heads8_embed8
dpa_heads8_embed8_f64
dpa_batch1_seq2
dpa_batch1_seq2_f64
dpa_batch8_seq4
dpa_batch8_seq4_f64
dpa_axis1
dpa_axis1_f64
dpa_with_tensor_mask
dpa_with_tensor_mask_f64
dpa_with_bias
dpa_with_bias_f64
dpa_with_custom_scale
dpa_with_custom_scale_f64
dpa_gqa_basic
dpa_gqa_basic_f64
dpa_return_residual_false
dpa_return_residual_false_f64
dpa_tiny_mask_all_valid
dpa_tiny_mask_all_valid_f64
dpa_tiny_mask_mixed
dpa_tiny_mask_mixed_f64
dpa_one_false
dpa_one_false_f64
dpa_mostly_false
dpa_mostly_false_f64
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_padding_mask
dpa_with_padding_mask_f64
dpa_with_local_window_mask
dpa_with_local_window_mask_f64
dpa_vmap_tnh_issue190
dpa_mask_none
0.8.0
nn.elu Elu jaxnn_elu
jaxnn_elu_1
jaxnn_elu_default
jaxnn_elu_custom_alpha_dynamic
jaxnn_elu_custom_alpha
elu_grad_issue_batch_diff_rules
0.7.1
nn.gelu Gelu jaxnn_gelu
jaxnn_gelu_1
jaxnn_gelu_approx
jaxnn_gelu_exact
jaxnn_gelu_tanh_dynamic
jaxnn_gelu_tanh
gelu_grad_issue_batch_diff_rules
0.7.1
nn.glu Mul
Sigmoid
Split
jaxnn_glu_axis_last
jaxnn_glu_axis_last_f64
jaxnn_glu_axis1_dynamic_dynamic
jaxnn_glu_axis1_dynamic_dynamic_f64
jaxnn_glu_axis1_dynamic
jaxnn_glu_axis1_dynamic_f64
0.12.1
nn.hard_sigmoid HardSigmoid jaxnn_hard_sigmoid
jaxnn_hard_sigmoid_dynamic_dynamic
jaxnn_hard_sigmoid_dynamic
0.12.1
nn.hard_swish HardSwish jaxnn_hard_swish
jaxnn_hard_swish_dynamic_dynamic
jaxnn_hard_swish_dynamic
0.12.1
nn.hard_tanh Max
Min
jaxnn_hard_tanh
jaxnn_hard_tanh_f64
jaxnn_hard_tanh_dynamic_dynamic
jaxnn_hard_tanh_dynamic_dynamic_f64
jaxnn_hard_tanh_dynamic
jaxnn_hard_tanh_dynamic_f64
0.12.1
nn.hardmax Hardmax jaxnn_hardmax_default_axis
jaxnn_hardmax_axis0
jaxnn_hardmax_dynamic_dynamic
jaxnn_hardmax_dynamic
jaxnn_hardmax_vmap_batching
0.12.1
nn.identity Identity jaxnn_identity
jaxnn_identity_f64
jaxnn_identity_1
jaxnn_identity_1_f64
jaxnn_identity_basic
jaxnn_identity_basic_f64
jaxnn_identity_dynamic_dynamic
jaxnn_identity_dynamic_dynamic_f64
jaxnn_identity_dynamic
jaxnn_identity_dynamic_f64
0.7.1
nn.leaky_relu LeakyRelu jaxnn_leaky_relu
jaxnn_leaky_relu_1
jaxnn_leaky_relu_default_dynamic
jaxnn_leaky_relu_default
jaxnn_leaky_relu_custom
leaky_relu_grad_issue_batch_diff_rules
0.7.1
nn.log1mexp Exp
Less
Log
Neg
Sub
Where
jaxnn_log1mexp_basic
jaxnn_log1mexp_basic_f64
jaxnn_log1mexp_dynamic_dynamic
jaxnn_log1mexp_dynamic_dynamic_f64
jaxnn_log1mexp_dynamic
jaxnn_log1mexp_dynamic_f64
0.12.1
nn.log_sigmoid Log
Sigmoid
jaxnn_log_sigmoid
jaxnn_log_sigmoid_f64
jaxnn_log_sigmoid_dynamic_dynamic
jaxnn_log_sigmoid_dynamic_dynamic_f64
jaxnn_log_sigmoid_dynamic
jaxnn_log_sigmoid_dynamic_f64
0.12.1
nn.log_softmax LogSoftmax jaxnn_log_softmax_default_axis_dynamic
jaxnn_log_softmax_default_axis_dynamic_f64
jaxnn_log_softmax_default_axis
jaxnn_log_softmax_default_axis_f64
jaxnn_log_softmax_axis0
jaxnn_log_softmax_axis0_f64
jaxnn_log_softmax_axis_last_3d
jaxnn_log_softmax_axis_last_3d_f64
log_softmax_grad_issue_batch_diff_rules
0.12.1
nn.logmeanexp ReduceLogSumExp
Sub
jaxnn_logmeanexp_axis1
jaxnn_logmeanexp_axis1_f64
jaxnn_logmeanexp_axis12_keepdims
jaxnn_logmeanexp_axis12_keepdims_f64
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic_f64
jaxnn_logmeanexp_dynamic_static_reduction_axis
jaxnn_logmeanexp_dynamic_static_reduction_axis_f64
0.12.1
nn.logsumexp ReduceLogSumExp jaxnn_logsumexp_axis1
jaxnn_logsumexp_axis1_f64
jaxnn_logsumexp_axis12_keepdims
jaxnn_logsumexp_axis12_keepdims_f64
jaxscipy_logsumexp_axis_last_dynamic
jaxscipy_logsumexp_axis_last_dynamic_f64
jaxscipy_logsumexp_axis_last
jaxscipy_logsumexp_axis_last_f64
logsumexp_grad_issue_batch_diff_rules
0.12.1
nn.mish Mish jaxnn_mish
jaxnn_mish_1
jaxnn_mish_basic
mish_grad_issue_batch_diff_rules
0.7.1
nn.one_hot OneHot jaxnn_one_hot_default
jaxnn_one_hot_axis0
0.12.1
nn.relu Relu jaxnn_relu
jaxnn_relu_f64
jaxnn_relu_1
jaxnn_relu_1_f64
jaxnn_relu_basic
jaxnn_relu_basic_f64
jaxnn_relu_dynamic_dynamic
jaxnn_relu_dynamic_dynamic_f64
jaxnn_relu_dynamic
jaxnn_relu_dynamic_f64
relu_grad_issue_batch_diff_rules
0.7.1
nn.relu6 Max
Min
jaxnn_relu6
jaxnn_relu6_f64
jaxnn_relu6_dynamic_dynamic
jaxnn_relu6_dynamic_dynamic_f64
jaxnn_relu6_dynamic
jaxnn_relu6_dynamic_f64
0.12.1
nn.scaled_dot_general Einsum
Gemm
MatMul
jaxnn_scaled_dot_general_basic
jaxnn_scaled_dot_general_batched
jaxnn_scaled_dot_general_batched_f64
0.12.1
nn.scaled_matmul MatMul
Mul
Reshape
Tile
Transpose
Unsqueeze
jaxnn_scaled_matmul_basic 0.12.1
nn.selu Selu jaxnn_selu
jaxnn_selu_1
jaxnn_selu_basic_dynamic
jaxnn_selu_basic
selu_grad_issue_batch_diff_rules
0.7.1
nn.sigmoid Sigmoid jaxnn_sigmoid
jaxnn_sigmoid_f64
jaxnn_sigmoid_1
jaxnn_sigmoid_1_f64
sigmoid_grad_issue_batch_diff_rules
0.7.1
nn.silu Mul
Sigmoid
Swish
jaxnn_silu
jaxnn_silu_f64
jaxnn_silu_dynamic_dynamic
jaxnn_silu_dynamic_dynamic_f64
jaxnn_silu_dynamic
jaxnn_silu_dynamic_f64
jaxnn_swish_alias
jaxnn_swish_alias_f64
silu_grad_issue_batch_diff_rules
0.12.1
nn.soft_sign Softsign jaxnn_soft_sign
jaxnn_soft_sign_1
jaxnn_softsign_basic
softsign_grad_issue_batch_diff_rules
0.7.1
nn.softmax Softmax softmax
softmax_f64
softmax_2d
softmax_2d_f64
softmax_3d
softmax_3d_f64
softmax_mask_where
softmax_mask_where_f64
softmax_grad_issue_batch_diff_rules
0.7.1
nn.softplus Softplus jaxnn_softplus
jaxnn_softplus_1
jaxnn_softplus_basic
softplus_grad_issue_batch_diff_rules
0.7.1
nn.sparse_plus Add
Greater
Less
Mul
Where
jaxnn_sparse_plus
jaxnn_sparse_plus_f64
jaxnn_sparse_plus_dynamic_dynamic
jaxnn_sparse_plus_dynamic_dynamic_f64
jaxnn_sparse_plus_dynamic
jaxnn_sparse_plus_dynamic_f64
0.12.1
nn.sparse_sigmoid Add
Greater
Less
Mul
Where
jaxnn_sparse_sigmoid
jaxnn_sparse_sigmoid_f64
jaxnn_sparse_sigmoid_dynamic_dynamic
jaxnn_sparse_sigmoid_dynamic_dynamic_f64
jaxnn_sparse_sigmoid_dynamic
jaxnn_sparse_sigmoid_dynamic_f64
0.12.1
nn.squareplus Add
Mul
Sqrt
jaxnn_squareplus_default_b
jaxnn_squareplus_default_b_f64
jaxnn_squareplus_broadcast_b
jaxnn_squareplus_broadcast_b_f64
0.12.1
nn.standardize MeanVarianceNormalization jaxnn_standardize_axis_last_eps0
jaxnn_standardize_axis_tuple_eps0_dynamic
jaxnn_standardize_axis_tuple_eps0
standardize_grad_issue_batch_diff_rules
0.12.1
nn.tanh Tanh jaxnn_tanh
jaxnn_tanh_f64
jaxnn_tanh_dynamic_dynamic
jaxnn_tanh_dynamic_dynamic_f64
jaxnn_tanh_dynamic
jaxnn_tanh_dynamic_f64
0.12.1
nn.thresholded_relu ThresholdedRelu jaxnn_thresholded_relu_default
jaxnn_thresholded_relu_custom_dynamic
jaxnn_thresholded_relu_custom
0.12.1
nn.truncated_normal initializer
random_truncated_normal_positional
flax_dense_like_init
0.7.1
nnx.avg_pool AveragePool
GlobalAveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
avg_pool_global_window_dynamic
avg_pool_global_window
0.1.0
nnx.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale
batch_norm_bias_scale_dynamic
batch_norm_bias_scale
batch_norm_3d_dynamic
batch_norm_3d
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale
0.1.0
nnx.combine_masks And
Cast
combine_masks_two
combine_masks_with_none
0.12.2
nnx.conv CastLike
Conv
Reshape
Transpose
conv_basic_bias_dynamic
conv_basic_bias
conv_basic_bias_2
conv_basic_bias_3
conv_stride2_bias
conv_no_bias_dynamic
conv_no_bias
conv_valid_padding
conv_stride1
conv_stride2
conv_2d_reflect_padding
conv_2d_circular_padding
conv_different_kernel
conv_float64
conv_single_batch
conv_large_batch
conv_1d_causal_padding
conv_1d
conv_1d_more_1d_inputs
conv_1d_more_2d_inputs
conv_1d_large_kernel
conv_1d_dilation
conv_1d_stride_dilation
conv_2d_asymmetric_kernel
conv_2d_asymmetric_stride
conv_2d_asymmetric_dilation
conv_2d_large_dilation
conv_2d_large_stride
conv_2d_mixed_params
conv_2d_same_padding_mixed_dilation
conv_3d_basic
conv_3d_stride
conv_3d_asymmetric
conv_3d_dilation
conv_2d_small_input
conv_2d_many_channels
conv_1d_wide_input
conv_2d_kernel_1x1
conv_1d_kernel_1
conv_2d_group_conv
conv_1d_group_conv_more_dims
conv_2d_depthwise
conv_1d_complex_on_4d
conv_2d_complex_on_5d
conv_2d_asymmetric_on_5d
conv_1d_high_dilation_on_3d
conv_1d_large_kernel_on_4d
conv_2d_group_stride_dilation
conv_1d_group_on_higher_dim
conv_1d_same_padding_on_3d
conv_3d_group_complex
conv_1d_unit_group_on_multi_dim
0.1.0
nnx.dot_product_attention Add
Attention
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_with_tensor_mask
dpa_with_bias
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_mask_and_bias
dpa_gqa_basic
dpa_gqa_basic_opset23
0.1.0
nnx.dropout Constant
Dropout
dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias
einsum_module_with_bias_f64
einsum_module_no_bias
einsum_module_no_bias_f64
0.4.2
nnx.elu Elu elu
elu_default_dynamic
elu_default
elu_alpha
0.2.0
nnx.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
0.7.0
nnx.flip_sequences GatherElements
Slice
flip_sequences_batch_major_with_lengths
flip_sequences_batch_major_with_lengths_f64
flip_sequences_batch_major_no_lengths_dynamic
flip_sequences_batch_major_no_lengths_dynamic_f64
flip_sequences_batch_major_no_lengths
flip_sequences_batch_major_no_lengths_f64
flip_sequences_time_major_with_lengths
flip_sequences_time_major_with_lengths_f64
0.12.2
nnx.gelu Gelu gelu
gelu_1
gelu_2
gelu_2_f64
gelu_3_dynamic
gelu_3_dynamic_f64
gelu_3
gelu_3_f64
gelu_4
gelu_4_f64
gelu_5_dynamic
gelu_5_dynamic_f64
gelu_5
gelu_5_f64
0.1.0
nnx.glu Mul
Sigmoid
Split
glu_last_axis
glu_last_axis_f64
glu_axis_1_dynamic
glu_axis_1_dynamic_f64
glu_axis_1
glu_axis_1_f64
0.12.2
nnx.group_norm GroupNormalization group_norm
group_norm_rank2_dynamic
group_norm_rank2
group_norm_rank4
group_norm_no_bias_dynamic
group_norm_no_bias
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale
group_norm_bias_no_scale_dynamic
group_norm_bias_no_scale
group_norm_no_scale_dynamic
group_norm_no_scale
group_norm_no_bias_scale_dynamic
group_norm_no_bias_scale
group_norm_bias_scale_dynamic
group_norm_bias_scale
0.2.0
nnx.hard_tanh Clip hard_tanh_basic_dynamic
hard_tanh_basic_dynamic_f64
hard_tanh_basic
hard_tanh_basic_f64
0.12.2
nnx.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale
layer_norm_bias_no_scale_dynamic
layer_norm_bias_no_scale
layer_norm_no_bias_scale_dynamic
layer_norm_no_bias_scale
layer_norm_bias_scale_dynamic
layer_norm_bias_scale
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
layer_norm_symbolic_batch_dynamic
layer_norm_symbolic_batch
layer_norm_symbolic_batch_seq10_feat3_dynamic
layer_norm_symbolic_batch_seq10_feat3
layer_norm_symbolic_batch_seq10_feat3_2_dynamic
layer_norm_symbolic_batch_seq10_feat3_2
layer_norm_negative_axis_no_div_dynamic
layer_norm_negative_axis_no_div
0.1.0
nnx.leaky_relu LeakyRelu leaky_relu
leaky_relu_default_dynamic
leaky_relu_default
leaky_relu_custom
0.2.0
nnx.linear CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_symbolic_batch_dynamic
linear_symbolic_batch_dynamic_f64
linear_symbolic_batch
linear_symbolic_batch_f64
linear_high_rank_dynamic
linear_high_rank_dynamic_f64
linear_high_rank_static
linear_high_rank_static_f64
linear_no_bias_dynamic
linear_no_bias_dynamic_f64
linear_no_bias
linear_no_bias_f64
linear_high_rank_no_bias_dynamic
linear_high_rank_no_bias_dynamic_f64
linear_high_rank_no_bias
linear_high_rank_no_bias_f64
linear_merge_symbolic_dim_dynamic
0.1.0
nnx.linear_general CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_general_merge_symbolic_dim_dynamic
linear_general_dynamic
linear_general
linear_general_2
linear_general_3
linear_general_4
linear_general_abstract_eval_axes
linear_general_abstract_eval_axes_pair
dynamic_batch_and_feature_dims_dynamic
0.1.0
nnx.log_sigmoid Neg
Softplus
log_sigmoid_basic_dynamic
log_sigmoid_basic_dynamic_f64
log_sigmoid_basic
log_sigmoid_basic_f64
0.12.2
nnx.log_softmax LogSoftmax log_softmax
log_softmax_f64
log_softmax_default_axis_dynamic
log_softmax_default_axis_dynamic_f64
log_softmax_default_axis
log_softmax_default_axis_f64
log_softmax_axis0
log_softmax_axis0_f64
0.2.0
nnx.lora Add
MatMul
lora_basic_dynamic
lora_basic_dynamic_f64
lora_basic
lora_basic_f64
lora_static
lora_static_f64
0.12.2
nnx.lora_linear Add
MatMul
lora_linear_basic_dynamic
lora_linear_basic_dynamic_f64
lora_linear_basic
lora_linear_basic_f64
lora_linear_high_rank_dynamic
lora_linear_high_rank_dynamic_f64
lora_linear_high_rank
lora_linear_high_rank_f64
0.12.2
nnx.max_pool GlobalMaxPool
MaxPool
max_pool
max_pool_same_padding
max_pool_basic
max_pool_same_dynamic
max_pool_same
max_pool_global_window_dynamic
max_pool_global_window
0.2.0
nnx.prelu PRelu prelu_default
prelu_custom_slope_dynamic
prelu_custom_slope
0.12.1
nnx.relu Relu relu_1d
relu_1d_f64
relu_4d_dynamic
relu_4d_dynamic_f64
relu_4d
relu_4d_f64
0.2.0
nnx.rms_norm RMSNormalization rms_norm_basic
rms_norm_use_scale_false
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic
rms_norm_4d_dynamic_no_scale_dynamic
rms_norm_4d_dynamic_no_scale
0.2.0
nnx.sigmoid Sigmoid sigmoid_dynamic
sigmoid_dynamic_f64
sigmoid
sigmoid_f64
0.2.0
nnx.silu Mul
Sigmoid
Swish
silu_opset23
silu_opset23_f64
silu_opset24
silu_opset24_f64
swish_alias_opset24_dynamic
swish_alias_opset24_dynamic_f64
swish_alias_opset24
swish_alias_opset24_f64
0.14.0
nnx.softmax Softmax softmax_dynamic
softmax_dynamic_f64
softmax
softmax_f64
0.1.0
nnx.softplus Softplus softplus 0.1.0
nnx.tanh Tanh tanh
tanh_f64
0.1.0
random.bernoulli Bernoulli bernoulli_scalar_prob_shape
bernoulli_tensor_prob
0.12.1
random.random_bits Cast
Floor
RandomUniformLike
RandomUniform
random_bits_uint32
random_bits_uint32_f64
0.7.2
random.random_categorical Multinomial random_categorical_logits_batch
random_categorical_logits_batch_opset23
random_categorical_logits_rank3
random_categorical_logits_rank3_opset23
random_categorical_logits_symbolic_batch_dynamic
random_categorical_logits_symbolic_batch_opset23_dynamic
0.12.1
random.random_fold_in Identity random_fold_in_passthrough
random_fold_in_passthrough_f64
0.2.0
random.random_normal RandomNormalLike
RandomNormal
random_normal_f32_2x3 0.12.1
random.random_seed Cast
Concat
random_seed_basic
random_seed_basic_f64
0.2.0