Skip to content

Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
core.custom_jvp_generic (Generic passthrough for custom JVP calls) custom_jvp_square
custom_jvp_square_f64
0.7.1
core.custom_vjp_generic (Generic passthrough for custom VJP calls) custom_vjp_square
custom_vjp_square_f64
0.7.1
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic
dim_as_value_dynamic_f64
dim_as_value
dim_as_value_f64
0.5.0
core.jit_inline jit_identity
jit_identity_f64
0.9.0
dm_pix.depth_to_space DepthToSpace depth_to_space_nhwc 0.12.0
dm_pix.space_to_depth SpaceToDepth space_to_depth_nhwc 0.12.1
eqx.adaptive_pool AveragePool
MaxPool
eqx_adaptive_avg_pool2d_divisible
eqx_adaptive_max_pool2d_divisible
eqx_adaptive_pool_generic_avg_vmap_dynamic
eqx_adaptive_pool_generic_avg_vmap
0.12.2
eqx.avg_pool AveragePool eqx_avg_pool2d_basic
eqx_avg_pool2d_batched_dynamic
eqx_avg_pool2d_batched
0.12.2
eqx.batch_norm BatchNormalization eqx_batch_norm_inference 0.12.2
eqx.conv Conv eqx_conv2d_nchw
eqx_conv2d_batched_nchw
0.10.0
eqx.dropout Dropout
Not
eqx_dropout_inference_mode
eqx_dropout_inference_mode_f64
eqx_dropout_training_mode
eqx_dropout_training_mode_f64
eqx_dropout_dynamic_inference
eqx_dropout_dynamic_inference_f64
eqx_dropout_batched_inference_dynamic
eqx_dropout_batched_inference_dynamic_f64
eqx_dropout_batched_inference
eqx_dropout_batched_inference_f64
0.8.0
eqx.embedding Gather eqx_embedding_scalar_index
eqx_embedding_scalar_index_f64
eqx_embedding_batched_vmap_dynamic
eqx_embedding_batched_vmap_dynamic_f64
eqx_embedding_batched_vmap
eqx_embedding_batched_vmap_f64
0.12.2
eqx.group_norm GroupNormalization eqx_group_norm_rank3
eqx_group_norm_no_affine
0.12.2
eqx.gru_cell Add
Gemm
Mul
Sigmoid
Tanh
eqx_gru_cell_basic 0.12.2
eqx.identity Identity eqx_identity_static
eqx_identity_static_f64
eqx_identity_symbolic_batch_dynamic
eqx_identity_symbolic_batch_dynamic_f64
eqx_identity_symbolic_batch
eqx_identity_symbolic_batch_f64
0.8.0
eqx.lambda Relu
Sigmoid
Tanh
eqx_lambda_relu
eqx_lambda_relu_f64
eqx_lambda_sigmoid
eqx_lambda_sigmoid_f64
0.12.2
eqx.layer_norm LayerNormalization layer_norm
layer_norm_f64
layer_norm_multiaxis
layer_norm_multiaxis_f64
batched_layer_norm_dynamic
batched_layer_norm_dynamic_f64
batched_layer_norm
batched_layer_norm_f64
layer_norm_no_bias_no_scale
layer_norm_no_bias_no_scale_f64
0.8.0
eqx.linear Gemm
Reshape
eqx_linear_symbolic_batch_dynamic
eqx_linear_symbolic_batch_dynamic_f64
eqx_linear_symbolic_batch
eqx_linear_symbolic_batch_f64
eqx_linear_no_bias_symbolic_batch_dynamic
eqx_linear_no_bias_symbolic_batch_dynamic_f64
eqx_linear_no_bias_symbolic_batch
eqx_linear_no_bias_symbolic_batch_f64
eqx_linear_no_bias_vector
eqx_linear_no_bias_vector_f64
eqx_linear_high_rank
eqx_linear_high_rank_f64
eqx_linear_vector
eqx_linear_vector_f64
0.8.0
eqx.lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
eqx_lstm_cell_basic 0.12.2
eqx.max_pool MaxPool eqx_max_pool2d_basic
eqx_max_pool2d_basic_f64
eqx_max_pool2d_batched_dynamic
eqx_max_pool2d_batched_dynamic_f64
eqx_max_pool2d_batched
eqx_max_pool2d_batched_f64
0.12.2
eqx.multihead_attention Attention
Gemm
MatMul
Softmax
eqx_multihead_attention
eqx_multihead_attention_opset23
eqx_multihead_attention_core_dynamic
eqx_multihead_attention_core_opset23_dynamic
eqx_multihead_attention_core_static
eqx_multihead_attention_core_static_opset23
0.10.0
eqx.pool MaxPool eqx_pool_max2d_basic
eqx_pool_max2d_basic_f64
eqx_pool_max2d_batched_dynamic
eqx_pool_max2d_batched_dynamic_f64
eqx_pool_max2d_batched
eqx_pool_max2d_batched_f64
0.12.2
eqx.prelu PRelu eqx_prelu_default_dynamic
eqx_prelu_default
eqx_prelu_channelwise
0.12.2
eqx.rms_norm RMSNormalization rms_norm_vector
rms_norm_vector_f64
rms_norm_no_affine
rms_norm_no_affine_f64
0.10.2
eqx.rotary_positional_embedding Add
Concat
Multiply
RotaryEmbedding
eqx_rotary_positional_embedding
eqx_rotary_positional_embedding_opset23
eqx_rotary_positional_embedding_heads
0.10.0
eqx.sequential Identity
Relu
Sigmoid
Tanh
eqx_sequential_relu_tanh_dynamic
eqx_sequential_relu_tanh_dynamic_f64
eqx_sequential_relu_tanh
eqx_sequential_relu_tanh_f64
eqx_sequential_identity_sigmoid
eqx_sequential_identity_sigmoid_f64
0.12.2
eqx.spectral_norm Div
MatMul
eqx_spectral_norm_linear_inference 0.12.2
eqx.weight_norm Div
Mul
eqx_weight_norm_linear_dynamic
eqx_weight_norm_linear
eqx_weight_norm_conv2d
0.12.2
jax_image.resize Resize
Upsample
resize_linear
resize_nearest
resize_nearest_antialias_ignored
resize_nearest_opset9_upsample
resize_linear_opset9_upsample
resize_nearest_rank3_opset9_upsample
0.10.0
jnp.abs Abs jnp_abs_basic
jnp_abs_basic_f64
jnp_abs_int
jnp_abs_int_f64
abs_vmap_batching
abs_vmap_batching_f64
abs_grad_issue_batch_diff_rules
0.12.2
jnp.acos Acos jnp_acos_basic
jnp_acos_int_promote
acos_vmap_batching
acos_grad_issue_batch_diff_rules
0.12.2
jnp.acosh Acosh jnp_acosh_basic
jnp_acosh_int_promote
acosh_vmap_batching
acosh_grad_issue_batch_diff_rules
0.12.2
jnp.add Add add
add_f64
jnp_add_vector
jnp_add_vector_f64
jnp_add_broadcast
jnp_add_broadcast_f64
add_vmap_batching
add_vmap_batching_f64
add_grad_issue_batch_diff_rules
0.8.0
jnp.all ReduceMin jnp_all_basic
jnp_all_basic_f64
jnp_all_axis_keepdims
jnp_all_axis_keepdims_f64
all_vmap_batching
all_vmap_batching_f64
0.12.2
jnp.amax ReduceMax jnp_amax_basic
jnp_amax_basic_f64
jnp_amax_axis
jnp_amax_axis_f64
jnp_amax_keepdims
jnp_amax_keepdims_f64
amax_vmap_batching
amax_vmap_batching_f64
0.12.2
jnp.amin ReduceMin jnp_amin_basic
jnp_amin_basic_f64
jnp_amin_axis
jnp_amin_axis_f64
jnp_amin_keepdims
jnp_amin_keepdims_f64
amin_vmap_batching
amin_vmap_batching_f64
0.12.2
jnp.any ReduceMax jnp_any_basic
jnp_any_basic_f64
jnp_any_axis_keepdims
jnp_any_axis_keepdims_f64
any_vmap_batching
any_vmap_batching_f64
0.12.2
jnp.arange Range arange_data_dependent_indices
arange_stop_only_concrete_input_val
arange_stop_only_concrete_input_val_f64
arange_start_stop_concrete_input_val
arange_start_stop_concrete_input_val_f64
arange_start_stop_step_concrete_input_val
arange_start_stop_step_concrete_input_val_f64
arange_float_concrete_input_val
arange_float_concrete_input_val_f64
arange_static_stop_only_int
arange_static_stop_only_int_f64
arange_static_stop_only_float
arange_static_stop_only_float_f64
arange_static_start_stop_int
arange_static_start_stop_int_f64
arange_static_start_stop_step_int
arange_static_start_stop_step_int_f64
arange_static_empty_result_pos_step
arange_static_empty_result_pos_step_f64
arange_static_empty_result_neg_step
arange_static_empty_result_neg_step_f64
arange_static_negative_step
arange_static_negative_step_f64
arange_static_float_step_explicit_dtype
arange_static_float_step_explicit_dtype_f64
arange_static_float_step_inferred_dtype
arange_static_float_step_inferred_dtype_f64
arange_static_stop_zero
arange_static_stop_zero_f64
arange_static_start_equals_stop
arange_static_start_equals_stop_f64
arange_static_large_numbers_int
arange_static_large_numbers_int_f64
0.5.2
jnp.argmax ArgMax jnp_argmax_axis1
jnp_argmax_axis0_bool
0.12.2
jnp.argmin ArgMin jnp_argmin_axis1
jnp_argmin_axis0
0.12.2
jnp.asin Asin jnp_asin_basic
jnp_asin_int_promote
asin_vmap_batching
asin_grad_issue_batch_diff_rules
0.12.2
jnp.asinh Asinh jnp_asinh_basic
jnp_asinh_int_promote
asinh_vmap_batching
asinh_grad_issue_batch_diff_rules
0.12.2
jnp.atan Atan jnp_atan_basic
jnp_atan_int_promote
atan_vmap_batching
atan_grad_issue_batch_diff_rules
0.12.2
jnp.atan2 Atan
Div
Where
jnp_atan2_quadrants
jnp_atan2_broadcast
atan2_vmap_batching
atan2_grad_issue_batch_diff_rules
0.12.2
jnp.atanh Atanh jnp_atanh_basic
jnp_atanh_int_promote
atanh_vmap_batching
atanh_grad_issue_batch_diff_rules
0.12.2
jnp.bartlett Constant jnp_bartlett_basic
jnp_bartlett_basic_f64
jnp_bartlett_empty
jnp_bartlett_empty_f64
0.12.1
jnp.bitwise_and And
BitwiseAnd
jnp_bitwise_and_int
jnp_bitwise_and_int_f64
jnp_bitwise_and_bool
jnp_bitwise_and_bool_f64
jnp_bitwise_and_mixed_dtype
jnp_bitwise_and_mixed_dtype_f64
jnp_bitwise_and_broadcast
jnp_bitwise_and_broadcast_f64
bitwise_and_vmap_batching
bitwise_and_vmap_batching_f64
0.12.2
jnp.bitwise_left_shift BitShift jnp_bitwise_left_shift_basic
jnp_bitwise_left_shift_basic_f64
jnp_bitwise_left_shift_bool_bool
jnp_bitwise_left_shift_bool_bool_f64
bitwise_left_shift_vmap_batching
bitwise_left_shift_vmap_batching_f64
0.12.2
jnp.bitwise_not BitwiseNot
Not
jnp_bitwise_not_bool
jnp_bitwise_not_bool_f64
jnp_bitwise_not_int
jnp_bitwise_not_int_f64
bitwise_not_vmap_batching
bitwise_not_vmap_batching_f64
0.12.2
jnp.bitwise_or BitwiseOr
Or
jnp_bitwise_or_int
jnp_bitwise_or_int_f64
jnp_bitwise_or_bool
jnp_bitwise_or_bool_f64
jnp_bitwise_or_mixed_dtype
jnp_bitwise_or_mixed_dtype_f64
jnp_bitwise_or_broadcast
jnp_bitwise_or_broadcast_f64
bitwise_or_vmap_batching
bitwise_or_vmap_batching_f64
0.12.2
jnp.bitwise_right_shift BitShift jnp_bitwise_right_shift_signed
jnp_bitwise_right_shift_signed_f64
jnp_bitwise_right_shift_bool_bool
jnp_bitwise_right_shift_bool_bool_f64
bitwise_right_shift_vmap_batching
bitwise_right_shift_vmap_batching_f64
0.12.2
jnp.bitwise_xor BitwiseXor
Xor
jnp_bitwise_xor_int
jnp_bitwise_xor_int_f64
jnp_bitwise_xor_bool
jnp_bitwise_xor_bool_f64
jnp_bitwise_xor_mixed_dtype
jnp_bitwise_xor_mixed_dtype_f64
jnp_bitwise_xor_broadcast
jnp_bitwise_xor_broadcast_f64
bitwise_xor_vmap_batching
bitwise_xor_vmap_batching_f64
0.12.2
jnp.blackman BlackmanWindow jnp_blackman_basic 0.12.1
jnp.block Concat jnp_block_basic 0.12.1
jnp.ceil Ceil
Identity
jnp_ceil_basic
jnp_ceil_basic_f64
jnp_ceil_int_identity
jnp_ceil_int_identity_f64
ceil_vmap_batching
ceil_vmap_batching_f64
ceil_grad_issue_batch_diff_rules
0.12.2
jnp.choose Concat
Gather
jnp_choose_wrap 0.12.1
jnp.clip Clip
Max
Min
clip_i32_scalar_bounds
clip_i32_scalar_bounds_f64
clip_f32_scalar_bounds_no_upcast_f64_mode
clip_only_upper
clip_only_upper_f64
clip_only_lower
clip_only_lower_f64
clip_broadcast_bounds
clip_grad_issue_batch_diff_rules
clip_keyword_min_alias
clip_keyword_min_alias_f64
0.8.0
jnp.compress Compress jnp_compress_axis0
jnp_compress_axis0_f64
jnp_compress_axis1
jnp_compress_axis1_f64
compress_jvp_issue_batch_diff_rules
0.12.1
jnp.concatenate Concat concatenate_basic
concatenate_basic_f64
concatenate_mixed_dtypes
concatenate_mixed_dtypes_f64
concatenate_with_explicit_dtype
concatenate_with_explicit_dtype_casts_inputs
concatenate_abstract_middle_dim_dynamic
concatenate_abstract_middle_dim_dynamic_f64
concatenate_abstract_middle_dim
concatenate_abstract_middle_dim_f64
concatenate_tile_and_symbolic_dynamic
concatenate_tile_and_symbolic_dynamic_f64
concatenate_tile_and_symbolic
concatenate_tile_and_symbolic_f64
concatenate_grad_issue191
0.8.0
jnp.conj Identity jnp_conj_real
jnp_conj_real_f64
jnp_conj_complex64
jnp_conj_complex64_f64
conj_vmap_batching
conj_vmap_batching_f64
conj_grad_issue_batch_diff_rules
0.10.1
jnp.copysign Abs
Less
Neg
Where
jnp_copysign_basic
jnp_copysign_basic_f64
jnp_copysign_int_promote
jnp_copysign_int_promote_f64
jnp_copysign_broadcast
jnp_copysign_broadcast_f64
copysign_vmap_batching
copysign_vmap_batching_f64
0.12.2
jnp.corrcoef Div
MatMul
ReduceMean
jnp_corrcoef_basic 0.12.1
jnp.cos Cos
Sin
jnp_cos_basic
jnp_cos_basic_f64
jnp_cos_int_promote
jnp_cos_int_promote_f64
jnp_cos_f64_via_sin
cos_vmap_batching
cos_vmap_batching_f64
cos_grad_issue_batch_diff_rules
0.12.2
jnp.cosh Cosh jnp_cosh_basic
jnp_cosh_basic_f64
jnp_cosh_int_promote
jnp_cosh_int_promote_f64
cosh_vmap_batching
cosh_vmap_batching_f64
cosh_grad_issue_batch_diff_rules
0.12.2
jnp.cov MatMul
ReduceMean
jnp_cov_basic
jnp_cov_basic_f64
0.12.1
jnp.cross Concat
Mul
Sub
jnp_cross_vectors
jnp_cross_vectors_f64
0.12.1
jnp.cumprod CumProd jnp_cumprod_axis1
jnp_cumprod_axis1_f64
jnp_cumprod_axis_none_flatten
jnp_cumprod_axis_none_flatten_f64
jnp_cumprod_dtype_cast
jnp_cumprod_dtype_cast_f64
0.12.1
jnp.cumsum CumSum jnp_cumsum_axis1
jnp_cumsum_axis1_f64
jnp_cumsum_reverse_dtype
cumsum_axis2_i32
cumsum_axis2_i32_f64
cumsum_axis2_reverse_i32
cumsum_axis2_reverse_i32_f64
cumsum_vmap_batching
cumsum_vmap_batching_f64
0.8.0
jnp.delete Concat
Slice
jnp_delete_axis0_index1
jnp_delete_axis0_index1_f64
0.12.1
jnp.diag Gather
Mul
Reshape
jnp_diag_from_vector
jnp_diag_from_vector_f64
jnp_diag_from_matrix
jnp_diag_from_matrix_f64
jnp_diag_from_matrix_k1
jnp_diag_from_matrix_k1_f64
0.12.2
jnp.diagonal Gather
Reshape
jnp_diagonal_basic
jnp_diagonal_basic_f64
jnp_diagonal_offset_neg1
jnp_diagonal_offset_neg1_f64
0.12.2
jnp.divide Div jnp_divide_basic
jnp_divide_basic_f64
jnp_divide_int_promote
jnp_divide_int_promote_f64
jnp_divide_broadcast
jnp_divide_broadcast_f64
divide_vmap_batching
divide_vmap_batching_f64
divide_grad_issue_batch_diff_rules
0.12.2
jnp.dot MatMul jnp_dot_vector_vector
jnp_dot_vector_vector_f64
jnp_dot_matrix_vector
jnp_dot_matrix_vector_f64
jnp_dot_matrix_matrix
jnp_dot_matrix_matrix_f64
0.12.2
jnp.einsum Einsum einsum_vector_dot
einsum_vector_dot_f64
einsum_matrix_vector
einsum_matrix_vector_f64
einsum_matrix_matrix_dynamic
einsum_matrix_matrix_dynamic_f64
einsum_matrix_matrix
einsum_matrix_matrix_f64
einsum_transpose
einsum_transpose_f64
einsum_batch_transpose_dynamic
einsum_batch_transpose_dynamic_f64
einsum_batch_transpose
einsum_batch_transpose_f64
einsum_diag
einsum_diag_f64
einsum_sum_reduce
einsum_sum_reduce_f64
einsum_multi_operand
einsum_multi_operand_f64
einsum_attention_logits_orig_dynamic
einsum_attention_logits_orig_dynamic_f64
einsum_attention_logits_orig
einsum_attention_logits_orig_f64
einsum_attention_output_orig_dynamic
einsum_attention_output_orig_dynamic_f64
einsum_attention_output_orig
einsum_attention_output_orig_f64
einsum_attention_logits_batched_dynamic
einsum_attention_logits_batched_dynamic_f64
einsum_attention_logits_batched
einsum_attention_logits_batched_f64
einsum_attention_output_batched_dynamic
einsum_attention_output_batched_dynamic_f64
einsum_attention_output_batched
einsum_attention_output_batched_f64
einsum_ellipsis_rank_mismatch
einsum_ellipsis_rank_mismatch_f64
einsum_attention_logits_batched_rank_mismatch
einsum_attention_logits_batched_rank_mismatch_f64
einsum_grad_issue_batch_diff_rules
0.1.0
jnp.equal Equal jnp_equal_basic
jnp_equal_basic_f64
jnp_equal_broadcast
jnp_equal_broadcast_f64
jnp_equal_mixed_dtype
jnp_equal_mixed_dtype_f64
0.12.2
jnp.exp Exp jnp_exp_basic
jnp_exp_basic_f64
jnp_exp_int_promote
jnp_exp_int_promote_f64
exp_vmap_batching
exp_vmap_batching_f64
exp_grad_issue_batch_diff_rules
0.12.2
jnp.exp2 Pow jnp_exp2_basic
jnp_exp2_basic_f64
jnp_exp2_int_promote
jnp_exp2_int_promote_f64
exp2_vmap_batching
exp2_vmap_batching_f64
exp2_grad_issue_batch_diff_rules
0.12.2
jnp.expm1 Exp
Sub
jnp_expm1_basic
jnp_expm1_basic_f64
jnp_expm1_int_promote
jnp_expm1_int_promote_f64
expm1_vmap_batching
expm1_vmap_batching_f64
expm1_grad_issue_batch_diff_rules
0.12.2
jnp.eye EyeLike jnp_eye_square
jnp_eye_rect_k1
0.12.1
jnp.fabs Abs jnp_fabs_basic
jnp_fabs_basic_f64
jnp_fabs_int_promote
jnp_fabs_int_promote_f64
fabs_vmap_batching
fabs_vmap_batching_f64
fabs_grad_issue_batch_diff_rules
0.12.2
jnp.fft DFT jnp_fft_complex64
jnp_fft_complex128
0.10.1
jnp.fft_fftfreq Div
Range
jnp_fft_fftfreq_n8
jnp_fft_fftfreq_n8_f64
0.12.1
jnp.fft_fftshift Concat
Slice
jnp_fft_fftshift_1d
jnp_fft_fftshift_1d_f64
0.12.1
jnp.fft_ifftshift Concat
Slice
jnp_fft_ifftshift_1d
jnp_fft_ifftshift_1d_f64
0.12.1
jnp.fft_irfft DFT jnp_irfft_complex64 0.12.1
jnp.fft_rfftfreq Div
Range
jnp_fft_rfftfreq_n8
jnp_fft_rfftfreq_n8_f64
0.12.1
jnp.fill_diagonal ScatterND jnp_fill_diagonal_basic
jnp_fill_diagonal_basic_f64
0.12.1
jnp.floor Floor
Identity
jnp_floor_basic
jnp_floor_basic_f64
jnp_floor_int_identity
jnp_floor_int_identity_f64
floor_vmap_batching
floor_vmap_batching_f64
floor_grad_issue_batch_diff_rules
0.12.2
jnp.floor_divide Div
Floor
Where
jnp_floor_divide_float
jnp_floor_divide_float_f64
jnp_floor_divide_int_neg
jnp_floor_divide_int_neg_f64
jnp_floor_divide_broadcast
jnp_floor_divide_broadcast_f64
floor_divide_vmap_batching
floor_divide_vmap_batching_f64
0.12.2
jnp.fmod Div
Mod
Sub
jnp_fmod_float
jnp_fmod_float_f64
jnp_fmod_int
jnp_fmod_int_f64
jnp_fmod_broadcast
jnp_fmod_broadcast_f64
fmod_vmap_batching
fmod_vmap_batching_f64
0.12.2
jnp.full ConstantOfShape jnp_full_const_scalar
jnp_full_const_scalar_f64
jnp_full_expand_scalar_input
jnp_full_expand_scalar_input_f64
jnp_full_expand_row_broadcast
jnp_full_expand_row_broadcast_f64
0.12.1
jnp.gcd Equal
Loop
jnp_gcd_basic
jnp_gcd_basic_f64
0.12.1
jnp.geomspace Exp
Log
Range
jnp_geomspace_scalar
jnp_geomspace_scalar_f64
0.12.1
jnp.gradient Div
Slice
Sub
jnp_gradient_1d
jnp_gradient_1d_f64
0.12.1
jnp.greater Greater jnp_greater_basic
jnp_greater_basic_f64
jnp_greater_broadcast
jnp_greater_broadcast_f64
jnp_greater_mixed_dtype
jnp_greater_mixed_dtype_f64
greater_vmap_batching
greater_vmap_batching_f64
0.12.2
jnp.greater_equal GreaterOrEqual jnp_greater_equal_basic
jnp_greater_equal_basic_f64
jnp_greater_equal_broadcast
jnp_greater_equal_broadcast_f64
jnp_greater_equal_mixed_dtype
jnp_greater_equal_mixed_dtype_f64
greater_equal_vmap_batching
greater_equal_vmap_batching_f64
0.12.2
jnp.hamming HammingWindow jnp_hamming_basic 0.12.1
jnp.hanning HannWindow jnp_hanning_basic 0.12.1
jnp.histogram_bin_edges Range
ReduceMax
ReduceMin
jnp_histogram_bin_edges_basic
jnp_histogram_bin_edges_basic_f64
0.12.1
jnp.i0 Add
Exp
Mul
jnp_i0_basic
jnp_i0_basic_f64
0.12.1
jnp.ifft DFT jnp_ifft_complex64 0.10.1
jnp.invert BitwiseNot
Not
jnp_invert_bool
jnp_invert_bool_f64
jnp_invert_int
jnp_invert_int_f64
invert_vmap_batching
invert_vmap_batching_f64
0.12.2
jnp.isfinite IsInf
IsNaN
Not
Or
jnp_isfinite_basic
jnp_isfinite_basic_f64
0.12.2
jnp.lcm Div
Loop
Mul
jnp_lcm_basic
jnp_lcm_basic_f64
0.12.1
jnp.left_shift BitShift jnp_left_shift_int
jnp_left_shift_int_f64
jnp_left_shift_unsigned
jnp_left_shift_unsigned_f64
jnp_left_shift_mixed_bool_rhs
jnp_left_shift_mixed_bool_rhs_f64
jnp_left_shift_broadcast
jnp_left_shift_broadcast_f64
left_shift_vmap_batching
left_shift_vmap_batching_f64
0.12.2
jnp.less Less jnp_less_basic
jnp_less_basic_f64
jnp_less_broadcast
jnp_less_broadcast_f64
jnp_less_mixed_dtype
jnp_less_mixed_dtype_f64
less_vmap_batching
less_vmap_batching_f64
0.12.2
jnp.less_equal LessOrEqual jnp_less_equal_basic
jnp_less_equal_basic_f64
jnp_less_equal_broadcast
jnp_less_equal_broadcast_f64
jnp_less_equal_mixed_dtype
jnp_less_equal_mixed_dtype_f64
less_equal_vmap_batching
less_equal_vmap_batching_f64
0.12.2
jnp.linalg_cholesky Cholesky jnp_linalg_cholesky_basic 0.12.1
jnp.linalg_cond Div
SVD
jnp_linalg_cond_basic 0.12.1
jnp.linalg_det Det linalg_det_2x2
linalg_det_batched
0.12.1
jnp.linalg_eig Eig jnp_linalg_eig_basic 0.12.1
jnp.linalg_eigh Eigh jnp_linalg_eigh_basic 0.12.1
jnp.linalg_eigvals Eig jnp_linalg_eigvals_basic 0.12.1
jnp.linalg_eigvalsh Eigh jnp_linalg_eigvalsh_basic 0.12.1
jnp.linalg_lstsq MatMul
SVD
jnp_linalg_lstsq_basic 0.12.1
jnp.linalg_matrix_norm Mul
ReduceSum
Sqrt
jnp_linalg_matrix_norm_basic 0.12.1
jnp.linalg_matrix_power MatMul jnp_linalg_matrix_power_basic 0.12.1
jnp.linalg_matrix_rank ReduceSum
SVD
jnp_linalg_matrix_rank_basic 0.12.1
jnp.linalg_multi_dot MatMul jnp_linalg_multi_dot_basic 0.12.1
jnp.linalg_norm GlobalLpPool
Transpose
linalg_norm_global_fro
linalg_norm_global_default
linalg_norm_global_ord1
linalg_norm_global_ord1_nokeepdims
0.12.1
jnp.linalg_pinv MatMul
SVD
jnp_linalg_pinv_basic 0.12.1
jnp.linalg_qr QR jnp_linalg_qr_basic 0.12.1
jnp.linalg_svd SVD jnp_linalg_svd_basic 0.12.1
jnp.linalg_svdvals SVD jnp_linalg_svdvals_basic 0.12.1
jnp.linalg_tensordot MatMul
Transpose
jnp_linalg_tensordot_basic 0.12.1
jnp.linalg_trace Gather
ReduceSum
jnp_linalg_trace_basic 0.12.1
jnp.linalg_vecdot Mul
ReduceSum
jnp_linalg_vecdot_basic 0.12.1
jnp.linalg_vector_norm Mul
ReduceSum
Sqrt
jnp_linalg_vector_norm_basic 0.12.1
jnp.linspace Add
Range
linspace_static_basic
linspace_static_basic_f64
linspace_static_endpoint_false
linspace_static_endpoint_false_f64
linspace_static_int_inputs_default_dtype
linspace_static_int_inputs_default_dtype_f64
linspace_basic_f32
linspace_basic_f32_f64
linspace_endpoint_false_i32
linspace_endpoint_false_i32_f64
linspace_num_zero
linspace_num_zero_f64
linspace_num_one
linspace_num_one_f64
linspace_static_num_0
linspace_static_num_0_f64
linspace_static_num_1
linspace_static_num_1_f64
linspace_vmap_batching
linspace_vmap_batching_f64
0.5.2
jnp.log Log
ReduceLogSumExp
ReduceLogSum
jnp_log_basic
jnp_log_basic_f64
jnp_log_int_promote
jnp_log_int_promote_f64
log_vmap_batching
log_vmap_batching_f64
log_grad_issue_batch_diff_rules
0.12.2
jnp.logspace Pow
Range
jnp_logspace_basic 0.12.1
jnp.matmul MatMul matmul_1d
matmul_1d_f64
matmul_1d_2d
matmul_1d_2d_f64
matmul_2d
matmul_2d_f64
matmul_2d_1d
matmul_2d_1d_f64
matmul_3d
matmul_3d_f64
matmul_dynamic_dynamic
matmul_dynamic_dynamic_f64
matmul_dynamic
matmul_dynamic_f64
matmul_dynamic_a_dynamic
matmul_dynamic_a_dynamic_f64
matmul_dynamic_a
matmul_dynamic_a_f64
matmul_complex64
matmul_complex64_f64
matmul_vmap_batching
matmul_vmap_batching_f64
matmul_grad_issue_batch_diff_rules
0.1.0
jnp.matvec MatMul jnp_matvec_basic 0.12.1
jnp.max ReduceMax jnp_max_basic
jnp_max_basic_f64
jnp_max_axis
jnp_max_axis_f64
jnp_max_keepdims
jnp_max_keepdims_f64
max_vmap_batching
max_vmap_batching_f64
0.12.2
jnp.maximum Max jnp_maximum_basic
jnp_maximum_basic_f64
jnp_maximum_broadcast_scalar
jnp_maximum_broadcast_scalar_f64
0.12.2
jnp.mean ReduceMean basic_mean
basic_mean_f64
mean_with_axis
mean_with_axis_f64
mean_with_keepdims
mean_with_keepdims_f64
jnp_mean_basic
jnp_mean_basic_f64
jnp_mean_axis
jnp_mean_axis_f64
mean_grad_issue_batch_diff_rules
0.12.0
jnp.median ReduceMean
TopK
jnp_median_basic 0.12.1
jnp.min ReduceMin jnp_min_basic
jnp_min_basic_f64
jnp_min_axis
jnp_min_axis_f64
jnp_min_keepdims
jnp_min_keepdims_f64
min_vmap_batching
min_vmap_batching_f64
0.12.2
jnp.minimum Min jnp_minimum_basic
jnp_minimum_basic_f64
jnp_minimum_broadcast_scalar
jnp_minimum_broadcast_scalar_f64
0.12.2
jnp.moveaxis Transpose jnp_moveaxis_2d
jnp_moveaxis_2d_f64
jnp_moveaxis_3d_tuple
jnp_moveaxis_3d_tuple_f64
jnp_moveaxis_vmap_batching
jnp_moveaxis_vmap_batching_f64
0.12.2
jnp.nanargmax ArgMax
IsNaN
Where
jnp_nanargmax_basic 0.12.1
jnp.nanargmin ArgMin
IsNaN
Where
jnp_nanargmin_basic 0.12.1
jnp.nancumsum CumSum
IsNaN
Where
jnp_nancumsum_basic 0.12.1
jnp.nanmax IsNaN
ReduceMax
Where
jnp_nanmax_basic 0.12.1
jnp.nanmean IsNaN
ReduceSum
Where
jnp_nanmean_basic 0.12.1
jnp.nanmedian TopK
Where
jnp_nanmedian_basic 0.12.1
jnp.nanmin IsNaN
ReduceMin
Where
jnp_nanmin_basic 0.12.1
jnp.nanpercentile TopK
Where
jnp_nanpercentile_basic 0.12.1
jnp.nanprod IsNaN
ReduceProd
Where
jnp_nanprod_basic 0.12.1
jnp.nanquantile TopK
Where
jnp_nanquantile_basic 0.12.1
jnp.nanstd ReduceSum
Sqrt
jnp_nanstd_basic 0.12.1
jnp.nansum IsNaN
ReduceSum
Where
jnp_nansum_basic 0.12.1
jnp.nanvar Mul
ReduceSum
jnp_nanvar_basic 0.12.1
jnp.ndarray_at ScatterND jnp_ndarray_at_set_basic 0.12.1
jnp.ones ConstantOfShape jnp_ones_2x3
jnp_ones_2x3_f64
jnp_ones_bool
jnp_ones_bool_f64
0.12.2
jnp.outer Mul outer_vector
outer_vector_f64
outer
outer_f64
outer_vmap_batching
outer_vmap_batching_f64
outer_grad_issue_batch_diff_rules
0.10.0
jnp.pad Pad jnp_pad_constant_1d
jnp_pad_constant_1d_f64
jnp_pad_constant_2d_tuple
jnp_pad_constant_2d_tuple_f64
jnp_pad_constant_vmap_batching
jnp_pad_constant_vmap_batching_f64
0.12.2
jnp.partition Concat
TopK
jnp_partition_basic 0.12.1
jnp.percentile ReduceMean
TopK
jnp_percentile_basic 0.12.1
jnp.polyadd Add
Pad
jnp_polyadd_basic 0.12.1
jnp.polyder Mul
Slice
jnp_polyder_basic 0.12.1
jnp.polydiv Div
Sub
jnp_polydiv_basic 0.12.1
jnp.polyint Concat
Div
jnp_polyint_basic 0.12.1
jnp.polysub Pad
Sub
jnp_polysub_basic 0.12.1
jnp.pow Pow jnp_pow_vector
jnp_pow_vector_f64
pow_jnp_pow
pow_jnp_pow_f64
pow_vmap_batching
pow_vmap_batching_f64
pow_grad_issue_batch_diff_rules
0.8.0
jnp.power Pow jnp_power_vector
jnp_power_vector_f64
pow_jnp_power
pow_jnp_power_f64
power_vmap_batching
power_vmap_batching_f64
power_grad_issue_batch_diff_rules
0.8.0
jnp.prod ReduceProd basic_prod
basic_prod_f64
prod_with_axis
prod_with_axis_f64
prod_with_keepdims
prod_with_keepdims_f64
jnp_prod_basic
jnp_prod_basic_f64
jnp_prod_axis
jnp_prod_axis_f64
jnp_prod_keepdims
jnp_prod_keepdims_f64
prod_grad_issue_batch_diff_rules
prod_vmap_batching
prod_vmap_batching_f64
0.8.0
jnp.put ScatterND jnp_put_basic 0.12.1
jnp.put_along_axis ScatterND jnp_put_along_axis_basic 0.12.1
jnp.quantile ReduceMean
TopK
jnp_quantile_basic 0.12.1
jnp.ravel_multi_index Add
Mod
Mul
jnp_ravel_multi_index_wrap 0.12.1
jnp.reshape Flatten
Reshape
reshape_1
reshape_1_f64
reshape_2
reshape_2_f64
reshape_flatten_static_uses_flatten
reshape_flatten_static_uses_flatten_f64
reshape_3
reshape_3_f64
reshape_4_dynamic
reshape_4_dynamic_f64
reshape_4
reshape_4_f64
reshape_to_scalar
reshape_to_scalar_f64
reshape_from_scalar
reshape_from_scalar_f64
reshape_cnn_dynamic
reshape_cnn_dynamic_f64
reshape_cnn
reshape_cnn_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_basic
reshape_basic_f64
reshape_infer
reshape_infer_f64
reshape_symbolic_flatten_dynamic
reshape_symbolic_flatten_dynamic_f64
reshape_symbolic_flatten
reshape_symbolic_flatten_f64
reshape_vmap_batching_issue_144
reshape_vmap_batching_issue_144_f64
reshape_grad_issue_batch_diff_rules
0.1.0
jnp.rfft DFT jnp_rfft_float32 0.10.1
jnp.right_shift BitShift jnp_right_shift_signed_arithmetic
jnp_right_shift_signed_arithmetic_f64
jnp_right_shift_unsigned_logical
jnp_right_shift_unsigned_logical_f64
jnp_right_shift_mixed_bool_rhs
jnp_right_shift_mixed_bool_rhs_f64
jnp_right_shift_broadcast
jnp_right_shift_broadcast_f64
right_shift_vmap_batching
right_shift_vmap_batching_f64
0.12.2
jnp.roll Concat
Slice
jnp_roll_basic 0.12.1
jnp.rot90 ReverseSequence
Transpose
jnp_rot90_basic 0.12.1
jnp.select Where select_simple
select_simple_f64
select_broadcast
select_broadcast_f64
select_gpt2_attention_mask_dynamic
select_gpt2_attention_mask_dynamic_f64
select_gpt2_attention_mask
select_gpt2_attention_mask_f64
select_basic
select_basic_f64
select_vmap_batching
select_vmap_batching_f64
select_grad_issue_batch_diff_rules
0.7.1
jnp.shape Shape shape_basic
shape_basic_f64
shape_dynamic_dynamic
shape_dynamic_dynamic_f64
shape_dynamic
shape_dynamic_f64
shape_vmap_batching
shape_vmap_batching_f64
0.4.0
jnp.sign Sign jnp_sign_basic
jnp_sign_basic_f64
jnp_sign_int
jnp_sign_int_f64
sign_vmap_batching
sign_vmap_batching_f64
sign_grad_issue_batch_diff_rules
0.12.2
jnp.sin Sin jnp_sin_basic
jnp_sin_basic_f64
jnp_sin_int_promote
jnp_sin_int_promote_f64
sin_vmap_batching
sin_vmap_batching_f64
sin_grad_issue_batch_diff_rules
0.12.2
jnp.sinc Div
Sin
jnp_sinc_basic 0.12.1
jnp.sinh Sinh jnp_sinh_basic
jnp_sinh_basic_f64
jnp_sinh_int_promote
jnp_sinh_int_promote_f64
sinh_vmap_batching
sinh_vmap_batching_f64
sinh_grad_issue_batch_diff_rules
0.12.2
jnp.size Size jnp_size_all
jnp_size_all_f64
jnp_size_axis
jnp_size_axis_f64
jnp_size_axis_tuple
jnp_size_axis_tuple_f64
jnp_size_dynamic_dynamic
jnp_size_dynamic_dynamic_f64
jnp_size_dynamic
jnp_size_dynamic_f64
0.12.1
jnp.sort TopK sort_1d
sort_1d_f64
sort_2d_axis0
sort_2d_axis0_f64
sort_basic
sort_basic_f64
sort_vmap_batching
sort_vmap_batching_f64
0.5.2
jnp.split Split split_by_sections
split_by_sections_f64
split_by_indices
split_by_indices_f64
split_by_indices_symbolic_dynamic
split_by_indices_symbolic_dynamic_f64
split_by_indices_symbolic
split_by_indices_symbolic_f64
split_sections
split_sections_f64
split_indices_numpy
split_indices_numpy_f64
split_grad_issue_batch_diff_rules
0.7.2
jnp.sqrt ReduceL2
Sqrt
jnp_sqrt_basic
jnp_sqrt_basic_f64
jnp_sqrt_int_promote
jnp_sqrt_int_promote_f64
jnp_sqrt_reduce_sum_square_axis1
jnp_sqrt_reduce_sum_square_axis1_f64
sqrt_vmap_batching
sqrt_vmap_batching_f64
sqrt_grad_issue_batch_diff_rules
0.12.2
jnp.squeeze Squeeze squeeze_single_dim
squeeze_single_dim_f64
squeeze_multiple_dims
squeeze_multiple_dims_f64
squeeze_vit_output
squeeze_vit_output_f64
squeeze_dynamic_batch_dynamic
squeeze_dynamic_batch_dynamic_f64
squeeze_dynamic_batch
squeeze_dynamic_batch_f64
squeeze_all_dims
squeeze_all_dims_f64
squeeze_negative_axis
squeeze_negative_axis_f64
squeeze_negative_axis_tuple
squeeze_negative_axis_tuple_f64
squeeze_dynamic_and_negative_axis_dynamic
squeeze_dynamic_and_negative_axis_dynamic_f64
squeeze_dynamic_and_negative_axis
squeeze_dynamic_and_negative_axis_f64
squeeze_vmap_batching
squeeze_vmap_batching_f64
squeeze_grad_issue_batch_diff_rules
0.1.0
jnp.stack Concat
Unsqueeze
stack_axis_0
stack_axis_0_f64
stack_axis_1
stack_axis_1_f64
stack_negative_axis
stack_negative_axis_f64
stack_scalars
stack_scalars_f64
jnp_stack_axis0
jnp_stack_axis0_f64
jnp_stack_axis1
jnp_stack_axis1_f64
jnp_stack_negative_axis
jnp_stack_negative_axis_f64
jnp_stack_scalars
jnp_stack_scalars_f64
stack_grad_issue_batch_diff_rules
0.8.0
jnp.std ReduceMean
Sqrt
jnp_std_basic 0.12.1
jnp.sum ReduceSum jnp_sum_basic
jnp_sum_basic_f64
jnp_sum_axis
jnp_sum_axis_f64
jnp_sum_keepdims
jnp_sum_keepdims_f64
jnp_sum_int8_promote
sum_vmap_batching
sum_vmap_batching_f64
sum_grad_issue_batch_diff_rules
0.12.2
jnp.take Gather take_data_dependent_indices
take_basic_axis1
take_basic_axis1_f64
take_vmap_batching
take_vmap_batching_f64
take_grad_issue_batch_diff_rules
0.7.0
jnp.tan Tan jnp_tan_basic
jnp_tan_basic_f64
jnp_tan_int_promote
jnp_tan_int_promote_f64
tan_vmap_batching
tan_vmap_batching_f64
tan_grad_issue_batch_diff_rules
0.12.2
jnp.tanh Tanh jnp_tanh_basic
jnp_tanh_basic_f64
jnp_tanh_int_promote
jnp_tanh_int_promote_f64
tanh_vmap_batching
tanh_vmap_batching_f64
tanh_grad_issue_batch_diff_rules
0.12.2
jnp.tensordot MatMul
Transpose
jnp_tensordot_basic 0.12.1
jnp.tile Tile tile_repeats
tile_repeats_f64
tile_a
tile_a_f64
tile_b
tile_b_f64
tile_c
tile_c_f64
tile_d
tile_d_f64
tile_dynamic_input_static
tile_dynamic_input_static_f64
tile_dynamic_input_dynamic
tile_dynamic_input_dynamic_f64
tile_dynamic_input
tile_dynamic_input_f64
tile_pad
tile_pad_f64
tile_param_symbolic_dynamic
tile_param_symbolic_dynamic_f64
tile_param_symbolic
tile_param_symbolic_f64
tile_with_symbolic_repeats_static
tile_with_symbolic_repeats_static_f64
tile_with_symbolic_repeats_dynamic
tile_with_symbolic_repeats_dynamic_f64
tile_with_symbolic_repeats
tile_with_symbolic_repeats_f64
jnp_tile_basic
jnp_tile_basic_f64
jnp_tile_scalar_repeats
jnp_tile_scalar_repeats_f64
jnp_tile_pad_rank
jnp_tile_pad_rank_f64
jnp_tile_symbolic_dynamic
jnp_tile_symbolic_dynamic_f64
jnp_tile_symbolic
jnp_tile_symbolic_f64
tile_vmap_batching
tile_vmap_batching_f64
tile_grad_issue_batch_diff_rules
0.8.0
jnp.trace Gather
ReduceSum
jnp_trace_basic 0.12.1
jnp.transpose Transpose transpose_basic
transpose_basic_f64
transpose_reverse_default
transpose_reverse_default_f64
transpose_high_dim
transpose_high_dim_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_no_axes
transpose_no_axes_f64
transpose_reverse
transpose_reverse_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_vmap_batching
transpose_vmap_batching_f64
transpose_grad_issue_batch_diff_rules
0.1.0
jnp.trapezoid Add
ReduceSum
Slice
jnp_trapezoid_basic 0.12.1
jnp.tri LessOrEqual
Range
jnp_tri_basic 0.12.1
jnp.trilu Trilu jnp_triu_basic
jnp_triu_basic_f64
jnp_tril_negative_k
jnp_tril_negative_k_f64
jnp_triu_symbolic_batch_dynamic
jnp_triu_symbolic_batch_dynamic_f64
jnp_triu_symbolic_batch
jnp_triu_symbolic_batch_f64
jnp_trilu_vmap_batching
jnp_trilu_vmap_batching_f64
0.12.1
jnp.unique Unique jnp_unique_f32_size_fill
jnp_unique_i32_size_fill
jnp_unique_symbolic_size_fill_dynamic
jnp_unique_symbolic_size_fill
0.12.1
jnp.unravel_index Concat
Div
Mod
jnp_unravel_index_basic
jnp_unravel_index_basic_f64
0.12.1
jnp.unstack Split
Squeeze
unstack_axis_0
unstack_axis_0_f64
unstack_axis_1
unstack_axis_1_f64
unstack_negative_axis
unstack_negative_axis_f64
unstack_vmap_batching
unstack_vmap_batching_f64
0.7.1
jnp.var Mul
ReduceMean
jnp_var_basic 0.12.1
jnp.vdot Mul
ReduceSum
jnp_vdot_basic 0.12.1
jnp.vecdot Mul
ReduceSum
jnp_vecdot_basic 0.12.1
jnp.vecmat MatMul jnp_vecmat_basic 0.12.1
jnp.vectorize Add jnp_vectorize_basic 0.12.1
jnp.where Where where_simple
where_simple_f64
where_broadcast
where_broadcast_f64
where_gpt_mask_scores_literal_else_dynamic
where_gpt_mask_scores_literal_else_dynamic_f64
where_gpt_mask_scores_literal_else
where_gpt_mask_scores_literal_else_f64
where_multidim_condition_scalar_branches_broadcast
where_multidim_condition_scalar_branches_broadcast_f64
where_A
where_A_f64
where_B
where_B_f64
where_gpt_mask_scores_scalar_else_dynamic
where_gpt_mask_scores_scalar_else_dynamic_f64
where_gpt_mask_scores_scalar_else
where_gpt_mask_scores_scalar_else_f64
where_int_condition_cast
where_int_condition_cast_f64
where_literal_else_pyfloat
where_literal_else_pyfloat_f64
where_jax_int_literals_broadcast_f64_mode
where_dtype_mismatch_f64_vs_i32_promote
jnp_where_basic
jnp_where_basic_f64
jnp_where_broadcast
jnp_where_broadcast_f64
jnp_where_scalar_else
jnp_where_scalar_else_f64
where_vmap_batching
where_vmap_batching_f64
where_grad_issue_batch_diff_rules
0.8.0
jnp.zeros ConstantOfShape jnp_zeros_2x3
jnp_zeros_2x3_f64
jnp_zeros_int32
jnp_zeros_int32_f64
0.12.2
lax.abs Abs abs
abs_f64
0.5.0
lax.acos Acos acos 0.12.1
lax.acosh Acosh acosh 0.12.1
lax.add Add add
add_f64
add_const
add_const_f64
add_complex64
add_complex64_f64
0.2.0
lax.add_any Add
Sum
add_any_via_jvp_on_mul
add_any_via_jvp_on_mul_f64
0.8.0
lax.and And
BitwiseAnd
and_bool
and_bool_f64
and_int
and_int_f64
0.6.5
lax.approx_top_k TopK approx_max_k_matrix
approx_max_k_matrix_f64
approx_min_k_axis0
approx_min_k_axis0_f64
0.12.1
lax.argmax ArgMax argmax_float_axis0
argmax_float_axis0_f64
argmax_float_axis1
argmax_float_axis1_f64
argmax_boolean_input_axis0_specific_values
argmax_boolean_input_axis0_specific_values_f64
argmax_boolean_input_axis1_specific_values
argmax_boolean_input_axis1_specific_values_f64
argmax_boolean_random_input_axis0
argmax_boolean_random_input_axis0_f64
0.2.0
lax.argmin ArgMin argmin_test1
argmin_test1_f64
argmin_test2
argmin_test2_f64
0.2.0
lax.asin Asin asin 0.12.1
lax.asinh Asinh asinh 0.12.1
lax.atan Atan atan 0.12.1
lax.atan2 Add
Atan
Div
Equal
Greater
Less
Or
Sub
Where
atan2_quadrants_and_zero
atan2_quadrants_and_zero_f64
atan2_broadcast
0.12.1
lax.atanh Atanh atanh 0.12.1
lax.bessel_i0e Abs
Exp
Where
bessel_i0e_basic
bessel_i0e_basic_f64
0.12.1
lax.bessel_i1e Abs
Sign
Where
bessel_i1e_basic
bessel_i1e_basic_f64
0.12.1
lax.betainc Exp
Pow
Where
betainc_basic
betainc_edge_x
0.12.1
lax.bitcast_convert_type BitCast bitcast_scalar_f32_to_i32
bitcast_scalar_f32_to_i32_f64
bitcast_tensor_i32_to_f32
bitcast_tensor_i32_to_f32_f64
0.7.2
lax.bitwise_not BitwiseNot
Not
bitwise_not_bool
bitwise_not_bool_f64
bitwise_not_i32
bitwise_not_i32_f64
0.7.5
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim
broadcast_in_dim_f64
broadcast_in_dim_2d_to_3d
broadcast_in_dim_2d_to_3d_f64
broadcast_in_dim_scalar
broadcast_in_dim_scalar_f64
broadcast_in_dim_batch_dynamic
broadcast_in_dim_batch_dynamic_f64
broadcast_in_dim_batch
broadcast_in_dim_batch_f64
broadcast_in_dim_dynamic_B_dynamic
broadcast_in_dim_dynamic_B_dynamic_f64
broadcast_in_dim_dynamic_B
broadcast_in_dim_dynamic_B_f64
0.2.0
lax.cbrt Abs
Mul
Pow
Sign
cbrt
cbrt_f64
0.12.1
lax.ceil Ceil ceil
ceil_f64
0.12.1
lax.cholesky Concat
Div
Gather
Mul
ScatterND
Sqrt
Squeeze
Sub
Unsqueeze
cholesky_spd_3x3
cholesky_spd_3x3_f64
cholesky_diagonal
cholesky_diagonal_f64
cholesky_batched_2x3x3
cholesky_batched_2x3x3_f64
0.12.1
lax.cholesky_update Add
Div
Mul
ScatterND
Sqrt
Sub
cholesky_update_upper_3x3
cholesky_update_upper_3x3_f64
cholesky_update_identity
cholesky_update_identity_f64
0.12.1
lax.clamp Max
Min
clamp_i32_scalar_bounds
clamp_i32_scalar_bounds_f64
clamp_scalar_float_bounds_match_x
clamp_scalar_float_bounds_match_x_f64
clamp_vector_bounds_match
clamp_pyint_bounds_promote_to_x_dtype
clamp_pyint_bounds_promote_to_x_dtype_f64
0.7.5
lax.clz BitShift
BitwiseAnd
clz_i32
clz_i32_f64
clz_u8
clz_u8_f64
0.12.1
lax.complex Concat
Unsqueeze
complex_scalar
complex_scalar_f64
complex_array
complex_array_f64
complex_double_precision
0.12.2
lax.concatenate Cast
Concat
concatenate
concatenate_f64
concatenate_axis1_dynamic
concatenate_axis1_dynamic_f64
concatenate_axis1
concatenate_axis1_f64
concatenate_axis0
concatenate_axis0_f64
concatenate_3d
concatenate_3d_f64
concatenate_internal_int32_then_cast_to_f32_zeroarg
0.2.0
lax.cond If cond_scalar
cond_scalar_f64
cond_multiple_operands_in_tuple
cond_multiple_operands_in_tuple_f64
cond_my_new_complex_scenario
cond_my_new_complex_scenario_f64
cond_nested_conditional
cond_nested_conditional_f64
cond_variables
cond_variables_f64
cond_internal_constant_f64
cond_passthrough_identity
cond_passthrough_identity_f64
cond_with_scatter
cond_with_scatter_f64
0.5.1
lax.conj Identity conj_real
conj_real_f64
conj_complex64
conj_complex64_f64
0.10.1
lax.conv Conv conv
conv2
conv_nchw
conv_nhwc
conv_general_dilated_nhwc_output
conv_complex64
conv_complex64_nhwc
conv_complex128_grouped
0.2.0
lax.convert_element_type Cast convert_element_type
convert_element_type_f64
0.2.0
lax.copy (Handles the JAX primitive lax.copy_p. The public API jax.lax.copy is deprecated, but the primitive still appears in transformed jaxprs.) Identity copy_float32_array
copy_int64_scalar
lax.cos Cos cos
cos_f64
0.4.4
lax.cosh Cosh cosh
cosh_f64
0.4.4
lax.cumlogsumexp CumSum
Exp
Log
cumlogsumexp_axis1
cumlogsumexp_axis1_f64
cumlogsumexp_reverse_last_axis
cumlogsumexp_reverse_last_axis_f64
0.12.1
lax.cummax MaxPool
Reshape
Transpose
cummax_axis1
cummax_axis1_f64
cummax_reverse_last_axis
cummax_reverse_last_axis_f64
0.12.1
lax.cummin MaxPool
Neg
Reshape
Transpose
cummin_axis1
cummin_axis1_f64
cummin_reverse_last_axis
cummin_reverse_last_axis_f64
0.12.1
lax.cumprod CumProd cumprod_i32_axis2
cumprod_i32_axis2_f64
cumprod_f32_axism1_reverse
cumprod_f32_axism1_reverse_f64
0.12.1
lax.cumsum CumSum cumsum_i32_axis2
cumsum_i32_axis2_f64
cumsum_f32_axism1_reverse
cumsum_f32_axism1_reverse_f64
0.7.4
lax.custom_linear_solve custom_linear_solve_via_matvec
custom_linear_solve_via_matvec_f64
0.12.1
lax.device_put Identity device_put_array
device_put_array_f64
device_put_scalar
device_put_scalar_f64
0.4.0
lax.digamma Cos
Log
Sin
Where
digamma_positive
digamma_mixed
0.12.1
lax.div Div
LpNormalization
Mean
div
div_f64
div_const
div_const_f64
div_add_half_fuses_to_mean
div_add_half_fuses_to_mean_symbolic_dynamic
div_add_half_fuses_to_mean_symbolic
div_add_third_no_mean
div_add_half_f64_no_mean
div_lpnorm_l2_axis1
div_lpnorm_l1_axis1
div_lpnorm_l2_axis2
div_sqrt_of_norm_no_lpnorm_fusion
div_complex64
div_complex64_f64
0.2.0
lax.dot_general Einsum
MatMul/Gemm
dot_contract_nm
dot_contract_nm_f64
dot_contract_min
dot_contract_min_f64
dot_general
dot_general_f64
dot_general_lhs1_rhs1
dot_general_lhs1_rhs1_f64
dot_double_contract
dot_double_contract_f64
dot_batched_double_contract
dot_batched_double_contract_f64
dot_highrank_batch
dot_highrank_batch_f64
dot_contract_inner_lhs_with_middle_rhs
dot_contract_inner_lhs_with_middle_rhs_f64
dot_outer_product
dot_outer_product_f64
dot_full_contract_scalar
dot_full_contract_scalar_f64
dot_general_complex_matmul
dot_general_complex_matmul_f64
0.2.0
lax.dynamic_slice Slice dynamic_slice_test1
dynamic_slice_test1_f64
dynamic_slice_2d
dynamic_slice_2d_f64
dynamic_slice_3d
dynamic_slice_3d_f64
dynamic_slice_vit_like_dynamic
dynamic_slice_vit_like_dynamic_f64
dynamic_slice_vit_like
dynamic_slice_vit_like_f64
0.1.0
lax.dynamic_update_slice ScatterND
TensorScatter
dus_1d_scalar_update
dus_1d_scalar_update_f64
dus_1d_block_update
dus_1d_block_update_f64
dus_2d_block_update
dus_2d_block_update_f64
dus_3d_block_update
dus_3d_block_update_f64
dus_4d_block_update
dus_4d_block_update_f64
dus_tensorscatter_axis1_opset24
0.8.1
lax.eig Concat
Gather
Identity
Reshape
Unsqueeze
eig_1x1_values_only
eig_1x1_values_only_f64
eig_1x1_left_only
eig_1x1_left_only_f64
eig_1x1_right_only
eig_1x1_right_only_f64
eig_1x1_full
eig_1x1_full_f64
eig_2x2_values_only_real
eig_2x2_values_only_real_f64
eig_2x2_values_only_complex128
0.12.1
lax.eigh Add
Cast
Concat
Div
Gather
Identity
LessOrEqual
Mul
Reshape
Slice
Sqrt
Sub
Where
eigh_1x1
eigh_1x1_f64
eigh_2x2_lower_true
eigh_2x2_lower_true_f64
eigh_2x2_lower_false
eigh_2x2_lower_false_f64
eigh_2x2_subset_top1
eigh_2x2_subset_top1_f64
0.12.1
lax.eq Equal eq
eq_f64
0.2.0
lax.erf Erf erf 0.4.4
lax.erf_inv Erf
Exp
Log
Sqrt
erf_inv_midrange
erf_inv_matrix
0.12.1
lax.erfc Erf
Sub
erfc 0.12.1
lax.exp Exp exp
exp_f64
0.2.0
lax.exp2 Pow exp2
exp2_f64
0.12.1
lax.expm1 Exp
Sub
expm1
expm1_f64
0.12.1
lax.fft DFT fft_complex64_1d
fft_complex64_len8
fft_complex64_batch
fft_complex128_1d
fft_complex128_len8
ifft_complex64_1d
ifft_complex64_len8
ifft_complex128_batch
rfft_real32_1d
rfft_real64_len8
irfft_complex64_1d
irfft_complex128_len8
0.10.1
lax.floor Floor floor
floor_f64
0.12.1
lax.fori_loop Loop fori_loop_counter
fori_loop_counter_f64
fori_loop_zero
fori_loop_zero_f64
fori_loop_vector
fori_loop_vector_f64
fori_loop_example
fori_loop_example_f64
fori_loop_test
fori_loop_test_f64
0.5.1
lax.gather GatherND gather_trig_where_pipeline_f64_indices_i64
gather_trig_where_pipeline_f64_indices_i32
gather_f64_data_i64_indices_output_is_f64
gather_f64_data_i32_indices_cast_and_output_is_f64
gather_static
gather_static_f64
gather_dynamic_batch_simple_index_dynamic
gather_dynamic_batch_simple_index_dynamic_f64
gather_dynamic_batch_simple_index
gather_dynamic_batch_simple_index_f64
0.2.0
lax.greater_equal GreaterOrEqual greater_equal
greater_equal_f64
0.7.5
lax.gt Greater gt
gt_f64
0.2.0
lax.hessenberg Div
MatMul
ScatterND
Sqrt
Where
hessenberg_square_4x4
hessenberg_square_4x4_f64
hessenberg_diagonal
hessenberg_diagonal_f64
0.12.1
lax.householder_product MatMul
Mul
ScatterND
Sub
Transpose
householder_product_basic
householder_product_basic_f64
householder_product_k0
householder_product_k0_f64
0.12.1
lax.igamma Exp
Pow
Where
igamma_basic
igamma_zero_x
0.12.1
lax.igamma_grad_a Add
Div
Sub
Where
igamma_grad_a_basic 0.12.1
lax.igammac Sub
Where
igammac_basic
igammac_zero_x
0.12.1
lax.imag Mul imag_complex64_input
imag_complex64_input_f64
0.10.2
lax.integer_pow Pow
Reciprocal
integer_pow
integer_pow_f64
integer_pow_reciprocal
integer_pow_reciprocal_f64
0.2.0
lax.iota Range iota_int32
iota_int32_f64
iota_float32
iota_float32_f64
broadcasted_iota
broadcasted_iota_f64
0.5.0
lax.is_finite IsInf
IsNaN
Not
Or
is_finite_vec
is_finite_vec_f64
0.12.1
lax.less_equal LessOrEqual less_equal
less_equal_f64
0.7.5
lax.lgamma Log
Sin
Where
lgamma_positive
lgamma_positive_f64
lgamma_negative_noninteger
lgamma_negative_noninteger_f64
0.12.1
lax.log Log
ReduceLogSumExp
ReduceLogSum
log
log_f64
log_of_reduce_sum_axis1
log_of_reduce_sum_axis1_f64
log_of_reduce_sum_exp_axis1
log_of_reduce_sum_exp_axis1_f64
0.2.0
lax.log1p Add
Log
log1p
log1p_f64
0.11.0
lax.logistic Sigmoid lax_logistic_basic
lax_logistic_basic_f64
0.7.2
lax.lt Less lt
lt_f64
0.2.0
lax.lu Abs
ArgMax
Div
Mul
ScatterElements
ScatterND
Sub
lu_square_3x3
lu_square_3x3_f64
lu_rectangular_4x3
lu_rectangular_4x3_f64
0.12.1
lax.lu_pivots_to_permutation Concat
Gather
ScatterElements
Squeeze
Unsqueeze
lu_pivots_to_permutation_basic
lu_pivots_to_permutation_basic_f64
lu_pivots_to_permutation_identity
lu_pivots_to_permutation_identity_f64
lu_pivots_to_permutation_batched
lu_pivots_to_permutation_batched_f64
0.12.1
lax.max Max max
max_f64
0.2.0
lax.min Min min_test1
min_test1_f64
0.1.0
lax.mul Mul mul_test1
mul_test1_f64
mul_test2
mul_test2_f64
mul_pyfloat_promotes_to_array_dtype_f64
mul_scalar_broadcast_promote_to_f64
mul_complex128
mul_complex64
0.1.0
lax.ne Equal
Not
ne
ne_f64
0.2.0
lax.neg Neg neg
neg_f64
0.2.0
lax.nextafter IsNaN
Log
Pow
Sign
Where
nextafter_vector
nextafter_special_values
0.12.1
lax.optimization_barrier Identity optimization_barrier_single
optimization_barrier_single_f64
optimization_barrier_tuple
optimization_barrier_tuple_f64
0.12.1
lax.or BitwiseOr
Or
or_bool_vec
or_bool_vec_f64
or_int_vec
or_int_vec_f64
0.7.2
lax.pad Pad pad_const_1d
pad_const_1d_f64
pad_const_2d
pad_const_2d_f64
pad_const_2d_cval
pad_const_2d_cval_f64
pad_inside_scan_smoke_f64
pad_inside_nested_scan_smoke_f64
0.8.0
lax.pjit pjit_inline_mul
pjit_inline_mul_f64
pjit_inline_tuple
pjit_inline_tuple_f64
0.1.0
lax.polygamma Exp
Floor
Pow
Round
Where
polygamma_orders
polygamma_zero_order
0.12.1
lax.population_count BitShift
BitwiseAnd
population_count_i32
population_count_i32_f64
population_count_u8
population_count_u8_f64
0.12.1
lax.pow Pow pow_basic
pow_basic_f64
pow_lax
pow_lax_f64
0.8.2
lax.qr MatMul
Mul
ScatterND
Sqrt
Sub
qr_reduced_tall
qr_reduced_tall_f64
qr_reduced_wide
qr_reduced_wide_f64
qr_full_tall
qr_full_tall_f64
qr_full_wide
qr_full_wide_f64
0.12.1
lax.real Identity real_complex64_input
real_complex64_input_f64
0.10.2
lax.reduce ReduceMax
ReduceMin
reduce_max_lambda
reduce_max_lambda_f64
reduce_min_lambda
reduce_min_lambda_f64
0.12.1
lax.reduce_and ReduceMin reduce_and_all_true
reduce_and_all_true_f64
reduce_and_one_false
reduce_and_one_false_f64
reduce_and_keepdims
reduce_and_keepdims_f64
0.6.1
lax.reduce_max ReduceMax reduce_max
reduce_max_f64
reduce_max_allaxes
reduce_max_allaxes_f64
reduce_max_axes_input
reduce_max_axes_input_f64
reduce_max_keepdims
reduce_max_keepdims_f64
0.2.0
lax.reduce_min ReduceMin reduce_min
reduce_min_f64
reduce_min_allaxes
reduce_min_allaxes_f64
reduce_min_keepdims
reduce_min_keepdims_f64
0.2.0
lax.reduce_or ReduceMax reduce_or_all_false
reduce_or_all_false_f64
reduce_or_one_true
reduce_or_one_true_f64
reduce_or_keepdims
reduce_or_keepdims_f64
0.6.1
lax.reduce_precision Floor
Log
Pow
Round
Where
reduce_precision_mantissa10
reduce_precision_underflow_overflow
0.12.1
lax.reduce_prod ReduceProd reduce_prod
reduce_prod_f64
reduce_prod_allaxes
reduce_prod_allaxes_f64
reduce_prod_dtype
reduce_prod_dtype_f64
reduce_prod_keepdims
reduce_prod_keepdims_f64
0.6.1
lax.reduce_sum ReduceL1
ReduceSumSquare
ReduceSum
reduce_sum
reduce_sum_f64
reduce_sum_allaxes
reduce_sum_allaxes_f64
reduce_sum_dtype
reduce_sum_dtype_f64
reduce_sum_keepdims
reduce_sum_keepdims_f64
reduce_sum_of_abs_axis1
reduce_sum_of_abs_axis1_f64
reduce_sum_of_square_axis1
reduce_sum_of_square_axis1_f64
0.2.0
lax.reduce_window Conv
MaxPool
Neg
reduce_window_add_lambda
reduce_window_max_lambda
reduce_window_min_lambda_stride
0.12.1
lax.reduce_window_max MaxPool reduce_window_max_primitive 0.12.2
lax.reduce_window_sum Conv
LpPool
reduce_window_sum_valid
reduce_window_sum_same_padding
reduce_window_sum_stride_dilate
reduce_window_sum_int32
reduce_window_sum_base_dilation
reduce_window_sum_abs_lppool_opset23
reduce_window_sum_abs_lppool_dilated_opset23
0.10.1
lax.reduce_xor Mod
ReduceSum
reduce_xor_all_false
reduce_xor_all_false_f64
reduce_xor_one_true
reduce_xor_one_true_f64
reduce_xor_two_true
reduce_xor_two_true_f64
reduce_xor_keepdims
reduce_xor_keepdims_f64
0.6.1
lax.rem Div
Mod
rem_int
rem_int_f64
rem_float
rem_float_f64
rem_int_neg
rem_int_neg_f64
rem_float_neg
rem_float_neg_f64
0.6.5
lax.remat2 remat2_scalar_sin_chain
remat2_scalar_sin_chain_f64
remat2_tuple_passthrough
remat2_tuple_passthrough_f64
0.6.5
lax.reshape Reshape reshape_after_transpose_folds_const_shape
reshape_after_transpose_folds_const_shape_f64
reshape_flatten_trailing_folds_const_shape
reshape_flatten_trailing_folds_const_shape_f64
reshape
reshape_f64
reshape_valid_squeeze_middle_dim_from_problematic_source
reshape_valid_squeeze_middle_dim_from_problematic_source_f64
reshape_valid_flatten_trailing
reshape_valid_flatten_trailing_f64
reshape_with_target_shape_from_symbolic_dim_computation
reshape_with_target_shape_from_symbolic_dim_computation_f64
reshape_with_inferred_dimension_from_input_dynamic_dynamic
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64
reshape_with_inferred_dimension_from_input_dynamic
reshape_with_inferred_dimension_from_input_dynamic_f64
reshape_with_inferred_dimension_from_input
reshape_with_inferred_dimension_from_input_f64
reshape_merge_symbolic_with_static_and_check_name_dynamic
reshape_merge_symbolic_with_static_and_check_name
0.2.0
lax.rev Gather
Range
rev_vector
rev_vector_f64
rev_matrix_axes01
rev_matrix_axes01_f64
0.7.5
lax.rng_bit_generator Cast
Identity
RandomUniform
rng_bit_generator_u32 0.12.1
lax.rng_uniform Add
Mul
RandomUniform
Sub
rng_uniform_f32 0.12.1
lax.round Round round
round_f64
0.12.1
lax.rsqrt Div
Sqrt
rsqrt
rsqrt_f64
0.10.2
lax.scan Loop scan_identity_slice_helper
scan_identity_slice_helper_f64
scan_cumsum
scan_cumsum_f64
scan_carry_only
scan_carry_only_f64
scan_multiple_sequences
scan_multiple_sequences_f64
scan_multiple_carry
scan_multiple_carry_f64
scan_matrix_carry_multidim_xs
scan_matrix_carry_multidim_xs_f64
scan_no_xs
scan_no_xs_f64
scan_fn
scan_fn_f64
scan_jit_no_xs
scan_jit_no_xs_f64
scan_captured_scalar
scan_captured_scalar_f64
scan_rank0_sequence_vectorized
scan_rank0_sequence_vectorized_f64
scan_two_diff_lengths
scan_two_diff_lengths_f64
scan_two_diff_lengths_broadcast
scan_two_diff_lengths_broadcast_f64
scan_two_diff_lengths_with_broadcast
scan_nested_len_mismatch
scan_nested_len_mismatch_f64
scan_captured_scalar_with_xs
scan_captured_vector_with_xs_f64
0.5.1
lax.scatter NonZero
ScatterElements
ScatterND
Scatter
scatter_set_axis0
scatter_set_axis0_f64
scatter_set_middle
scatter_set_middle_f64
scatter_set_single
scatter_set_single_f64
scatter_set_vector
scatter_set_vector_f64
scatter_elements_set_vector_promise_in_bounds
scatter_elements_set_vector_promise_in_bounds_f64
scatter_correct_axis_determination
scatter_correct_axis_determination_f64
scatter_updates_slice_needed_axis0
scatter_updates_slice_needed_axis0_f64
scatter_from_user_warning_shapes_valid_jax
scatter_from_user_warning_shapes_valid_jax_f64
scatter_user_error_scenario_precise
scatter_user_error_scenario_precise_f64
scatter_window_update_f64
scatter_window_update_depth3_shapes_ok
scatter_static_slice_set_f64
scatter_depth2_fp64_type_mismatch
scatter_clip_2d_window_at_edge
scatter_simple_2d_window_out_of_bounds
scatter_depth2_mixed_dtypes_fp_mismatch_f64
scatter_depth2_mixed_dtypes_fp_mismatch
0.4.4
lax.scatter_add ScatterND(reduction='add') scatter_add_vector
scatter_add_vector_f64
scatter_add_scalar
scatter_add_scalar_f64
scatter_add_simple_1d
scatter_add_simple_1d_f64
scatter_add_batch_updates_1d_operand
scatter_add_batch_updates_1d_operand_f64
scatter_add_window_2d_operand_1d_indices
scatter_add_window_2d_operand_1d_indices_f64
scatter_add_mismatched_window_dims_from_user_report
scatter_add_mismatched_window_dims_from_user_report2
scatter_add_mismatched_window_dims_from_user_report3
scatter_add_fluids_pattern_updates_5_4_1_1
scatter_add_in_cond_float64
scatter_add_fp64_dtype_mismatch
scatter_add_depth2_depth2_helper_regression
scatter_depth2_fp64_type_mismatch
0.5.3
lax.scatter_max ScatterND(reduction='max') scatter_max_simple_1d
scatter_max_simple_1d_f64
scatter_max_batch_updates_1d_operand
scatter_max_batch_updates_1d_operand_f64
scatter_max_window_2d_operand_1d_indices
scatter_max_window_2d_operand_1d_indices_f64
scatter_max_fp64_dtype_path_check
scatter_max_depth2_helper_regression_fp64
0.7.5
lax.scatter_min ScatterND(reduction='min') scatter_min_simple_1d
scatter_min_simple_1d_f64
scatter_min_batch_updates_1d_operand
scatter_min_batch_updates_1d_operand_f64
scatter_min_window_2d_operand_1d_indices
scatter_min_window_2d_operand_1d_indices_f64
scatter_min_fp64_dtype_path_check
scatter_min_depth2_helper_regression_fp64
0.7.5
lax.scatter_mul ScatterND(reduction='mul') scatter_mul_simple_1d
scatter_mul_simple_1d_f64
scatter_mul_batch_updates_1d_operand
scatter_mul_batch_updates_1d_operand_f64
scatter_mul_window_2d_operand_1d_indices
scatter_mul_window_2d_operand_1d_indices_f64
scatter_mul_mismatched_window_dims_from_user_report
scatter_mul_mismatched_window_dims_from_user_report2
scatter_mul_mismatched_window_dims_from_user_report3
scatter_mul_fluids_pattern_updates_5_4_1_1
scatter_mul_in_cond_float64
0.6.4
lax.scatter_sub Neg
ScatterND(reduction='add')
scatter_sub_simple_1d
scatter_sub_simple_1d_f64
scatter_sub_window_2d_operand_1d_indices
scatter_sub_window_2d_operand_1d_indices_f64
0.12.1
lax.schur Identity schur_1x1_default
schur_1x1_default_f64
schur_1x1_no_vectors
schur_1x1_no_vectors_f64
0.12.1
lax.select Where select_simple
select_simple_f64
select_basic
select_basic_f64
select_mask_scores_tensor_else_dynamic
select_mask_scores_tensor_else_dynamic_f64
select_mask_scores_tensor_else
select_mask_scores_tensor_else_f64
0.7.1
lax.select_n Equal
Where
select_n_bool_predicate_two_cases_float
select_n_bool_predicate_two_cases_float_f64
select_n_bool_predicate_two_cases_int
select_n_bool_predicate_two_cases_int_f64
select_n_bool_predicate_scalar_broadcast
select_n_bool_predicate_scalar_broadcast_f64
select_n_int_indices_three_cases
select_n_int_indices_three_cases_f64
select_n_int_indices_four_cases
select_n_int_indices_four_cases_f64
0.2.0
lax.shard_map shard_map_inline_add
shard_map_inline_add_f64
0.10.2
lax.shift_left BitShift shift_left_vec
shift_left_vec_f64
shift_left_scalar
shift_left_scalar_f64
0.12.1
lax.shift_right_arithmetic BitShift shift_right_arithmetic_vec
shift_right_arithmetic_vec_f64
shift_right_arithmetic_scalar
shift_right_arithmetic_scalar_f64
0.12.1
lax.shift_right_logical BitShift shift_right_logical_vec
shift_right_logical_vec_f64
shift_right_logical_scalar
shift_right_logical_scalar_f64
0.7.2
lax.sign Sign sign
sign_f64
0.5.0
lax.sin Sin sin
sin_f64
0.4.4
lax.sinh Sinh sinh
sinh_f64
0.4.4
lax.slice Slice slice_test1
slice_test1_f64
slice_3d_none_strides
slice_3d_none_strides_f64
slice_scan_axis_drop
slice_scan_axis_drop_f64
0.1.0
lax.sort GatherElements
TopK
sort_1d
sort_1d_f64
sort_2d
sort_2d_f64
0.2.0
lax.split Split lax_split_equal_parts
lax_split_equal_parts_f64
lax_split_unequal_parts
lax_split_unequal_parts_f64
0.7.2
lax.sqrt ReduceL2
Sqrt
sqrt
sqrt_f64
sqrt_reduce_sum_square_axis1
sqrt_reduce_sum_square_axis1_f64
0.2.0
lax.square Mul square
square_f64
0.2.0
lax.squeeze Squeeze squeeze_single_axis
squeeze_single_axis_f64
squeeze_all_unit_dims_default
squeeze_all_unit_dims_default_f64
lax_squeeze_specific_axis_0
lax_squeeze_specific_axis_0_f64
lax_squeeze_multiple_axes
lax_squeeze_multiple_axes_f64
lax_squeeze_no_op_empty_dims
lax_squeeze_no_op_empty_dims_f64
lax_squeeze_problem_case_input_squeeze_only_axis_0
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64
lax_squeeze_problem_case_input_squeeze_axes_0_2
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64
0.2.0
lax.stop_gradient Identity stop_gradient
stop_gradient_f64
stop_gradient_basic
stop_gradient_basic_f64
0.2.0
lax.sub Sub sub_test1
sub_test1_f64
sub_test2
sub_test2_f64
sub_const
sub_const_f64
sub_complex64
sub_complex64_f64
0.1.0
lax.svd Abs
Add
Div
Gather
Identity
Max
Mul
Reshape
Sign
Sub
Where
svd_1x1_default
svd_1x1_default_f64
svd_1x1_values_only
svd_1x1_values_only_f64
svd_2x2_values_only
svd_2x2_values_only_f64
svd_3x2_values_only
svd_3x2_values_only_f64
svd_2x4_values_only
svd_2x4_values_only_f64
0.12.1
lax.symmetric_product Add
MatMul
Mul
Transpose
symmetric_product_default
symmetric_product_default_f64
symmetric_product_alpha_beta
symmetric_product_alpha_beta_f64
0.12.1
lax.tan Tan tan 0.12.1
lax.tanh Tanh tanh
tanh_f64
0.2.0
lax.top_k TopK top_k_last_axis
top_k_last_axis_f64
top_k_matrix
top_k_matrix_f64
0.10.2
lax.transpose Transpose transpose_basic
transpose_basic_f64
transpose_square_matrix
transpose_square_matrix_f64
transpose_3d
transpose_3d_f64
transpose_4d
transpose_4d_f64
transpose_reverse
transpose_reverse_f64
transpose_no_axes
transpose_no_axes_f64
transpose_nhwc_to_nchw
transpose_nhwc_to_nchw_f64
0.2.0
lax.triangular_solve Div
Gather
triangular_solve_left_basic
triangular_solve_left_basic_f64
triangular_solve_batched_unit_diag
0.12.1
lax.tridiagonal Gather
Identity
ScatterND
tridiagonal_2x2_lower_true
tridiagonal_2x2_lower_true_f64
tridiagonal_2x2_lower_false
tridiagonal_2x2_lower_false_f64
0.12.1
lax.tridiagonal_solve Concat
Div
Gather
Mul
Sub
tridiagonal_solve_single_rhs
tridiagonal_solve_single_rhs_f64
tridiagonal_solve_multi_rhs
tridiagonal_solve_multi_rhs_f64
0.12.1
lax.while_loop Loop while_scalar_counter
while_scalar_counter_f64
while_tuple_state
while_tuple_state_f64
while_loop_counter
while_loop_counter_f64
while_loop_vector
while_loop_vector_f64
while_loop_f64
while_loop_multi_state_f32
while_loop_multi_state_f64
while_loop_with_closure
while_loop_with_closure_f64
while_loop_basic
while_loop_two_state
while_loop_captured_tracer
while_loop_with_scalar_state
while_loop_renamed_passthrough
while_loop_closure_topo
while_loop_mixed_rank
while_loop_tracer_passthrough
while_loop_no_loop_output_reused_as_input
while_loop_4d_and_scalar_state
while_loop_4d_and_scalar_state_f64
while_loop_cnn_scalar_state_bug
while_loop_cnn_scalar_state_bug_f64
while_loop_nnx_repro
while_loop_nnx_repro_f64
0.5.1
lax.xor BitwiseXor
Xor
xor_bool_vec
xor_bool_vec_f64
xor_int_vec
xor_int_vec_f64
0.12.1
lax.zeta Add
Div
Pow
Where
zeta_positive
zeta_broadcast
0.12.1
linen.activation activation_glu_basic
activation_glu_basic_f64
activation_hard_sigmoid_basic
activation_hard_silu_basic
activation_hard_silu_basic_f64
activation_hard_swish_basic
activation_hard_tanh_basic
activation_hard_tanh_basic_f64
activation_log_sigmoid_basic
activation_log_sigmoid_basic_f64
activation_log_softmax_basic
activation_log_softmax_basic_f64
activation_relu6_basic
activation_relu6_basic_f64
activation_silu_basic
activation_silu_basic_f64
activation_swish_basic
activation_swish_basic_f64
activation_tanh_basic
activation_tanh_basic_f64
activation_normalize_basic
activation_normalize_basic_f64
activation_one_hot_basic
0.11.0
linen.avg_pool AveragePool
GlobalAveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
avg_pool_global_window_dynamic
avg_pool_global_window
0.11.0
linen.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale
batch_norm_bias_scale_dynamic
batch_norm_bias_scale
batch_norm_3d_dynamic
batch_norm_3d
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale
0.11.0
linen.bidirectional Concat
Loop
bidirectional_basic_dynamic
bidirectional_basic_dynamic_f64
bidirectional_basic
bidirectional_basic_f64
0.11.0
linen.conv CastLike
Conv
Reshape
Transpose
conv_basic_dynamic
conv_basic
conv_no_bias
conv_stride
0.11.0
linen.conv_local Conv
Gemm
Reshape
Transpose
conv_local_valid
conv_local_same
0.11.0
linen.conv_lstm_cell Add
Conv
Gemm
Mul
Sigmoid
Tanh
conv_lstm_cell_basic_dynamic
conv_lstm_cell_basic
0.11.0
linen.conv_transpose ConvTranspose conv_transpose_basic_dynamic
conv_transpose_basic
conv_transpose_valid_stride
0.11.0
linen.dense Gemm dense_basic_dynamic
dense_basic_dynamic_f64
dense_basic
dense_basic_f64
dense_high_rank_dynamic_dynamic
dense_high_rank_dynamic_dynamic_f64
dense_high_rank_static
dense_high_rank_static_f64
dense_high_rank_no_bias
dense_high_rank_no_bias_f64
dense_no_bias_dynamic
dense_no_bias_dynamic_f64
dense_no_bias
dense_no_bias_f64
0.11.0
linen.dense_general CastLike
Concat
Gemm
Reshape
Shape
Slice
dense_general_basic_dynamic
dense_general_basic_dynamic_f64
dense_general_basic
dense_general_basic_f64
dense_general_multi_out
dense_general_multi_out_f64
dense_general_contract_last_two
dense_general_contract_last_two_f64
dense_general_dynamic_batch_dynamic
dense_general_dynamic_batch_dynamic_f64
dense_general_no_bias
dense_general_no_bias_f64
0.11.0
linen.dot_product_attention MatMul
Mul
Softmax
Transpose
Where
dot_product_attention_basic
dot_product_attention_basic_f64
0.11.0
linen.dot_product_attention_weights Add
Div
MatMul
Mul
ReduceSum
Softmax
Where
dot_product_attention_weights_basic
dot_product_attention_weights_basic_f64
0.11.0
linen.dropout Dropout dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
0.11.0
linen.einsum Add
Einsum
Reshape
einsum_with_bias
einsum_with_bias_f64
einsum_no_bias
einsum_no_bias_f64
0.11.0
linen.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
0.11.0
linen.group_norm GroupNormalization group_norm_rank4
group_norm_rank2_dynamic
group_norm_rank2
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale
0.11.0
linen.gru_cell Add
Gemm
Mul
Sigmoid
Tanh
gru_cell_basic_dynamic
gru_cell_basic_dynamic_f64
gru_cell_basic
gru_cell_basic_f64
0.11.0
linen.instance_norm GroupNormalization
InstanceNormalization
instance_norm_rank4_dynamic
instance_norm_rank4
instance_norm_rank2_dynamic
instance_norm_rank2
0.11.0
linen.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
layer_norm_default_epsilon_dynamic
layer_norm_default_epsilon
0.11.0
linen.lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
lstm_cell_basic_dynamic
lstm_cell_basic_dynamic_f64
lstm_cell_basic
lstm_cell_basic_f64
0.11.0
linen.make_attention_mask Cast
Mul
make_attention_mask_basic
make_attention_mask_basic_f64
0.11.0
linen.make_causal_mask Cast
Less
make_causal_mask_basic
make_causal_mask_basic_f64
0.11.0
linen.max_pool GlobalMaxPool
MaxPool
max_pool
max_pool_same_padding
max_pool_basic
max_pool_same_dynamic
max_pool_same
max_pool_global_window_dynamic
max_pool_global_window
0.11.0
linen.mgu_cell Add
Gemm
Mul
Sigmoid
Tanh
mgu_cell_basic_dynamic
mgu_cell_basic_dynamic_f64
mgu_cell_basic
mgu_cell_basic_f64
0.11.0
linen.min_pool MaxPool
Neg
min_pool
min_pool_same_padding
min_pool_basic
min_pool_same_dynamic
min_pool_same
0.11.0
linen.multi_head_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
multi_head_attention_basic_dynamic
multi_head_attention_basic_dynamic_f64
multi_head_attention_basic
multi_head_attention_basic_f64
multi_head_attention_no_bias_dynamic
multi_head_attention_no_bias_dynamic_f64
multi_head_attention_no_bias
multi_head_attention_no_bias_f64
0.11.0
linen.multi_head_dot_product_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
multi_head_dot_product_attention_basic_dynamic
multi_head_dot_product_attention_basic_dynamic_f64
multi_head_dot_product_attention_basic
multi_head_dot_product_attention_basic_f64
multi_head_dot_product_attention_no_bias_dynamic
multi_head_dot_product_attention_no_bias_dynamic_f64
multi_head_dot_product_attention_no_bias
multi_head_dot_product_attention_no_bias_f64
0.11.0
linen.optimized_lstm_cell Add
Gemm
Mul
Sigmoid
Tanh
optimized_lstm_cell_basic_dynamic
optimized_lstm_cell_basic_dynamic_f64
optimized_lstm_cell_basic
optimized_lstm_cell_basic_f64
0.11.0
linen.pool AveragePool
MaxPool
Mul
Neg
Transpose
pool_max_basic_dynamic
pool_max_basic
pool_min_basic_dynamic
pool_min_basic
pool_sum_basic_dynamic
pool_sum_basic
0.11.0
linen.prelu PRelu linen_prelu_default_dynamic
linen_prelu_default
linen_prelu_custom_slope
0.12.1
linen.rms_norm RMSNormalization rms_norm_basic
rms_norm_use_scale_false
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic
0.11.0
linen.rnn Loop rnn_basic_dynamic
rnn_basic_dynamic_f64
rnn_basic
rnn_basic_f64
0.11.0
linen.self_attention Add
Gemm
MatMul
Mul
Reshape
Softmax
Transpose
self_attention_basic_dynamic
self_attention_basic_dynamic_f64
self_attention_basic
self_attention_basic_f64
self_attention_no_bias_dynamic
self_attention_no_bias_dynamic_f64
self_attention_no_bias
self_attention_no_bias_f64
0.11.0
linen.simple_cell Add
Gemm
Mul
Sigmoid
Tanh
simple_cell_basic_dynamic
simple_cell_basic_dynamic_f64
simple_cell_basic
simple_cell_basic_f64
0.11.0
linen.spectral_norm Div
MatMul
spectral_norm_dense_dynamic
spectral_norm_dense
0.11.0
linen.weight_norm Div
Mul
weight_norm_dense_dynamic
weight_norm_dense
0.11.0
nn.celu Celu jaxnn_celu
jaxnn_celu_1
jaxnn_celu_alpha_default
jaxnn_celu_alpha_custom_dynamic
jaxnn_celu_alpha_custom
celu_grad_issue_batch_diff_rules
0.7.1
nn.dot_product_attention Add
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_positional_bias_mask
dpa_positional_bias_mask_f64
dpa_diff_heads_embed
dpa_diff_heads_embed_f64
dpa_batch4_seq16
dpa_batch4_seq16_f64
dpa_float64
dpa_heads1_embed4
dpa_heads1_embed4_f64
dpa_heads8_embed8
dpa_heads8_embed8_f64
dpa_batch1_seq2
dpa_batch1_seq2_f64
dpa_batch8_seq4
dpa_batch8_seq4_f64
dpa_axis1
dpa_axis1_f64
dpa_with_tensor_mask
dpa_with_tensor_mask_f64
dpa_tiny_mask_all_valid
dpa_tiny_mask_all_valid_f64
dpa_tiny_mask_mixed
dpa_tiny_mask_mixed_f64
dpa_one_false
dpa_one_false_f64
dpa_mostly_false
dpa_mostly_false_f64
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_padding_mask
dpa_with_padding_mask_f64
dpa_with_local_window_mask
dpa_with_local_window_mask_f64
dpa_vmap_tnh_issue190
dpa_mask_none
0.8.0
nn.elu Elu jaxnn_elu
jaxnn_elu_1
jaxnn_elu_default
jaxnn_elu_custom_alpha_dynamic
jaxnn_elu_custom_alpha
elu_grad_issue_batch_diff_rules
0.7.1
nn.gelu Gelu jaxnn_gelu
jaxnn_gelu_1
jaxnn_gelu_approx
jaxnn_gelu_exact
jaxnn_gelu_tanh_dynamic
jaxnn_gelu_tanh
gelu_grad_issue_batch_diff_rules
0.7.1
nn.glu Mul
Sigmoid
Split
jaxnn_glu_axis_last
jaxnn_glu_axis_last_f64
jaxnn_glu_axis1_dynamic_dynamic
jaxnn_glu_axis1_dynamic_dynamic_f64
jaxnn_glu_axis1_dynamic
jaxnn_glu_axis1_dynamic_f64
0.12.1
nn.hard_sigmoid HardSigmoid jaxnn_hard_sigmoid
jaxnn_hard_sigmoid_dynamic_dynamic
jaxnn_hard_sigmoid_dynamic
0.12.1
nn.hard_swish HardSwish jaxnn_hard_swish
jaxnn_hard_swish_dynamic_dynamic
jaxnn_hard_swish_dynamic
0.12.1
nn.hard_tanh Max
Min
jaxnn_hard_tanh
jaxnn_hard_tanh_f64
jaxnn_hard_tanh_dynamic_dynamic
jaxnn_hard_tanh_dynamic_dynamic_f64
jaxnn_hard_tanh_dynamic
jaxnn_hard_tanh_dynamic_f64
0.12.1
nn.hardmax Hardmax jaxnn_hardmax_default_axis
jaxnn_hardmax_axis0
jaxnn_hardmax_dynamic_dynamic
jaxnn_hardmax_dynamic
jaxnn_hardmax_vmap_batching
0.12.1
nn.identity Identity jaxnn_identity
jaxnn_identity_f64
jaxnn_identity_1
jaxnn_identity_1_f64
jaxnn_identity_basic
jaxnn_identity_basic_f64
jaxnn_identity_dynamic_dynamic
jaxnn_identity_dynamic_dynamic_f64
jaxnn_identity_dynamic
jaxnn_identity_dynamic_f64
0.7.1
nn.leaky_relu LeakyRelu jaxnn_leaky_relu
jaxnn_leaky_relu_1
jaxnn_leaky_relu_default_dynamic
jaxnn_leaky_relu_default
jaxnn_leaky_relu_custom
leaky_relu_grad_issue_batch_diff_rules
0.7.1
nn.log1mexp Exp
Less
Log
Neg
Sub
Where
jaxnn_log1mexp_basic
jaxnn_log1mexp_basic_f64
jaxnn_log1mexp_dynamic_dynamic
jaxnn_log1mexp_dynamic_dynamic_f64
jaxnn_log1mexp_dynamic
jaxnn_log1mexp_dynamic_f64
0.12.1
nn.log_sigmoid Log
Sigmoid
jaxnn_log_sigmoid
jaxnn_log_sigmoid_f64
jaxnn_log_sigmoid_dynamic_dynamic
jaxnn_log_sigmoid_dynamic_dynamic_f64
jaxnn_log_sigmoid_dynamic
jaxnn_log_sigmoid_dynamic_f64
0.12.1
nn.log_softmax LogSoftmax jaxnn_log_softmax_default_axis_dynamic
jaxnn_log_softmax_default_axis_dynamic_f64
jaxnn_log_softmax_default_axis
jaxnn_log_softmax_default_axis_f64
jaxnn_log_softmax_axis0
jaxnn_log_softmax_axis0_f64
jaxnn_log_softmax_axis_last_3d
jaxnn_log_softmax_axis_last_3d_f64
log_softmax_grad_issue_batch_diff_rules
0.12.1
nn.logmeanexp ReduceLogSumExp
Sub
jaxnn_logmeanexp_axis1
jaxnn_logmeanexp_axis1_f64
jaxnn_logmeanexp_axis12_keepdims
jaxnn_logmeanexp_axis12_keepdims_f64
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic_f64
jaxnn_logmeanexp_dynamic_static_reduction_axis
jaxnn_logmeanexp_dynamic_static_reduction_axis_f64
0.12.1
nn.logsumexp ReduceLogSumExp jaxnn_logsumexp_axis1
jaxnn_logsumexp_axis1_f64
jaxnn_logsumexp_axis12_keepdims
jaxnn_logsumexp_axis12_keepdims_f64
jaxscipy_logsumexp_axis_last_dynamic
jaxscipy_logsumexp_axis_last_dynamic_f64
jaxscipy_logsumexp_axis_last
jaxscipy_logsumexp_axis_last_f64
logsumexp_grad_issue_batch_diff_rules
0.12.1
nn.mish Mish jaxnn_mish
jaxnn_mish_1
jaxnn_mish_basic
mish_grad_issue_batch_diff_rules
0.7.1
nn.one_hot OneHot jaxnn_one_hot_default
jaxnn_one_hot_axis0
0.12.1
nn.relu Relu jaxnn_relu
jaxnn_relu_f64
jaxnn_relu_1
jaxnn_relu_1_f64
jaxnn_relu_basic
jaxnn_relu_basic_f64
jaxnn_relu_dynamic_dynamic
jaxnn_relu_dynamic_dynamic_f64
jaxnn_relu_dynamic
jaxnn_relu_dynamic_f64
relu_grad_issue_batch_diff_rules
0.7.1
nn.relu6 Max
Min
jaxnn_relu6
jaxnn_relu6_f64
jaxnn_relu6_dynamic_dynamic
jaxnn_relu6_dynamic_dynamic_f64
jaxnn_relu6_dynamic
jaxnn_relu6_dynamic_f64
0.12.1
nn.scaled_dot_general Einsum
Gemm
MatMul
jaxnn_scaled_dot_general_basic
jaxnn_scaled_dot_general_batched
jaxnn_scaled_dot_general_batched_f64
0.12.1
nn.scaled_matmul MatMul
Mul
Reshape
Tile
Transpose
Unsqueeze
jaxnn_scaled_matmul_basic 0.12.1
nn.selu Selu jaxnn_selu
jaxnn_selu_1
jaxnn_selu_basic_dynamic
jaxnn_selu_basic
selu_grad_issue_batch_diff_rules
0.7.1
nn.sigmoid Sigmoid jaxnn_sigmoid
jaxnn_sigmoid_f64
jaxnn_sigmoid_1
jaxnn_sigmoid_1_f64
sigmoid_grad_issue_batch_diff_rules
0.7.1
nn.silu Mul
Sigmoid
Swish
jaxnn_silu
jaxnn_silu_f64
jaxnn_silu_dynamic_dynamic
jaxnn_silu_dynamic_dynamic_f64
jaxnn_silu_dynamic
jaxnn_silu_dynamic_f64
jaxnn_swish_alias
jaxnn_swish_alias_f64
silu_grad_issue_batch_diff_rules
0.12.1
nn.soft_sign Softsign jaxnn_soft_sign
jaxnn_soft_sign_1
jaxnn_softsign_basic
softsign_grad_issue_batch_diff_rules
0.7.1
nn.softmax Softmax softmax
softmax_f64
softmax_2d
softmax_2d_f64
softmax_3d
softmax_3d_f64
softmax_mask_where
softmax_mask_where_f64
softmax_grad_issue_batch_diff_rules
0.7.1
nn.softplus Softplus jaxnn_softplus
jaxnn_softplus_1
jaxnn_softplus_basic
softplus_grad_issue_batch_diff_rules
0.7.1
nn.sparse_plus Add
Greater
Less
Mul
Where
jaxnn_sparse_plus
jaxnn_sparse_plus_f64
jaxnn_sparse_plus_dynamic_dynamic
jaxnn_sparse_plus_dynamic_dynamic_f64
jaxnn_sparse_plus_dynamic
jaxnn_sparse_plus_dynamic_f64
0.12.1
nn.sparse_sigmoid Add
Greater
Less
Mul
Where
jaxnn_sparse_sigmoid
jaxnn_sparse_sigmoid_f64
jaxnn_sparse_sigmoid_dynamic_dynamic
jaxnn_sparse_sigmoid_dynamic_dynamic_f64
jaxnn_sparse_sigmoid_dynamic
jaxnn_sparse_sigmoid_dynamic_f64
0.12.1
nn.squareplus Add
Mul
Sqrt
jaxnn_squareplus_default_b
jaxnn_squareplus_default_b_f64
jaxnn_squareplus_broadcast_b
jaxnn_squareplus_broadcast_b_f64
0.12.1
nn.standardize MeanVarianceNormalization jaxnn_standardize_axis_last_eps0
jaxnn_standardize_axis_tuple_eps0_dynamic
jaxnn_standardize_axis_tuple_eps0
standardize_grad_issue_batch_diff_rules
0.12.1
nn.tanh Tanh jaxnn_tanh
jaxnn_tanh_f64
jaxnn_tanh_dynamic_dynamic
jaxnn_tanh_dynamic_dynamic_f64
jaxnn_tanh_dynamic
jaxnn_tanh_dynamic_f64
0.12.1
nn.thresholded_relu ThresholdedRelu jaxnn_thresholded_relu_default
jaxnn_thresholded_relu_custom_dynamic
jaxnn_thresholded_relu_custom
0.12.1
nn.truncated_normal initializer
random_truncated_normal_positional
flax_dense_like_init
0.7.1
nnx.avg_pool AveragePool
GlobalAveragePool
Transpose
avg_pool_dynamic
avg_pool
avg_pool_same_padding_dynamic
avg_pool_same_padding
avg_pool_default_padding_dynamic
avg_pool_default_padding
avg_pool_stride1_dynamic
avg_pool_stride1
avg_pool_win3x3_stride2_dynamic
avg_pool_win3x3_stride2
avg_pool_stride_none_dynamic
avg_pool_stride_none
avg_pool_count_include_pad_false_dynamic
avg_pool_count_include_pad_false
avg_pool_global_window_dynamic
avg_pool_global_window
0.1.0
nnx.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic
batch_norm_no_bias_no_scale
batch_norm_bias_no_scale_dynamic
batch_norm_bias_no_scale
batch_norm_no_bias_scale_dynamic
batch_norm_no_bias_scale
batch_norm_bias_scale_dynamic
batch_norm_bias_scale
batch_norm_3d_dynamic
batch_norm_3d
batch_norm_4d_dynamic
batch_norm_4d
batch_norm_4d_no_bias_no_scale_dynamic
batch_norm_4d_no_bias_no_scale
0.1.0
nnx.combine_masks And
Cast
combine_masks_two
combine_masks_with_none
0.12.2
nnx.conv CastLike
Conv
Reshape
Transpose
conv_basic_bias_dynamic
conv_basic_bias
conv_basic_bias_2
conv_basic_bias_3
conv_stride2_bias
conv_no_bias_dynamic
conv_no_bias
conv_valid_padding
conv_stride1
conv_stride2
conv_2d_reflect_padding
conv_2d_circular_padding
conv_different_kernel
conv_float64
conv_single_batch
conv_large_batch
conv_1d_causal_padding
conv_1d
conv_1d_more_1d_inputs
conv_1d_more_2d_inputs
conv_1d_large_kernel
conv_1d_dilation
conv_1d_stride_dilation
conv_2d_asymmetric_kernel
conv_2d_asymmetric_stride
conv_2d_asymmetric_dilation
conv_2d_large_dilation
conv_2d_large_stride
conv_2d_mixed_params
conv_2d_same_padding_mixed_dilation
conv_3d_basic
conv_3d_stride
conv_3d_asymmetric
conv_3d_dilation
conv_2d_small_input
conv_2d_many_channels
conv_1d_wide_input
conv_2d_kernel_1x1
conv_1d_kernel_1
conv_2d_group_conv
conv_1d_group_conv_more_dims
conv_2d_depthwise
conv_1d_complex_on_4d
conv_2d_complex_on_5d
conv_2d_asymmetric_on_5d
conv_1d_high_dilation_on_3d
conv_1d_large_kernel_on_4d
conv_2d_group_stride_dilation
conv_1d_group_on_higher_dim
conv_1d_same_padding_on_3d
conv_3d_group_complex
conv_1d_unit_group_on_multi_dim
0.1.0
nnx.dot_product_attention Add
Attention
MatMul
Mul
Softmax
Transpose
Where
dpa_basic
dpa_basic_f64
dpa_with_tensor_mask
dpa_with_bias
dpa_with_causal_mask
dpa_with_causal_mask_f64
dpa_with_mask_and_bias
dpa_gqa_basic
dpa_gqa_basic_opset23
0.1.0
nnx.dropout Constant
Dropout
dropout_init_params_dynamic
dropout_init_params_dynamic_f64
dropout_init_params
dropout_init_params_f64
dropout_call_params_dynamic
dropout_call_params_dynamic_f64
dropout_call_params
dropout_call_params_f64
0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias
einsum_module_with_bias_f64
einsum_module_no_bias
einsum_module_no_bias_f64
0.4.2
nnx.elu Elu elu
elu_default_dynamic
elu_default
elu_alpha
0.2.0
nnx.embed Gather token_embedding_dynamic
token_embedding_dynamic_f64
token_embedding
token_embedding_f64
positional_embedding_dynamic
positional_embedding_dynamic_f64
positional_embedding
positional_embedding_f64
0.7.0
nnx.flip_sequences GatherElements
Slice
flip_sequences_batch_major_with_lengths
flip_sequences_batch_major_with_lengths_f64
flip_sequences_batch_major_no_lengths_dynamic
flip_sequences_batch_major_no_lengths_dynamic_f64
flip_sequences_batch_major_no_lengths
flip_sequences_batch_major_no_lengths_f64
flip_sequences_time_major_with_lengths
flip_sequences_time_major_with_lengths_f64
0.12.2
nnx.gelu Gelu gelu
gelu_1
gelu_2
gelu_2_f64
gelu_3_dynamic
gelu_3_dynamic_f64
gelu_3
gelu_3_f64
gelu_4
gelu_4_f64
gelu_5_dynamic
gelu_5_dynamic_f64
gelu_5
gelu_5_f64
0.1.0
nnx.glu Mul
Sigmoid
Split
glu_last_axis
glu_last_axis_f64
glu_axis_1_dynamic
glu_axis_1_dynamic_f64
glu_axis_1
glu_axis_1_f64
0.12.2
nnx.group_norm GroupNormalization group_norm
group_norm_rank2_dynamic
group_norm_rank2
group_norm_rank4
group_norm_no_bias_dynamic
group_norm_no_bias
group_norm_no_bias_no_scale_dynamic
group_norm_no_bias_no_scale
group_norm_bias_no_scale_dynamic
group_norm_bias_no_scale
group_norm_no_scale_dynamic
group_norm_no_scale
group_norm_no_bias_scale_dynamic
group_norm_no_bias_scale
group_norm_bias_scale_dynamic
group_norm_bias_scale
0.2.0
nnx.hard_tanh Clip hard_tanh_basic_dynamic
hard_tanh_basic_dynamic_f64
hard_tanh_basic
hard_tanh_basic_f64
0.12.2
nnx.layer_norm LayerNormalization layer_norm_dynamic
layer_norm
layer_norm_no_bias_no_scale_dynamic
layer_norm_no_bias_no_scale
layer_norm_bias_no_scale_dynamic
layer_norm_bias_no_scale
layer_norm_no_bias_scale_dynamic
layer_norm_no_bias_scale
layer_norm_bias_scale_dynamic
layer_norm_bias_scale
layer_norm_multiaxis_dynamic
layer_norm_multiaxis
layer_norm_symbolic_batch_dynamic
layer_norm_symbolic_batch
layer_norm_symbolic_batch_seq10_feat3_dynamic
layer_norm_symbolic_batch_seq10_feat3
layer_norm_symbolic_batch_seq10_feat3_2_dynamic
layer_norm_symbolic_batch_seq10_feat3_2
layer_norm_negative_axis_no_div_dynamic
layer_norm_negative_axis_no_div
0.1.0
nnx.leaky_relu LeakyRelu leaky_relu
leaky_relu_default_dynamic
leaky_relu_default
leaky_relu_custom
0.2.0
nnx.linear CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_symbolic_batch_dynamic
linear_symbolic_batch_dynamic_f64
linear_symbolic_batch
linear_symbolic_batch_f64
linear_high_rank_dynamic
linear_high_rank_dynamic_f64
linear_high_rank_static
linear_high_rank_static_f64
linear_no_bias_dynamic
linear_no_bias_dynamic_f64
linear_no_bias
linear_no_bias_f64
linear_high_rank_no_bias_dynamic
linear_high_rank_no_bias_dynamic_f64
linear_high_rank_no_bias
linear_high_rank_no_bias_f64
linear_merge_symbolic_dim_dynamic
0.1.0
nnx.linear_general CastLike
Concat
Gemm
Reshape
Shape
Slice
linear_general_merge_symbolic_dim_dynamic
linear_general_dynamic
linear_general
linear_general_2
linear_general_3
linear_general_4
linear_general_abstract_eval_axes
linear_general_abstract_eval_axes_pair
dynamic_batch_and_feature_dims_dynamic
0.1.0
nnx.log_sigmoid Neg
Softplus
log_sigmoid_basic_dynamic
log_sigmoid_basic_dynamic_f64
log_sigmoid_basic
log_sigmoid_basic_f64
0.12.2
nnx.log_softmax LogSoftmax log_softmax
log_softmax_f64
log_softmax_default_axis_dynamic
log_softmax_default_axis_dynamic_f64
log_softmax_default_axis
log_softmax_default_axis_f64
log_softmax_axis0
log_softmax_axis0_f64
0.2.0
nnx.lora Add
MatMul
lora_basic_dynamic
lora_basic_dynamic_f64
lora_basic
lora_basic_f64
lora_static
lora_static_f64
0.12.2
nnx.lora_linear Add
MatMul
lora_linear_basic_dynamic
lora_linear_basic_dynamic_f64
lora_linear_basic
lora_linear_basic_f64
lora_linear_high_rank_dynamic
lora_linear_high_rank_dynamic_f64
lora_linear_high_rank
lora_linear_high_rank_f64
0.12.2
nnx.max_pool GlobalMaxPool
MaxPool
max_pool
max_pool_same_padding
max_pool_basic
max_pool_same_dynamic
max_pool_same
max_pool_global_window_dynamic
max_pool_global_window
0.2.0
nnx.prelu PRelu prelu_default
prelu_custom_slope_dynamic
prelu_custom_slope
0.12.1
nnx.relu Relu relu_1d
relu_1d_f64
relu_4d_dynamic
relu_4d_dynamic_f64
relu_4d
relu_4d_f64
0.2.0
nnx.rms_norm RMSNormalization rms_norm_basic
rms_norm_use_scale_false
rms_norm_4d_dynamic_dynamic
rms_norm_4d_dynamic
rms_norm_4d_dynamic_no_scale_dynamic
rms_norm_4d_dynamic_no_scale
0.2.0
nnx.sigmoid Sigmoid sigmoid_dynamic
sigmoid_dynamic_f64
sigmoid
sigmoid_f64
0.2.0
nnx.softmax Softmax softmax_dynamic
softmax_dynamic_f64
softmax
softmax_f64
0.1.0
nnx.softplus Softplus softplus 0.1.0
nnx.tanh Tanh tanh
tanh_f64
0.1.0
random.bernoulli Bernoulli bernoulli_scalar_prob_shape
bernoulli_tensor_prob
0.12.1
random.random_bits Cast
Floor
RandomUniformLike
RandomUniform
random_bits_uint32
random_bits_uint32_f64
0.7.2
random.random_categorical Multinomial random_categorical_logits_batch
random_categorical_logits_batch_opset23
random_categorical_logits_rank3
random_categorical_logits_rank3_opset23
random_categorical_logits_symbolic_batch_dynamic
random_categorical_logits_symbolic_batch_opset23_dynamic
0.12.1
random.random_fold_in Identity random_fold_in_passthrough
random_fold_in_passthrough_f64
0.2.0
random.random_normal RandomNormalLike
RandomNormal
random_normal_f32_2x3 0.12.1
random.random_seed Cast
Concat
random_seed_basic
random_seed_basic_f64
0.2.0