| core.custom_jvp_generic (Generic passthrough for custom JVP calls) |
➖ |
custom_jvp_square ✅
custom_jvp_square_f64 ✅ |
0.7.1 |
| core.custom_vjp_generic (Generic passthrough for custom VJP calls) |
➖ |
custom_vjp_square ✅
custom_vjp_square_f64 ✅ |
0.7.1 |
| core.dim_as_value |
Cast Gather Reshape Shape |
dim_as_value_dynamic ✅
dim_as_value_dynamic_f64 ✅
dim_as_value ✅
dim_as_value_f64 ✅ |
0.5.0 |
| core.jit_inline |
➖ |
jit_identity ✅
jit_identity_f64 ✅ |
0.9.0 |
| dm_pix.depth_to_space |
DepthToSpace |
depth_to_space_nhwc ✅ |
0.12.0 |
| dm_pix.space_to_depth |
SpaceToDepth |
space_to_depth_nhwc ✅ |
0.12.1 |
| eqx.adaptive_pool |
AveragePool MaxPool |
eqx_adaptive_avg_pool2d_divisible ✅
eqx_adaptive_max_pool2d_divisible ✅
eqx_adaptive_pool_generic_avg_vmap_dynamic ✅
eqx_adaptive_pool_generic_avg_vmap ✅ |
0.12.2 |
| eqx.avg_pool |
AveragePool |
eqx_avg_pool2d_basic ✅
eqx_avg_pool2d_batched_dynamic ✅
eqx_avg_pool2d_batched ✅ |
0.12.2 |
| eqx.batch_norm |
BatchNormalization |
eqx_batch_norm_inference ✅ |
0.12.2 |
| eqx.conv |
Conv |
eqx_conv2d_nchw ✅
eqx_conv2d_batched_nchw ✅ |
0.10.0 |
| eqx.dropout |
Dropout Not |
eqx_dropout_inference_mode ✅
eqx_dropout_inference_mode_f64 ✅
eqx_dropout_training_mode ✅
eqx_dropout_training_mode_f64 ✅
eqx_dropout_dynamic_inference ✅
eqx_dropout_dynamic_inference_f64 ✅
eqx_dropout_batched_inference_dynamic ✅
eqx_dropout_batched_inference_dynamic_f64 ✅
eqx_dropout_batched_inference ✅
eqx_dropout_batched_inference_f64 ✅ |
0.8.0 |
| eqx.embedding |
Gather |
eqx_embedding_scalar_index ✅
eqx_embedding_scalar_index_f64 ✅
eqx_embedding_batched_vmap_dynamic ✅
eqx_embedding_batched_vmap_dynamic_f64 ✅
eqx_embedding_batched_vmap ✅
eqx_embedding_batched_vmap_f64 ✅ |
0.12.2 |
| eqx.group_norm |
GroupNormalization |
eqx_group_norm_rank3 ✅
eqx_group_norm_no_affine ✅ |
0.12.2 |
| eqx.gru_cell |
Add Gemm Mul Sigmoid Tanh |
eqx_gru_cell_basic ✅ |
0.12.2 |
| eqx.identity |
Identity |
eqx_identity_static ✅
eqx_identity_static_f64 ✅
eqx_identity_symbolic_batch_dynamic ✅
eqx_identity_symbolic_batch_dynamic_f64 ✅
eqx_identity_symbolic_batch ✅
eqx_identity_symbolic_batch_f64 ✅ |
0.8.0 |
| eqx.lambda |
Relu Sigmoid Tanh |
eqx_lambda_relu ✅
eqx_lambda_relu_f64 ✅
eqx_lambda_sigmoid ✅
eqx_lambda_sigmoid_f64 ✅ |
0.12.2 |
| eqx.layer_norm |
LayerNormalization |
layer_norm ✅
layer_norm_f64 ✅
layer_norm_multiaxis ✅
layer_norm_multiaxis_f64 ✅
batched_layer_norm_dynamic ✅
batched_layer_norm_dynamic_f64 ✅
batched_layer_norm ✅
batched_layer_norm_f64 ✅
layer_norm_no_bias_no_scale ✅
layer_norm_no_bias_no_scale_f64 ✅ |
0.8.0 |
| eqx.linear |
Gemm Reshape |
eqx_linear_symbolic_batch_dynamic ✅
eqx_linear_symbolic_batch_dynamic_f64 ✅
eqx_linear_symbolic_batch ✅
eqx_linear_symbolic_batch_f64 ✅
eqx_linear_no_bias_symbolic_batch_dynamic ✅
eqx_linear_no_bias_symbolic_batch_dynamic_f64 ✅
eqx_linear_no_bias_symbolic_batch ✅
eqx_linear_no_bias_symbolic_batch_f64 ✅
eqx_linear_no_bias_vector ✅
eqx_linear_no_bias_vector_f64 ✅
eqx_linear_high_rank ✅
eqx_linear_high_rank_f64 ✅
eqx_linear_vector ✅
eqx_linear_vector_f64 ✅ |
0.8.0 |
| eqx.lstm_cell |
Add Gemm Mul Sigmoid Tanh |
eqx_lstm_cell_basic ✅ |
0.12.2 |
| eqx.max_pool |
MaxPool |
eqx_max_pool2d_basic ✅
eqx_max_pool2d_basic_f64 ✅
eqx_max_pool2d_batched_dynamic ✅
eqx_max_pool2d_batched_dynamic_f64 ✅
eqx_max_pool2d_batched ✅
eqx_max_pool2d_batched_f64 ✅ |
0.12.2 |
| eqx.multihead_attention |
Attention Gemm MatMul Softmax |
eqx_multihead_attention ✅
eqx_multihead_attention_opset23 ✅
eqx_multihead_attention_core_dynamic ✅
eqx_multihead_attention_core_opset23_dynamic ✅
eqx_multihead_attention_core_static ✅
eqx_multihead_attention_core_static_opset23 ✅ |
0.10.0 |
| eqx.pool |
MaxPool |
eqx_pool_max2d_basic ✅
eqx_pool_max2d_basic_f64 ✅
eqx_pool_max2d_batched_dynamic ✅
eqx_pool_max2d_batched_dynamic_f64 ✅
eqx_pool_max2d_batched ✅
eqx_pool_max2d_batched_f64 ✅ |
0.12.2 |
| eqx.prelu |
PRelu |
eqx_prelu_default_dynamic ✅
eqx_prelu_default ✅
eqx_prelu_channelwise ✅ |
0.12.2 |
| eqx.rms_norm |
RMSNormalization |
rms_norm_vector ✅
rms_norm_vector_f64 ✅
rms_norm_no_affine ✅
rms_norm_no_affine_f64 ✅ |
0.10.2 |
| eqx.rotary_positional_embedding |
Add Concat Multiply RotaryEmbedding |
eqx_rotary_positional_embedding ✅
eqx_rotary_positional_embedding_opset23 ✅
eqx_rotary_positional_embedding_heads ✅ |
0.10.0 |
| eqx.sequential |
Identity Relu Sigmoid Tanh |
eqx_sequential_relu_tanh_dynamic ✅
eqx_sequential_relu_tanh_dynamic_f64 ✅
eqx_sequential_relu_tanh ✅
eqx_sequential_relu_tanh_f64 ✅
eqx_sequential_identity_sigmoid ✅
eqx_sequential_identity_sigmoid_f64 ✅ |
0.12.2 |
| eqx.spectral_norm |
Div MatMul |
eqx_spectral_norm_linear_inference ✅ |
0.12.2 |
| eqx.weight_norm |
Div Mul |
eqx_weight_norm_linear_dynamic ✅
eqx_weight_norm_linear ✅
eqx_weight_norm_conv2d ✅ |
0.12.2 |
| jax_image.resize |
Resize Upsample |
resize_linear ✅
resize_nearest ✅
resize_nearest_antialias_ignored ✅
resize_nearest_opset9_upsample ✅
resize_linear_opset9_upsample ✅
resize_nearest_rank3_opset9_upsample ✅ |
0.10.0 |
| jnp.abs |
Abs |
jnp_abs_basic ✅
jnp_abs_basic_f64 ✅
jnp_abs_int ✅
jnp_abs_int_f64 ✅
abs_vmap_batching ✅
abs_vmap_batching_f64 ✅
abs_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.acos |
Acos |
jnp_acos_basic ✅
jnp_acos_int_promote ✅
acos_vmap_batching ✅
acos_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.acosh |
Acosh |
jnp_acosh_basic ✅
jnp_acosh_int_promote ✅
acosh_vmap_batching ✅
acosh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.add |
Add |
add ✅
add_f64 ✅
jnp_add_vector ✅
jnp_add_vector_f64 ✅
jnp_add_broadcast ✅
jnp_add_broadcast_f64 ✅
add_vmap_batching ✅
add_vmap_batching_f64 ✅
add_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.all |
ReduceMin |
jnp_all_basic ✅
jnp_all_basic_f64 ✅
jnp_all_axis_keepdims ✅
jnp_all_axis_keepdims_f64 ✅
all_vmap_batching ✅
all_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.amax |
ReduceMax |
jnp_amax_basic ✅
jnp_amax_basic_f64 ✅
jnp_amax_axis ✅
jnp_amax_axis_f64 ✅
jnp_amax_keepdims ✅
jnp_amax_keepdims_f64 ✅
amax_vmap_batching ✅
amax_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.amin |
ReduceMin |
jnp_amin_basic ✅
jnp_amin_basic_f64 ✅
jnp_amin_axis ✅
jnp_amin_axis_f64 ✅
jnp_amin_keepdims ✅
jnp_amin_keepdims_f64 ✅
amin_vmap_batching ✅
amin_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.any |
ReduceMax |
jnp_any_basic ✅
jnp_any_basic_f64 ✅
jnp_any_axis_keepdims ✅
jnp_any_axis_keepdims_f64 ✅
any_vmap_batching ✅
any_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.arange |
Range |
arange_data_dependent_indices ✅
arange_stop_only_concrete_input_val ✅
arange_stop_only_concrete_input_val_f64 ✅
arange_start_stop_concrete_input_val ✅
arange_start_stop_concrete_input_val_f64 ✅
arange_start_stop_step_concrete_input_val ✅
arange_start_stop_step_concrete_input_val_f64 ✅
arange_float_concrete_input_val ✅
arange_float_concrete_input_val_f64 ✅
arange_static_stop_only_int ✅
arange_static_stop_only_int_f64 ✅
arange_static_stop_only_float ✅
arange_static_stop_only_float_f64 ✅
arange_static_start_stop_int ✅
arange_static_start_stop_int_f64 ✅
arange_static_start_stop_step_int ✅
arange_static_start_stop_step_int_f64 ✅
arange_static_empty_result_pos_step ✅
arange_static_empty_result_pos_step_f64 ✅
arange_static_empty_result_neg_step ✅
arange_static_empty_result_neg_step_f64 ✅
arange_static_negative_step ✅
arange_static_negative_step_f64 ✅
arange_static_float_step_explicit_dtype ✅
arange_static_float_step_explicit_dtype_f64 ✅
arange_static_float_step_inferred_dtype ✅
arange_static_float_step_inferred_dtype_f64 ✅
arange_static_stop_zero ✅
arange_static_stop_zero_f64 ✅
arange_static_start_equals_stop ✅
arange_static_start_equals_stop_f64 ✅
arange_static_large_numbers_int ✅
arange_static_large_numbers_int_f64 ✅ |
0.5.2 |
| jnp.argmax |
ArgMax |
jnp_argmax_axis1 ✅
jnp_argmax_axis0_bool ✅ |
0.12.2 |
| jnp.argmin |
ArgMin |
jnp_argmin_axis1 ✅
jnp_argmin_axis0 ✅ |
0.12.2 |
| jnp.asin |
Asin |
jnp_asin_basic ✅
jnp_asin_int_promote ✅
asin_vmap_batching ✅
asin_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.asinh |
Asinh |
jnp_asinh_basic ✅
jnp_asinh_int_promote ✅
asinh_vmap_batching ✅
asinh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.atan |
Atan |
jnp_atan_basic ✅
jnp_atan_int_promote ✅
atan_vmap_batching ✅
atan_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.atan2 |
Atan Div Where |
jnp_atan2_quadrants ✅
jnp_atan2_broadcast ✅
atan2_vmap_batching ✅
atan2_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.atanh |
Atanh |
jnp_atanh_basic ✅
jnp_atanh_int_promote ✅
atanh_vmap_batching ✅
atanh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.bartlett |
Constant |
jnp_bartlett_basic ✅
jnp_bartlett_basic_f64 ✅
jnp_bartlett_empty ✅
jnp_bartlett_empty_f64 ✅ |
0.12.1 |
| jnp.bitwise_and |
And BitwiseAnd |
jnp_bitwise_and_int ✅
jnp_bitwise_and_int_f64 ✅
jnp_bitwise_and_bool ✅
jnp_bitwise_and_bool_f64 ✅
jnp_bitwise_and_mixed_dtype ✅
jnp_bitwise_and_mixed_dtype_f64 ✅
jnp_bitwise_and_broadcast ✅
jnp_bitwise_and_broadcast_f64 ✅
bitwise_and_vmap_batching ✅
bitwise_and_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.bitwise_left_shift |
BitShift |
jnp_bitwise_left_shift_basic ✅
jnp_bitwise_left_shift_basic_f64 ✅
jnp_bitwise_left_shift_bool_bool ✅
jnp_bitwise_left_shift_bool_bool_f64 ✅
bitwise_left_shift_vmap_batching ✅
bitwise_left_shift_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.bitwise_not |
BitwiseNot Not |
jnp_bitwise_not_bool ✅
jnp_bitwise_not_bool_f64 ✅
jnp_bitwise_not_int ✅
jnp_bitwise_not_int_f64 ✅
bitwise_not_vmap_batching ✅
bitwise_not_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.bitwise_or |
BitwiseOr Or |
jnp_bitwise_or_int ✅
jnp_bitwise_or_int_f64 ✅
jnp_bitwise_or_bool ✅
jnp_bitwise_or_bool_f64 ✅
jnp_bitwise_or_mixed_dtype ✅
jnp_bitwise_or_mixed_dtype_f64 ✅
jnp_bitwise_or_broadcast ✅
jnp_bitwise_or_broadcast_f64 ✅
bitwise_or_vmap_batching ✅
bitwise_or_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.bitwise_right_shift |
BitShift |
jnp_bitwise_right_shift_signed ✅
jnp_bitwise_right_shift_signed_f64 ✅
jnp_bitwise_right_shift_bool_bool ✅
jnp_bitwise_right_shift_bool_bool_f64 ✅
bitwise_right_shift_vmap_batching ✅
bitwise_right_shift_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.bitwise_xor |
BitwiseXor Xor |
jnp_bitwise_xor_int ✅
jnp_bitwise_xor_int_f64 ✅
jnp_bitwise_xor_bool ✅
jnp_bitwise_xor_bool_f64 ✅
jnp_bitwise_xor_mixed_dtype ✅
jnp_bitwise_xor_mixed_dtype_f64 ✅
jnp_bitwise_xor_broadcast ✅
jnp_bitwise_xor_broadcast_f64 ✅
bitwise_xor_vmap_batching ✅
bitwise_xor_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.blackman |
BlackmanWindow |
jnp_blackman_basic ✅ |
0.12.1 |
| jnp.block |
Concat |
jnp_block_basic ✅ |
0.12.1 |
| jnp.ceil |
Ceil Identity |
jnp_ceil_basic ✅
jnp_ceil_basic_f64 ✅
jnp_ceil_int_identity ✅
jnp_ceil_int_identity_f64 ✅
ceil_vmap_batching ✅
ceil_vmap_batching_f64 ✅
ceil_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.choose |
Concat Gather |
jnp_choose_wrap ✅ |
0.12.1 |
| jnp.clip |
Clip Max Min |
clip_i32_scalar_bounds ✅
clip_i32_scalar_bounds_f64 ✅
clip_f32_scalar_bounds_no_upcast_f64_mode ✅
clip_only_upper ✅
clip_only_upper_f64 ✅
clip_only_lower ✅
clip_only_lower_f64 ✅
clip_broadcast_bounds ✅
clip_grad_issue_batch_diff_rules ✅
clip_keyword_min_alias ✅
clip_keyword_min_alias_f64 ✅ |
0.8.0 |
| jnp.compress |
Compress |
jnp_compress_axis0 ✅
jnp_compress_axis0_f64 ✅
jnp_compress_axis1 ✅
jnp_compress_axis1_f64 ✅
compress_jvp_issue_batch_diff_rules ✅ |
0.12.1 |
| jnp.concatenate |
Concat |
concatenate_basic ✅
concatenate_basic_f64 ✅
concatenate_mixed_dtypes ✅
concatenate_mixed_dtypes_f64 ✅
concatenate_with_explicit_dtype ✅
concatenate_with_explicit_dtype_casts_inputs ✅
concatenate_abstract_middle_dim_dynamic ✅
concatenate_abstract_middle_dim_dynamic_f64 ✅
concatenate_abstract_middle_dim ✅
concatenate_abstract_middle_dim_f64 ✅
concatenate_tile_and_symbolic_dynamic ✅
concatenate_tile_and_symbolic_dynamic_f64 ✅
concatenate_tile_and_symbolic ✅
concatenate_tile_and_symbolic_f64 ✅
concatenate_grad_issue191 ✅ |
0.8.0 |
| jnp.conj |
Identity |
jnp_conj_real ✅
jnp_conj_real_f64 ✅
jnp_conj_complex64 ✅
jnp_conj_complex64_f64 ✅
conj_vmap_batching ✅
conj_vmap_batching_f64 ✅
conj_grad_issue_batch_diff_rules ✅ |
0.10.1 |
| jnp.copysign |
Abs Less Neg Where |
jnp_copysign_basic ✅
jnp_copysign_basic_f64 ✅
jnp_copysign_int_promote ✅
jnp_copysign_int_promote_f64 ✅
jnp_copysign_broadcast ✅
jnp_copysign_broadcast_f64 ✅
copysign_vmap_batching ✅
copysign_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.corrcoef |
Div MatMul ReduceMean |
jnp_corrcoef_basic ✅ |
0.12.1 |
| jnp.cos |
Cos Sin |
jnp_cos_basic ✅
jnp_cos_basic_f64 ✅
jnp_cos_int_promote ✅
jnp_cos_int_promote_f64 ✅
jnp_cos_f64_via_sin ✅
cos_vmap_batching ✅
cos_vmap_batching_f64 ✅
cos_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.cosh |
Cosh |
jnp_cosh_basic ✅
jnp_cosh_basic_f64 ✅
jnp_cosh_int_promote ✅
jnp_cosh_int_promote_f64 ✅
cosh_vmap_batching ✅
cosh_vmap_batching_f64 ✅
cosh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.cov |
MatMul ReduceMean |
jnp_cov_basic ✅
jnp_cov_basic_f64 ✅ |
0.12.1 |
| jnp.cross |
Concat Mul Sub |
jnp_cross_vectors ✅
jnp_cross_vectors_f64 ✅ |
0.12.1 |
| jnp.cumprod |
CumProd |
jnp_cumprod_axis1 ✅
jnp_cumprod_axis1_f64 ✅
jnp_cumprod_axis_none_flatten ✅
jnp_cumprod_axis_none_flatten_f64 ✅
jnp_cumprod_dtype_cast ✅
jnp_cumprod_dtype_cast_f64 ✅ |
0.12.1 |
| jnp.cumsum |
CumSum |
jnp_cumsum_axis1 ✅
jnp_cumsum_axis1_f64 ✅
jnp_cumsum_reverse_dtype ✅
cumsum_axis2_i32 ✅
cumsum_axis2_i32_f64 ✅
cumsum_axis2_reverse_i32 ✅
cumsum_axis2_reverse_i32_f64 ✅
cumsum_vmap_batching ✅
cumsum_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.delete |
Concat Slice |
jnp_delete_axis0_index1 ✅
jnp_delete_axis0_index1_f64 ✅ |
0.12.1 |
| jnp.diag |
Gather Mul Reshape |
jnp_diag_from_vector ✅
jnp_diag_from_vector_f64 ✅
jnp_diag_from_matrix ✅
jnp_diag_from_matrix_f64 ✅
jnp_diag_from_matrix_k1 ✅
jnp_diag_from_matrix_k1_f64 ✅ |
0.12.2 |
| jnp.diagonal |
Gather Reshape |
jnp_diagonal_basic ✅
jnp_diagonal_basic_f64 ✅
jnp_diagonal_offset_neg1 ✅
jnp_diagonal_offset_neg1_f64 ✅ |
0.12.2 |
| jnp.divide |
Div |
jnp_divide_basic ✅
jnp_divide_basic_f64 ✅
jnp_divide_int_promote ✅
jnp_divide_int_promote_f64 ✅
jnp_divide_broadcast ✅
jnp_divide_broadcast_f64 ✅
divide_vmap_batching ✅
divide_vmap_batching_f64 ✅
divide_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.dot |
MatMul |
jnp_dot_vector_vector ✅
jnp_dot_vector_vector_f64 ✅
jnp_dot_matrix_vector ✅
jnp_dot_matrix_vector_f64 ✅
jnp_dot_matrix_matrix ✅
jnp_dot_matrix_matrix_f64 ✅ |
0.12.2 |
| jnp.einsum |
Einsum |
einsum_vector_dot ✅
einsum_vector_dot_f64 ✅
einsum_matrix_vector ✅
einsum_matrix_vector_f64 ✅
einsum_matrix_matrix_dynamic ✅
einsum_matrix_matrix_dynamic_f64 ✅
einsum_matrix_matrix ✅
einsum_matrix_matrix_f64 ✅
einsum_transpose ✅
einsum_transpose_f64 ✅
einsum_batch_transpose_dynamic ✅
einsum_batch_transpose_dynamic_f64 ✅
einsum_batch_transpose ✅
einsum_batch_transpose_f64 ✅
einsum_diag ✅
einsum_diag_f64 ✅
einsum_sum_reduce ✅
einsum_sum_reduce_f64 ✅
einsum_multi_operand ✅
einsum_multi_operand_f64 ✅
einsum_attention_logits_orig_dynamic ✅
einsum_attention_logits_orig_dynamic_f64 ✅
einsum_attention_logits_orig ✅
einsum_attention_logits_orig_f64 ✅
einsum_attention_output_orig_dynamic ✅
einsum_attention_output_orig_dynamic_f64 ✅
einsum_attention_output_orig ✅
einsum_attention_output_orig_f64 ✅
einsum_attention_logits_batched_dynamic ✅
einsum_attention_logits_batched_dynamic_f64 ✅
einsum_attention_logits_batched ✅
einsum_attention_logits_batched_f64 ✅
einsum_attention_output_batched_dynamic ✅
einsum_attention_output_batched_dynamic_f64 ✅
einsum_attention_output_batched ✅
einsum_attention_output_batched_f64 ✅
einsum_ellipsis_rank_mismatch ✅
einsum_ellipsis_rank_mismatch_f64 ✅
einsum_attention_logits_batched_rank_mismatch ✅
einsum_attention_logits_batched_rank_mismatch_f64 ✅
einsum_grad_issue_batch_diff_rules ✅ |
0.1.0 |
| jnp.equal |
Equal |
jnp_equal_basic ✅
jnp_equal_basic_f64 ✅
jnp_equal_broadcast ✅
jnp_equal_broadcast_f64 ✅
jnp_equal_mixed_dtype ✅
jnp_equal_mixed_dtype_f64 ✅ |
0.12.2 |
| jnp.exp |
Exp |
jnp_exp_basic ✅
jnp_exp_basic_f64 ✅
jnp_exp_int_promote ✅
jnp_exp_int_promote_f64 ✅
exp_vmap_batching ✅
exp_vmap_batching_f64 ✅
exp_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.exp2 |
Pow |
jnp_exp2_basic ✅
jnp_exp2_basic_f64 ✅
jnp_exp2_int_promote ✅
jnp_exp2_int_promote_f64 ✅
exp2_vmap_batching ✅
exp2_vmap_batching_f64 ✅
exp2_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.expm1 |
Exp Sub |
jnp_expm1_basic ✅
jnp_expm1_basic_f64 ✅
jnp_expm1_int_promote ✅
jnp_expm1_int_promote_f64 ✅
expm1_vmap_batching ✅
expm1_vmap_batching_f64 ✅
expm1_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.eye |
EyeLike |
jnp_eye_square ✅
jnp_eye_rect_k1 ✅ |
0.12.1 |
| jnp.fabs |
Abs |
jnp_fabs_basic ✅
jnp_fabs_basic_f64 ✅
jnp_fabs_int_promote ✅
jnp_fabs_int_promote_f64 ✅
fabs_vmap_batching ✅
fabs_vmap_batching_f64 ✅
fabs_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.fft |
DFT |
jnp_fft_complex64 ✅
jnp_fft_complex128 ✅ |
0.10.1 |
| jnp.fft_fftfreq |
Div Range |
jnp_fft_fftfreq_n8 ✅
jnp_fft_fftfreq_n8_f64 ✅ |
0.12.1 |
| jnp.fft_fftshift |
Concat Slice |
jnp_fft_fftshift_1d ✅
jnp_fft_fftshift_1d_f64 ✅ |
0.12.1 |
| jnp.fft_ifftshift |
Concat Slice |
jnp_fft_ifftshift_1d ✅
jnp_fft_ifftshift_1d_f64 ✅ |
0.12.1 |
| jnp.fft_irfft |
DFT |
jnp_irfft_complex64 ✅ |
0.12.1 |
| jnp.fft_rfftfreq |
Div Range |
jnp_fft_rfftfreq_n8 ✅
jnp_fft_rfftfreq_n8_f64 ✅ |
0.12.1 |
| jnp.fill_diagonal |
ScatterND |
jnp_fill_diagonal_basic ✅
jnp_fill_diagonal_basic_f64 ✅ |
0.12.1 |
| jnp.floor |
Floor Identity |
jnp_floor_basic ✅
jnp_floor_basic_f64 ✅
jnp_floor_int_identity ✅
jnp_floor_int_identity_f64 ✅
floor_vmap_batching ✅
floor_vmap_batching_f64 ✅
floor_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.floor_divide |
Div Floor Where |
jnp_floor_divide_float ✅
jnp_floor_divide_float_f64 ✅
jnp_floor_divide_int_neg ✅
jnp_floor_divide_int_neg_f64 ✅
jnp_floor_divide_broadcast ✅
jnp_floor_divide_broadcast_f64 ✅
floor_divide_vmap_batching ✅
floor_divide_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.fmod |
Div Mod Sub |
jnp_fmod_float ✅
jnp_fmod_float_f64 ✅
jnp_fmod_int ✅
jnp_fmod_int_f64 ✅
jnp_fmod_broadcast ✅
jnp_fmod_broadcast_f64 ✅
fmod_vmap_batching ✅
fmod_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.full |
ConstantOfShape |
jnp_full_const_scalar ✅
jnp_full_const_scalar_f64 ✅
jnp_full_expand_scalar_input ✅
jnp_full_expand_scalar_input_f64 ✅
jnp_full_expand_row_broadcast ✅
jnp_full_expand_row_broadcast_f64 ✅ |
0.12.1 |
| jnp.gcd |
Equal Loop |
jnp_gcd_basic ✅
jnp_gcd_basic_f64 ✅ |
0.12.1 |
| jnp.geomspace |
Exp Log Range |
jnp_geomspace_scalar ✅
jnp_geomspace_scalar_f64 ✅ |
0.12.1 |
| jnp.gradient |
Div Slice Sub |
jnp_gradient_1d ✅
jnp_gradient_1d_f64 ✅ |
0.12.1 |
| jnp.greater |
Greater |
jnp_greater_basic ✅
jnp_greater_basic_f64 ✅
jnp_greater_broadcast ✅
jnp_greater_broadcast_f64 ✅
jnp_greater_mixed_dtype ✅
jnp_greater_mixed_dtype_f64 ✅
greater_vmap_batching ✅
greater_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.greater_equal |
GreaterOrEqual |
jnp_greater_equal_basic ✅
jnp_greater_equal_basic_f64 ✅
jnp_greater_equal_broadcast ✅
jnp_greater_equal_broadcast_f64 ✅
jnp_greater_equal_mixed_dtype ✅
jnp_greater_equal_mixed_dtype_f64 ✅
greater_equal_vmap_batching ✅
greater_equal_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.hamming |
HammingWindow |
jnp_hamming_basic ✅ |
0.12.1 |
| jnp.hanning |
HannWindow |
jnp_hanning_basic ✅ |
0.12.1 |
| jnp.histogram_bin_edges |
Range ReduceMax ReduceMin |
jnp_histogram_bin_edges_basic ✅
jnp_histogram_bin_edges_basic_f64 ✅ |
0.12.1 |
| jnp.i0 |
Add Exp Mul |
jnp_i0_basic ✅
jnp_i0_basic_f64 ✅ |
0.12.1 |
| jnp.ifft |
DFT |
jnp_ifft_complex64 ✅ |
0.10.1 |
| jnp.invert |
BitwiseNot Not |
jnp_invert_bool ✅
jnp_invert_bool_f64 ✅
jnp_invert_int ✅
jnp_invert_int_f64 ✅
invert_vmap_batching ✅
invert_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.isfinite |
IsInf IsNaN Not Or |
jnp_isfinite_basic ✅
jnp_isfinite_basic_f64 ✅ |
0.12.2 |
| jnp.lcm |
Div Loop Mul |
jnp_lcm_basic ✅
jnp_lcm_basic_f64 ✅ |
0.12.1 |
| jnp.left_shift |
BitShift |
jnp_left_shift_int ✅
jnp_left_shift_int_f64 ✅
jnp_left_shift_unsigned ✅
jnp_left_shift_unsigned_f64 ✅
jnp_left_shift_mixed_bool_rhs ✅
jnp_left_shift_mixed_bool_rhs_f64 ✅
jnp_left_shift_broadcast ✅
jnp_left_shift_broadcast_f64 ✅
left_shift_vmap_batching ✅
left_shift_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.less |
Less |
jnp_less_basic ✅
jnp_less_basic_f64 ✅
jnp_less_broadcast ✅
jnp_less_broadcast_f64 ✅
jnp_less_mixed_dtype ✅
jnp_less_mixed_dtype_f64 ✅
less_vmap_batching ✅
less_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.less_equal |
LessOrEqual |
jnp_less_equal_basic ✅
jnp_less_equal_basic_f64 ✅
jnp_less_equal_broadcast ✅
jnp_less_equal_broadcast_f64 ✅
jnp_less_equal_mixed_dtype ✅
jnp_less_equal_mixed_dtype_f64 ✅
less_equal_vmap_batching ✅
less_equal_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.linalg_cholesky |
Cholesky |
jnp_linalg_cholesky_basic ✅ |
0.12.1 |
| jnp.linalg_cond |
Div SVD |
jnp_linalg_cond_basic ✅ |
0.12.1 |
| jnp.linalg_det |
Det |
linalg_det_2x2 ✅
linalg_det_batched ✅ |
0.12.1 |
| jnp.linalg_eig |
Eig |
jnp_linalg_eig_basic ✅ |
0.12.1 |
| jnp.linalg_eigh |
Eigh |
jnp_linalg_eigh_basic ✅ |
0.12.1 |
| jnp.linalg_eigvals |
Eig |
jnp_linalg_eigvals_basic ✅ |
0.12.1 |
| jnp.linalg_eigvalsh |
Eigh |
jnp_linalg_eigvalsh_basic ✅ |
0.12.1 |
| jnp.linalg_lstsq |
MatMul SVD |
jnp_linalg_lstsq_basic ✅ |
0.12.1 |
| jnp.linalg_matrix_norm |
Mul ReduceSum Sqrt |
jnp_linalg_matrix_norm_basic ✅ |
0.12.1 |
| jnp.linalg_matrix_power |
MatMul |
jnp_linalg_matrix_power_basic ✅ |
0.12.1 |
| jnp.linalg_matrix_rank |
ReduceSum SVD |
jnp_linalg_matrix_rank_basic ✅ |
0.12.1 |
| jnp.linalg_multi_dot |
MatMul |
jnp_linalg_multi_dot_basic ✅ |
0.12.1 |
| jnp.linalg_norm |
GlobalLpPool Transpose |
linalg_norm_global_fro ✅
linalg_norm_global_default ✅
linalg_norm_global_ord1 ✅
linalg_norm_global_ord1_nokeepdims ✅ |
0.12.1 |
| jnp.linalg_pinv |
MatMul SVD |
jnp_linalg_pinv_basic ✅ |
0.12.1 |
| jnp.linalg_qr |
QR |
jnp_linalg_qr_basic ✅ |
0.12.1 |
| jnp.linalg_svd |
SVD |
jnp_linalg_svd_basic ✅ |
0.12.1 |
| jnp.linalg_svdvals |
SVD |
jnp_linalg_svdvals_basic ✅ |
0.12.1 |
| jnp.linalg_tensordot |
MatMul Transpose |
jnp_linalg_tensordot_basic ✅ |
0.12.1 |
| jnp.linalg_trace |
Gather ReduceSum |
jnp_linalg_trace_basic ✅ |
0.12.1 |
| jnp.linalg_vecdot |
Mul ReduceSum |
jnp_linalg_vecdot_basic ✅ |
0.12.1 |
| jnp.linalg_vector_norm |
Mul ReduceSum Sqrt |
jnp_linalg_vector_norm_basic ✅ |
0.12.1 |
| jnp.linspace |
Add Range |
linspace_static_basic ✅
linspace_static_basic_f64 ✅
linspace_static_endpoint_false ✅
linspace_static_endpoint_false_f64 ✅
linspace_static_int_inputs_default_dtype ✅
linspace_static_int_inputs_default_dtype_f64 ✅
linspace_basic_f32 ✅
linspace_basic_f32_f64 ✅
linspace_endpoint_false_i32 ✅
linspace_endpoint_false_i32_f64 ✅
linspace_num_zero ✅
linspace_num_zero_f64 ✅
linspace_num_one ✅
linspace_num_one_f64 ✅
linspace_static_num_0 ✅
linspace_static_num_0_f64 ✅
linspace_static_num_1 ✅
linspace_static_num_1_f64 ✅
linspace_vmap_batching ✅
linspace_vmap_batching_f64 ✅ |
0.5.2 |
| jnp.log |
Log ReduceLogSumExp ReduceLogSum |
jnp_log_basic ✅
jnp_log_basic_f64 ✅
jnp_log_int_promote ✅
jnp_log_int_promote_f64 ✅
log_vmap_batching ✅
log_vmap_batching_f64 ✅
log_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.logspace |
Pow Range |
jnp_logspace_basic ✅ |
0.12.1 |
| jnp.matmul |
MatMul |
matmul_1d ✅
matmul_1d_f64 ✅
matmul_1d_2d ✅
matmul_1d_2d_f64 ✅
matmul_2d ✅
matmul_2d_f64 ✅
matmul_2d_1d ✅
matmul_2d_1d_f64 ✅
matmul_3d ✅
matmul_3d_f64 ✅
matmul_dynamic_dynamic ✅
matmul_dynamic_dynamic_f64 ✅
matmul_dynamic ✅
matmul_dynamic_f64 ✅
matmul_dynamic_a_dynamic ✅
matmul_dynamic_a_dynamic_f64 ✅
matmul_dynamic_a ✅
matmul_dynamic_a_f64 ✅
matmul_complex64 ✅
matmul_complex64_f64 ✅
matmul_vmap_batching ✅
matmul_vmap_batching_f64 ✅
matmul_grad_issue_batch_diff_rules ✅ |
0.1.0 |
| jnp.matvec |
MatMul |
jnp_matvec_basic ✅ |
0.12.1 |
| jnp.max |
ReduceMax |
jnp_max_basic ✅
jnp_max_basic_f64 ✅
jnp_max_axis ✅
jnp_max_axis_f64 ✅
jnp_max_keepdims ✅
jnp_max_keepdims_f64 ✅
max_vmap_batching ✅
max_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.maximum |
Max |
jnp_maximum_basic ✅
jnp_maximum_basic_f64 ✅
jnp_maximum_broadcast_scalar ✅
jnp_maximum_broadcast_scalar_f64 ✅ |
0.12.2 |
| jnp.mean |
ReduceMean |
basic_mean ✅
basic_mean_f64 ✅
mean_with_axis ✅
mean_with_axis_f64 ✅
mean_with_keepdims ✅
mean_with_keepdims_f64 ✅
jnp_mean_basic ✅
jnp_mean_basic_f64 ✅
jnp_mean_axis ✅
jnp_mean_axis_f64 ✅
mean_grad_issue_batch_diff_rules ✅ |
0.12.0 |
| jnp.median |
ReduceMean TopK |
jnp_median_basic ✅ |
0.12.1 |
| jnp.min |
ReduceMin |
jnp_min_basic ✅
jnp_min_basic_f64 ✅
jnp_min_axis ✅
jnp_min_axis_f64 ✅
jnp_min_keepdims ✅
jnp_min_keepdims_f64 ✅
min_vmap_batching ✅
min_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.minimum |
Min |
jnp_minimum_basic ✅
jnp_minimum_basic_f64 ✅
jnp_minimum_broadcast_scalar ✅
jnp_minimum_broadcast_scalar_f64 ✅ |
0.12.2 |
| jnp.moveaxis |
Transpose |
jnp_moveaxis_2d ✅
jnp_moveaxis_2d_f64 ✅
jnp_moveaxis_3d_tuple ✅
jnp_moveaxis_3d_tuple_f64 ✅
jnp_moveaxis_vmap_batching ✅
jnp_moveaxis_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.nanargmax |
ArgMax IsNaN Where |
jnp_nanargmax_basic ✅ |
0.12.1 |
| jnp.nanargmin |
ArgMin IsNaN Where |
jnp_nanargmin_basic ✅ |
0.12.1 |
| jnp.nancumsum |
CumSum IsNaN Where |
jnp_nancumsum_basic ✅ |
0.12.1 |
| jnp.nanmax |
IsNaN ReduceMax Where |
jnp_nanmax_basic ✅ |
0.12.1 |
| jnp.nanmean |
IsNaN ReduceSum Where |
jnp_nanmean_basic ✅ |
0.12.1 |
| jnp.nanmedian |
TopK Where |
jnp_nanmedian_basic ✅ |
0.12.1 |
| jnp.nanmin |
IsNaN ReduceMin Where |
jnp_nanmin_basic ✅ |
0.12.1 |
| jnp.nanpercentile |
TopK Where |
jnp_nanpercentile_basic ✅ |
0.12.1 |
| jnp.nanprod |
IsNaN ReduceProd Where |
jnp_nanprod_basic ✅ |
0.12.1 |
| jnp.nanquantile |
TopK Where |
jnp_nanquantile_basic ✅ |
0.12.1 |
| jnp.nanstd |
ReduceSum Sqrt |
jnp_nanstd_basic ✅ |
0.12.1 |
| jnp.nansum |
IsNaN ReduceSum Where |
jnp_nansum_basic ✅ |
0.12.1 |
| jnp.nanvar |
Mul ReduceSum |
jnp_nanvar_basic ✅ |
0.12.1 |
| jnp.ndarray_at |
ScatterND |
jnp_ndarray_at_set_basic ✅ |
0.12.1 |
| jnp.ones |
ConstantOfShape |
jnp_ones_2x3 ✅
jnp_ones_2x3_f64 ✅
jnp_ones_bool ✅
jnp_ones_bool_f64 ✅ |
0.12.2 |
| jnp.outer |
Mul |
outer_vector ✅
outer_vector_f64 ✅
outer ✅
outer_f64 ✅
outer_vmap_batching ✅
outer_vmap_batching_f64 ✅
outer_grad_issue_batch_diff_rules ✅ |
0.10.0 |
| jnp.pad |
Pad |
jnp_pad_constant_1d ✅
jnp_pad_constant_1d_f64 ✅
jnp_pad_constant_2d_tuple ✅
jnp_pad_constant_2d_tuple_f64 ✅
jnp_pad_constant_vmap_batching ✅
jnp_pad_constant_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.partition |
Concat TopK |
jnp_partition_basic ✅ |
0.12.1 |
| jnp.percentile |
ReduceMean TopK |
jnp_percentile_basic ✅ |
0.12.1 |
| jnp.polyadd |
Add Pad |
jnp_polyadd_basic ✅ |
0.12.1 |
| jnp.polyder |
Mul Slice |
jnp_polyder_basic ✅ |
0.12.1 |
| jnp.polydiv |
Div Sub |
jnp_polydiv_basic ✅ |
0.12.1 |
| jnp.polyint |
Concat Div |
jnp_polyint_basic ✅ |
0.12.1 |
| jnp.polysub |
Pad Sub |
jnp_polysub_basic ✅ |
0.12.1 |
| jnp.pow |
Pow |
jnp_pow_vector ✅
jnp_pow_vector_f64 ✅
pow_jnp_pow ✅
pow_jnp_pow_f64 ✅
pow_vmap_batching ✅
pow_vmap_batching_f64 ✅
pow_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.power |
Pow |
jnp_power_vector ✅
jnp_power_vector_f64 ✅
pow_jnp_power ✅
pow_jnp_power_f64 ✅
power_vmap_batching ✅
power_vmap_batching_f64 ✅
power_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.prod |
ReduceProd |
basic_prod ✅
basic_prod_f64 ✅
prod_with_axis ✅
prod_with_axis_f64 ✅
prod_with_keepdims ✅
prod_with_keepdims_f64 ✅
jnp_prod_basic ✅
jnp_prod_basic_f64 ✅
jnp_prod_axis ✅
jnp_prod_axis_f64 ✅
jnp_prod_keepdims ✅
jnp_prod_keepdims_f64 ✅
prod_grad_issue_batch_diff_rules ✅
prod_vmap_batching ✅
prod_vmap_batching_f64 ✅ |
0.8.0 |
| jnp.put |
ScatterND |
jnp_put_basic ✅ |
0.12.1 |
| jnp.put_along_axis |
ScatterND |
jnp_put_along_axis_basic ✅ |
0.12.1 |
| jnp.quantile |
ReduceMean TopK |
jnp_quantile_basic ✅ |
0.12.1 |
| jnp.ravel_multi_index |
Add Mod Mul |
jnp_ravel_multi_index_wrap ✅ |
0.12.1 |
| jnp.reshape |
Flatten Reshape |
reshape_1 ✅
reshape_1_f64 ✅
reshape_2 ✅
reshape_2_f64 ✅
reshape_flatten_static_uses_flatten ✅
reshape_flatten_static_uses_flatten_f64 ✅
reshape_3 ✅
reshape_3_f64 ✅
reshape_4_dynamic ✅
reshape_4_dynamic_f64 ✅
reshape_4 ✅
reshape_4_f64 ✅
reshape_to_scalar ✅
reshape_to_scalar_f64 ✅
reshape_from_scalar ✅
reshape_from_scalar_f64 ✅
reshape_cnn_dynamic ✅
reshape_cnn_dynamic_f64 ✅
reshape_cnn ✅
reshape_cnn_f64 ✅
reshape_valid_flatten_trailing ✅
reshape_valid_flatten_trailing_f64 ✅
reshape_with_target_shape_from_symbolic_dim_computation ✅
reshape_with_target_shape_from_symbolic_dim_computation_f64 ✅
reshape_basic ✅
reshape_basic_f64 ✅
reshape_infer ✅
reshape_infer_f64 ✅
reshape_symbolic_flatten_dynamic ✅
reshape_symbolic_flatten_dynamic_f64 ✅
reshape_symbolic_flatten ✅
reshape_symbolic_flatten_f64 ✅
reshape_vmap_batching_issue_144 ✅
reshape_vmap_batching_issue_144_f64 ✅
reshape_grad_issue_batch_diff_rules ✅ |
0.1.0 |
| jnp.rfft |
DFT |
jnp_rfft_float32 ✅ |
0.10.1 |
| jnp.right_shift |
BitShift |
jnp_right_shift_signed_arithmetic ✅
jnp_right_shift_signed_arithmetic_f64 ✅
jnp_right_shift_unsigned_logical ✅
jnp_right_shift_unsigned_logical_f64 ✅
jnp_right_shift_mixed_bool_rhs ✅
jnp_right_shift_mixed_bool_rhs_f64 ✅
jnp_right_shift_broadcast ✅
jnp_right_shift_broadcast_f64 ✅
right_shift_vmap_batching ✅
right_shift_vmap_batching_f64 ✅ |
0.12.2 |
| jnp.roll |
Concat Slice |
jnp_roll_basic ✅ |
0.12.1 |
| jnp.rot90 |
ReverseSequence Transpose |
jnp_rot90_basic ✅ |
0.12.1 |
| jnp.select |
Where |
select_simple ✅
select_simple_f64 ✅
select_broadcast ✅
select_broadcast_f64 ✅
select_gpt2_attention_mask_dynamic ✅
select_gpt2_attention_mask_dynamic_f64 ✅
select_gpt2_attention_mask ✅
select_gpt2_attention_mask_f64 ✅
select_basic ✅
select_basic_f64 ✅
select_vmap_batching ✅
select_vmap_batching_f64 ✅
select_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| jnp.shape |
Shape |
shape_basic ✅
shape_basic_f64 ✅
shape_dynamic_dynamic ✅
shape_dynamic_dynamic_f64 ✅
shape_dynamic ✅
shape_dynamic_f64 ✅
shape_vmap_batching ✅
shape_vmap_batching_f64 ✅ |
0.4.0 |
| jnp.sign |
Sign |
jnp_sign_basic ✅
jnp_sign_basic_f64 ✅
jnp_sign_int ✅
jnp_sign_int_f64 ✅
sign_vmap_batching ✅
sign_vmap_batching_f64 ✅
sign_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.sin |
Sin |
jnp_sin_basic ✅
jnp_sin_basic_f64 ✅
jnp_sin_int_promote ✅
jnp_sin_int_promote_f64 ✅
sin_vmap_batching ✅
sin_vmap_batching_f64 ✅
sin_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.sinc |
Div Sin |
jnp_sinc_basic ✅ |
0.12.1 |
| jnp.sinh |
Sinh |
jnp_sinh_basic ✅
jnp_sinh_basic_f64 ✅
jnp_sinh_int_promote ✅
jnp_sinh_int_promote_f64 ✅
sinh_vmap_batching ✅
sinh_vmap_batching_f64 ✅
sinh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.size |
Size |
jnp_size_all ✅
jnp_size_all_f64 ✅
jnp_size_axis ✅
jnp_size_axis_f64 ✅
jnp_size_axis_tuple ✅
jnp_size_axis_tuple_f64 ✅
jnp_size_dynamic_dynamic ✅
jnp_size_dynamic_dynamic_f64 ✅
jnp_size_dynamic ✅
jnp_size_dynamic_f64 ✅ |
0.12.1 |
| jnp.sort |
TopK |
sort_1d ✅
sort_1d_f64 ✅
sort_2d_axis0 ✅
sort_2d_axis0_f64 ✅
sort_basic ✅
sort_basic_f64 ✅
sort_vmap_batching ✅
sort_vmap_batching_f64 ✅ |
0.5.2 |
| jnp.split |
Split |
split_by_sections ✅
split_by_sections_f64 ✅
split_by_indices ✅
split_by_indices_f64 ✅
split_by_indices_symbolic_dynamic ✅
split_by_indices_symbolic_dynamic_f64 ✅
split_by_indices_symbolic ✅
split_by_indices_symbolic_f64 ✅
split_sections ✅
split_sections_f64 ✅
split_indices_numpy ✅
split_indices_numpy_f64 ✅
split_grad_issue_batch_diff_rules ✅ |
0.7.2 |
| jnp.sqrt |
ReduceL2 Sqrt |
jnp_sqrt_basic ✅
jnp_sqrt_basic_f64 ✅
jnp_sqrt_int_promote ✅
jnp_sqrt_int_promote_f64 ✅
jnp_sqrt_reduce_sum_square_axis1 ✅
jnp_sqrt_reduce_sum_square_axis1_f64 ✅
sqrt_vmap_batching ✅
sqrt_vmap_batching_f64 ✅
sqrt_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.squeeze |
Squeeze |
squeeze_single_dim ✅
squeeze_single_dim_f64 ✅
squeeze_multiple_dims ✅
squeeze_multiple_dims_f64 ✅
squeeze_vit_output ✅
squeeze_vit_output_f64 ✅
squeeze_dynamic_batch_dynamic ✅
squeeze_dynamic_batch_dynamic_f64 ✅
squeeze_dynamic_batch ✅
squeeze_dynamic_batch_f64 ✅
squeeze_all_dims ✅
squeeze_all_dims_f64 ✅
squeeze_negative_axis ✅
squeeze_negative_axis_f64 ✅
squeeze_negative_axis_tuple ✅
squeeze_negative_axis_tuple_f64 ✅
squeeze_dynamic_and_negative_axis_dynamic ✅
squeeze_dynamic_and_negative_axis_dynamic_f64 ✅
squeeze_dynamic_and_negative_axis ✅
squeeze_dynamic_and_negative_axis_f64 ✅
squeeze_vmap_batching ✅
squeeze_vmap_batching_f64 ✅
squeeze_grad_issue_batch_diff_rules ✅ |
0.1.0 |
| jnp.stack |
Concat Unsqueeze |
stack_axis_0 ✅
stack_axis_0_f64 ✅
stack_axis_1 ✅
stack_axis_1_f64 ✅
stack_negative_axis ✅
stack_negative_axis_f64 ✅
stack_scalars ✅
stack_scalars_f64 ✅
jnp_stack_axis0 ✅
jnp_stack_axis0_f64 ✅
jnp_stack_axis1 ✅
jnp_stack_axis1_f64 ✅
jnp_stack_negative_axis ✅
jnp_stack_negative_axis_f64 ✅
jnp_stack_scalars ✅
jnp_stack_scalars_f64 ✅
stack_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.std |
ReduceMean Sqrt |
jnp_std_basic ✅ |
0.12.1 |
| jnp.sum |
ReduceSum |
jnp_sum_basic ✅
jnp_sum_basic_f64 ✅
jnp_sum_axis ✅
jnp_sum_axis_f64 ✅
jnp_sum_keepdims ✅
jnp_sum_keepdims_f64 ✅
jnp_sum_int8_promote ✅
sum_vmap_batching ✅
sum_vmap_batching_f64 ✅
sum_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.take |
Gather |
take_data_dependent_indices ✅
take_basic_axis1 ✅
take_basic_axis1_f64 ✅
take_vmap_batching ✅
take_vmap_batching_f64 ✅
take_grad_issue_batch_diff_rules ✅ |
0.7.0 |
| jnp.tan |
Tan |
jnp_tan_basic ✅
jnp_tan_basic_f64 ✅
jnp_tan_int_promote ✅
jnp_tan_int_promote_f64 ✅
tan_vmap_batching ✅
tan_vmap_batching_f64 ✅
tan_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.tanh |
Tanh |
jnp_tanh_basic ✅
jnp_tanh_basic_f64 ✅
jnp_tanh_int_promote ✅
jnp_tanh_int_promote_f64 ✅
tanh_vmap_batching ✅
tanh_vmap_batching_f64 ✅
tanh_grad_issue_batch_diff_rules ✅ |
0.12.2 |
| jnp.tensordot |
MatMul Transpose |
jnp_tensordot_basic ✅ |
0.12.1 |
| jnp.tile |
Tile |
tile_repeats ✅
tile_repeats_f64 ✅
tile_a ✅
tile_a_f64 ✅
tile_b ✅
tile_b_f64 ✅
tile_c ✅
tile_c_f64 ✅
tile_d ✅
tile_d_f64 ✅
tile_dynamic_input_static ✅
tile_dynamic_input_static_f64 ✅
tile_dynamic_input_dynamic ✅
tile_dynamic_input_dynamic_f64 ✅
tile_dynamic_input ✅
tile_dynamic_input_f64 ✅
tile_pad ✅
tile_pad_f64 ✅
tile_param_symbolic_dynamic ✅
tile_param_symbolic_dynamic_f64 ✅
tile_param_symbolic ✅
tile_param_symbolic_f64 ✅
tile_with_symbolic_repeats_static ✅
tile_with_symbolic_repeats_static_f64 ✅
tile_with_symbolic_repeats_dynamic ✅
tile_with_symbolic_repeats_dynamic_f64 ✅
tile_with_symbolic_repeats ✅
tile_with_symbolic_repeats_f64 ✅
jnp_tile_basic ✅
jnp_tile_basic_f64 ✅
jnp_tile_scalar_repeats ✅
jnp_tile_scalar_repeats_f64 ✅
jnp_tile_pad_rank ✅
jnp_tile_pad_rank_f64 ✅
jnp_tile_symbolic_dynamic ✅
jnp_tile_symbolic_dynamic_f64 ✅
jnp_tile_symbolic ✅
jnp_tile_symbolic_f64 ✅
tile_vmap_batching ✅
tile_vmap_batching_f64 ✅
tile_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.trace |
Gather ReduceSum |
jnp_trace_basic ✅ |
0.12.1 |
| jnp.transpose |
Transpose |
transpose_basic ✅
transpose_basic_f64 ✅
transpose_reverse_default ✅
transpose_reverse_default_f64 ✅
transpose_high_dim ✅
transpose_high_dim_f64 ✅
transpose_3d ✅
transpose_3d_f64 ✅
transpose_4d ✅
transpose_4d_f64 ✅
transpose_no_axes ✅
transpose_no_axes_f64 ✅
transpose_reverse ✅
transpose_reverse_f64 ✅
transpose_square_matrix ✅
transpose_square_matrix_f64 ✅
transpose_vmap_batching ✅
transpose_vmap_batching_f64 ✅
transpose_grad_issue_batch_diff_rules ✅ |
0.1.0 |
| jnp.trapezoid |
Add ReduceSum Slice |
jnp_trapezoid_basic ✅ |
0.12.1 |
| jnp.tri |
LessOrEqual Range |
jnp_tri_basic ✅ |
0.12.1 |
| jnp.trilu |
Trilu |
jnp_triu_basic ✅
jnp_triu_basic_f64 ✅
jnp_tril_negative_k ✅
jnp_tril_negative_k_f64 ✅
jnp_triu_symbolic_batch_dynamic ✅
jnp_triu_symbolic_batch_dynamic_f64 ✅
jnp_triu_symbolic_batch ✅
jnp_triu_symbolic_batch_f64 ✅
jnp_trilu_vmap_batching ✅
jnp_trilu_vmap_batching_f64 ✅ |
0.12.1 |
| jnp.unique |
Unique |
jnp_unique_f32_size_fill ✅
jnp_unique_i32_size_fill ✅
jnp_unique_symbolic_size_fill_dynamic ✅
jnp_unique_symbolic_size_fill ✅ |
0.12.1 |
| jnp.unravel_index |
Concat Div Mod |
jnp_unravel_index_basic ✅
jnp_unravel_index_basic_f64 ✅ |
0.12.1 |
| jnp.unstack |
Split Squeeze |
unstack_axis_0 ✅
unstack_axis_0_f64 ✅
unstack_axis_1 ✅
unstack_axis_1_f64 ✅
unstack_negative_axis ✅
unstack_negative_axis_f64 ✅
unstack_vmap_batching ✅
unstack_vmap_batching_f64 ✅ |
0.7.1 |
| jnp.var |
Mul ReduceMean |
jnp_var_basic ✅ |
0.12.1 |
| jnp.vdot |
Mul ReduceSum |
jnp_vdot_basic ✅ |
0.12.1 |
| jnp.vecdot |
Mul ReduceSum |
jnp_vecdot_basic ✅ |
0.12.1 |
| jnp.vecmat |
MatMul |
jnp_vecmat_basic ✅ |
0.12.1 |
| jnp.vectorize |
Add |
jnp_vectorize_basic ✅ |
0.12.1 |
| jnp.where |
Where |
where_simple ✅
where_simple_f64 ✅
where_broadcast ✅
where_broadcast_f64 ✅
where_gpt_mask_scores_literal_else_dynamic ✅
where_gpt_mask_scores_literal_else_dynamic_f64 ✅
where_gpt_mask_scores_literal_else ✅
where_gpt_mask_scores_literal_else_f64 ✅
where_multidim_condition_scalar_branches_broadcast ✅
where_multidim_condition_scalar_branches_broadcast_f64 ✅
where_A ✅
where_A_f64 ✅
where_B ✅
where_B_f64 ✅
where_gpt_mask_scores_scalar_else_dynamic ✅
where_gpt_mask_scores_scalar_else_dynamic_f64 ✅
where_gpt_mask_scores_scalar_else ✅
where_gpt_mask_scores_scalar_else_f64 ✅
where_int_condition_cast ✅
where_int_condition_cast_f64 ✅
where_literal_else_pyfloat ✅
where_literal_else_pyfloat_f64 ✅
where_jax_int_literals_broadcast_f64_mode ✅
where_dtype_mismatch_f64_vs_i32_promote ✅
jnp_where_basic ✅
jnp_where_basic_f64 ✅
jnp_where_broadcast ✅
jnp_where_broadcast_f64 ✅
jnp_where_scalar_else ✅
jnp_where_scalar_else_f64 ✅
where_vmap_batching ✅
where_vmap_batching_f64 ✅
where_grad_issue_batch_diff_rules ✅ |
0.8.0 |
| jnp.zeros |
ConstantOfShape |
jnp_zeros_2x3 ✅
jnp_zeros_2x3_f64 ✅
jnp_zeros_int32 ✅
jnp_zeros_int32_f64 ✅ |
0.12.2 |
| lax.abs |
Abs |
abs ✅
abs_f64 ✅ |
0.5.0 |
| lax.acos |
Acos |
acos ✅ |
0.12.1 |
| lax.acosh |
Acosh |
acosh ✅ |
0.12.1 |
| lax.add |
Add |
add ✅
add_f64 ✅
add_const ✅
add_const_f64 ✅
add_complex64 ✅
add_complex64_f64 ✅ |
0.2.0 |
| lax.add_any |
Add Sum |
add_any_via_jvp_on_mul ✅
add_any_via_jvp_on_mul_f64 ✅ |
0.8.0 |
| lax.and |
And BitwiseAnd |
and_bool ✅
and_bool_f64 ✅
and_int ✅
and_int_f64 ✅ |
0.6.5 |
| lax.approx_top_k |
TopK |
approx_max_k_matrix ✅
approx_max_k_matrix_f64 ✅
approx_min_k_axis0 ✅
approx_min_k_axis0_f64 ✅ |
0.12.1 |
| lax.argmax |
ArgMax |
argmax_float_axis0 ✅
argmax_float_axis0_f64 ✅
argmax_float_axis1 ✅
argmax_float_axis1_f64 ✅
argmax_boolean_input_axis0_specific_values ✅
argmax_boolean_input_axis0_specific_values_f64 ✅
argmax_boolean_input_axis1_specific_values ✅
argmax_boolean_input_axis1_specific_values_f64 ✅
argmax_boolean_random_input_axis0 ✅
argmax_boolean_random_input_axis0_f64 ✅ |
0.2.0 |
| lax.argmin |
ArgMin |
argmin_test1 ✅
argmin_test1_f64 ✅
argmin_test2 ✅
argmin_test2_f64 ✅ |
0.2.0 |
| lax.asin |
Asin |
asin ✅ |
0.12.1 |
| lax.asinh |
Asinh |
asinh ✅ |
0.12.1 |
| lax.atan |
Atan |
atan ✅ |
0.12.1 |
| lax.atan2 |
Add Atan Div Equal Greater Less Or Sub Where |
atan2_quadrants_and_zero ✅
atan2_quadrants_and_zero_f64 ✅
atan2_broadcast ✅ |
0.12.1 |
| lax.atanh |
Atanh |
atanh ✅ |
0.12.1 |
| lax.bessel_i0e |
Abs Exp Where |
bessel_i0e_basic ✅
bessel_i0e_basic_f64 ✅ |
0.12.1 |
| lax.bessel_i1e |
Abs Sign Where |
bessel_i1e_basic ✅
bessel_i1e_basic_f64 ✅ |
0.12.1 |
| lax.betainc |
Exp Pow Where |
betainc_basic ✅
betainc_edge_x ✅ |
0.12.1 |
| lax.bitcast_convert_type |
BitCast |
bitcast_scalar_f32_to_i32 ✅
bitcast_scalar_f32_to_i32_f64 ✅
bitcast_tensor_i32_to_f32 ✅
bitcast_tensor_i32_to_f32_f64 ✅ |
0.7.2 |
| lax.bitwise_not |
BitwiseNot Not |
bitwise_not_bool ✅
bitwise_not_bool_f64 ✅
bitwise_not_i32 ✅
bitwise_not_i32_f64 ✅ |
0.7.5 |
| lax.broadcast_in_dim |
Expand Identity Reshape |
broadcast_in_dim ✅
broadcast_in_dim_f64 ✅
broadcast_in_dim_2d_to_3d ✅
broadcast_in_dim_2d_to_3d_f64 ✅
broadcast_in_dim_scalar ✅
broadcast_in_dim_scalar_f64 ✅
broadcast_in_dim_batch_dynamic ✅
broadcast_in_dim_batch_dynamic_f64 ✅
broadcast_in_dim_batch ✅
broadcast_in_dim_batch_f64 ✅
broadcast_in_dim_dynamic_B_dynamic ✅
broadcast_in_dim_dynamic_B_dynamic_f64 ✅
broadcast_in_dim_dynamic_B ✅
broadcast_in_dim_dynamic_B_f64 ✅ |
0.2.0 |
| lax.cbrt |
Abs Mul Pow Sign |
cbrt ✅
cbrt_f64 ✅ |
0.12.1 |
| lax.ceil |
Ceil |
ceil ✅
ceil_f64 ✅ |
0.12.1 |
| lax.cholesky |
Concat Div Gather Mul ScatterND Sqrt Squeeze Sub Unsqueeze |
cholesky_spd_3x3 ✅
cholesky_spd_3x3_f64 ✅
cholesky_diagonal ✅
cholesky_diagonal_f64 ✅
cholesky_batched_2x3x3 ✅
cholesky_batched_2x3x3_f64 ✅ |
0.12.1 |
| lax.cholesky_update |
Add Div Mul ScatterND Sqrt Sub |
cholesky_update_upper_3x3 ✅
cholesky_update_upper_3x3_f64 ✅
cholesky_update_identity ✅
cholesky_update_identity_f64 ✅ |
0.12.1 |
| lax.clamp |
Max Min |
clamp_i32_scalar_bounds ✅
clamp_i32_scalar_bounds_f64 ✅
clamp_scalar_float_bounds_match_x ✅
clamp_scalar_float_bounds_match_x_f64 ✅
clamp_vector_bounds_match ✅
clamp_pyint_bounds_promote_to_x_dtype ✅
clamp_pyint_bounds_promote_to_x_dtype_f64 ✅ |
0.7.5 |
| lax.clz |
BitShift BitwiseAnd |
clz_i32 ✅
clz_i32_f64 ✅
clz_u8 ✅
clz_u8_f64 ✅ |
0.12.1 |
| lax.complex |
Concat Unsqueeze |
complex_scalar ✅
complex_scalar_f64 ✅
complex_array ✅
complex_array_f64 ✅
complex_double_precision ✅ |
0.12.2 |
| lax.concatenate |
Cast Concat |
concatenate ✅
concatenate_f64 ✅
concatenate_axis1_dynamic ✅
concatenate_axis1_dynamic_f64 ✅
concatenate_axis1 ✅
concatenate_axis1_f64 ✅
concatenate_axis0 ✅
concatenate_axis0_f64 ✅
concatenate_3d ✅
concatenate_3d_f64 ✅
concatenate_internal_int32_then_cast_to_f32_zeroarg ✅ |
0.2.0 |
| lax.cond |
If |
cond_scalar ✅
cond_scalar_f64 ✅
cond_multiple_operands_in_tuple ✅
cond_multiple_operands_in_tuple_f64 ✅
cond_my_new_complex_scenario ✅
cond_my_new_complex_scenario_f64 ✅
cond_nested_conditional ✅
cond_nested_conditional_f64 ✅
cond_variables ✅
cond_variables_f64 ✅
cond_internal_constant_f64 ✅
cond_passthrough_identity ✅
cond_passthrough_identity_f64 ✅
cond_with_scatter ✅
cond_with_scatter_f64 ✅ |
0.5.1 |
| lax.conj |
Identity |
conj_real ✅
conj_real_f64 ✅
conj_complex64 ✅
conj_complex64_f64 ✅ |
0.10.1 |
| lax.conv |
Conv |
conv ✅
conv2 ✅
conv_nchw ✅
conv_nhwc ✅
conv_general_dilated_nhwc_output ✅
conv_complex64 ✅
conv_complex64_nhwc ✅
conv_complex128_grouped ✅ |
0.2.0 |
| lax.convert_element_type |
Cast |
convert_element_type ✅
convert_element_type_f64 ✅ |
0.2.0 |
lax.copy (Handles the JAX primitive lax.copy_p. The public API jax.lax.copy is deprecated, but the primitive still appears in transformed jaxprs.) |
Identity |
copy_float32_array ✅
copy_int64_scalar ✅ |
|
| lax.cos |
Cos |
cos ✅
cos_f64 ✅ |
0.4.4 |
| lax.cosh |
Cosh |
cosh ✅
cosh_f64 ✅ |
0.4.4 |
| lax.cumlogsumexp |
CumSum Exp Log |
cumlogsumexp_axis1 ✅
cumlogsumexp_axis1_f64 ✅
cumlogsumexp_reverse_last_axis ✅
cumlogsumexp_reverse_last_axis_f64 ✅ |
0.12.1 |
| lax.cummax |
MaxPool Reshape Transpose |
cummax_axis1 ✅
cummax_axis1_f64 ✅
cummax_reverse_last_axis ✅
cummax_reverse_last_axis_f64 ✅ |
0.12.1 |
| lax.cummin |
MaxPool Neg Reshape Transpose |
cummin_axis1 ✅
cummin_axis1_f64 ✅
cummin_reverse_last_axis ✅
cummin_reverse_last_axis_f64 ✅ |
0.12.1 |
| lax.cumprod |
CumProd |
cumprod_i32_axis2 ✅
cumprod_i32_axis2_f64 ✅
cumprod_f32_axism1_reverse ✅
cumprod_f32_axism1_reverse_f64 ✅ |
0.12.1 |
| lax.cumsum |
CumSum |
cumsum_i32_axis2 ✅
cumsum_i32_axis2_f64 ✅
cumsum_f32_axism1_reverse ✅
cumsum_f32_axism1_reverse_f64 ✅ |
0.7.4 |
| lax.custom_linear_solve |
➖ |
custom_linear_solve_via_matvec ✅
custom_linear_solve_via_matvec_f64 ✅ |
0.12.1 |
| lax.device_put |
Identity |
device_put_array ✅
device_put_array_f64 ✅
device_put_scalar ✅
device_put_scalar_f64 ✅ |
0.4.0 |
| lax.digamma |
Cos Log Sin Where |
digamma_positive ✅
digamma_mixed ✅ |
0.12.1 |
| lax.div |
Div LpNormalization Mean |
div ✅
div_f64 ✅
div_const ✅
div_const_f64 ✅
div_add_half_fuses_to_mean ✅
div_add_half_fuses_to_mean_symbolic_dynamic ✅
div_add_half_fuses_to_mean_symbolic ✅
div_add_third_no_mean ✅
div_add_half_f64_no_mean ✅
div_lpnorm_l2_axis1 ✅
div_lpnorm_l1_axis1 ✅
div_lpnorm_l2_axis2 ✅
div_sqrt_of_norm_no_lpnorm_fusion ✅
div_complex64 ✅
div_complex64_f64 ✅ |
0.2.0 |
| lax.dot_general |
Einsum MatMul/Gemm |
dot_contract_nm ✅
dot_contract_nm_f64 ✅
dot_contract_min ✅
dot_contract_min_f64 ✅
dot_general ✅
dot_general_f64 ✅
dot_general_lhs1_rhs1 ✅
dot_general_lhs1_rhs1_f64 ✅
dot_double_contract ✅
dot_double_contract_f64 ✅
dot_batched_double_contract ✅
dot_batched_double_contract_f64 ✅
dot_highrank_batch ✅
dot_highrank_batch_f64 ✅
dot_contract_inner_lhs_with_middle_rhs ✅
dot_contract_inner_lhs_with_middle_rhs_f64 ✅
dot_outer_product ✅
dot_outer_product_f64 ✅
dot_full_contract_scalar ✅
dot_full_contract_scalar_f64 ✅
dot_general_complex_matmul ✅
dot_general_complex_matmul_f64 ✅ |
0.2.0 |
| lax.dynamic_slice |
Slice |
dynamic_slice_test1 ✅
dynamic_slice_test1_f64 ✅
dynamic_slice_2d ✅
dynamic_slice_2d_f64 ✅
dynamic_slice_3d ✅
dynamic_slice_3d_f64 ✅
dynamic_slice_vit_like_dynamic ✅
dynamic_slice_vit_like_dynamic_f64 ✅
dynamic_slice_vit_like ✅
dynamic_slice_vit_like_f64 ✅ |
0.1.0 |
| lax.dynamic_update_slice |
ScatterND TensorScatter |
dus_1d_scalar_update ✅
dus_1d_scalar_update_f64 ✅
dus_1d_block_update ✅
dus_1d_block_update_f64 ✅
dus_2d_block_update ✅
dus_2d_block_update_f64 ✅
dus_3d_block_update ✅
dus_3d_block_update_f64 ✅
dus_4d_block_update ✅
dus_4d_block_update_f64 ✅
dus_tensorscatter_axis1_opset24 ✅ |
0.8.1 |
| lax.eig |
Concat Gather Identity Reshape Unsqueeze |
eig_1x1_values_only ✅
eig_1x1_values_only_f64 ✅
eig_1x1_left_only ✅
eig_1x1_left_only_f64 ✅
eig_1x1_right_only ✅
eig_1x1_right_only_f64 ✅
eig_1x1_full ✅
eig_1x1_full_f64 ✅
eig_2x2_values_only_real ✅
eig_2x2_values_only_real_f64 ✅
eig_2x2_values_only_complex128 ✅ |
0.12.1 |
| lax.eigh |
Add Cast Concat Div Gather Identity LessOrEqual Mul Reshape Slice Sqrt Sub Where |
eigh_1x1 ✅
eigh_1x1_f64 ✅
eigh_2x2_lower_true ✅
eigh_2x2_lower_true_f64 ✅
eigh_2x2_lower_false ✅
eigh_2x2_lower_false_f64 ✅
eigh_2x2_subset_top1 ✅
eigh_2x2_subset_top1_f64 ✅ |
0.12.1 |
| lax.eq |
Equal |
eq ✅
eq_f64 ✅ |
0.2.0 |
| lax.erf |
Erf |
erf ✅ |
0.4.4 |
| lax.erf_inv |
Erf Exp Log Sqrt |
erf_inv_midrange ✅
erf_inv_matrix ✅ |
0.12.1 |
| lax.erfc |
Erf Sub |
erfc ✅ |
0.12.1 |
| lax.exp |
Exp |
exp ✅
exp_f64 ✅ |
0.2.0 |
| lax.exp2 |
Pow |
exp2 ✅
exp2_f64 ✅ |
0.12.1 |
| lax.expm1 |
Exp Sub |
expm1 ✅
expm1_f64 ✅ |
0.12.1 |
| lax.fft |
DFT |
fft_complex64_1d ✅
fft_complex64_len8 ✅
fft_complex64_batch ✅
fft_complex128_1d ✅
fft_complex128_len8 ✅
ifft_complex64_1d ✅
ifft_complex64_len8 ✅
ifft_complex128_batch ✅
rfft_real32_1d ✅
rfft_real64_len8 ✅
irfft_complex64_1d ✅
irfft_complex128_len8 ✅ |
0.10.1 |
| lax.floor |
Floor |
floor ✅
floor_f64 ✅ |
0.12.1 |
| lax.fori_loop |
Loop |
fori_loop_counter ✅
fori_loop_counter_f64 ✅
fori_loop_zero ✅
fori_loop_zero_f64 ✅
fori_loop_vector ✅
fori_loop_vector_f64 ✅
fori_loop_example ✅
fori_loop_example_f64 ✅
fori_loop_test ✅
fori_loop_test_f64 ✅ |
0.5.1 |
| lax.gather |
GatherND |
gather_trig_where_pipeline_f64_indices_i64 ✅
gather_trig_where_pipeline_f64_indices_i32 ✅
gather_f64_data_i64_indices_output_is_f64 ✅
gather_f64_data_i32_indices_cast_and_output_is_f64 ✅
gather_static ✅
gather_static_f64 ✅
gather_dynamic_batch_simple_index_dynamic ✅
gather_dynamic_batch_simple_index_dynamic_f64 ✅
gather_dynamic_batch_simple_index ✅
gather_dynamic_batch_simple_index_f64 ✅ |
0.2.0 |
| lax.greater_equal |
GreaterOrEqual |
greater_equal ✅
greater_equal_f64 ✅ |
0.7.5 |
| lax.gt |
Greater |
gt ✅
gt_f64 ✅ |
0.2.0 |
| lax.hessenberg |
Div MatMul ScatterND Sqrt Where |
hessenberg_square_4x4 ✅
hessenberg_square_4x4_f64 ✅
hessenberg_diagonal ✅
hessenberg_diagonal_f64 ✅ |
0.12.1 |
| lax.householder_product |
MatMul Mul ScatterND Sub Transpose |
householder_product_basic ✅
householder_product_basic_f64 ✅
householder_product_k0 ✅
householder_product_k0_f64 ✅ |
0.12.1 |
| lax.igamma |
Exp Pow Where |
igamma_basic ✅
igamma_zero_x ✅ |
0.12.1 |
| lax.igamma_grad_a |
Add Div Sub Where |
igamma_grad_a_basic ✅ |
0.12.1 |
| lax.igammac |
Sub Where |
igammac_basic ✅
igammac_zero_x ✅ |
0.12.1 |
| lax.imag |
Mul |
imag_complex64_input ✅
imag_complex64_input_f64 ✅ |
0.10.2 |
| lax.integer_pow |
Pow Reciprocal |
integer_pow ✅
integer_pow_f64 ✅
integer_pow_reciprocal ✅
integer_pow_reciprocal_f64 ✅ |
0.2.0 |
| lax.iota |
Range |
iota_int32 ✅
iota_int32_f64 ✅
iota_float32 ✅
iota_float32_f64 ✅
broadcasted_iota ✅
broadcasted_iota_f64 ✅ |
0.5.0 |
| lax.is_finite |
IsInf IsNaN Not Or |
is_finite_vec ✅
is_finite_vec_f64 ✅ |
0.12.1 |
| lax.less_equal |
LessOrEqual |
less_equal ✅
less_equal_f64 ✅ |
0.7.5 |
| lax.lgamma |
Log Sin Where |
lgamma_positive ✅
lgamma_positive_f64 ✅
lgamma_negative_noninteger ✅
lgamma_negative_noninteger_f64 ✅ |
0.12.1 |
| lax.log |
Log ReduceLogSumExp ReduceLogSum |
log ✅
log_f64 ✅
log_of_reduce_sum_axis1 ✅
log_of_reduce_sum_axis1_f64 ✅
log_of_reduce_sum_exp_axis1 ✅
log_of_reduce_sum_exp_axis1_f64 ✅ |
0.2.0 |
| lax.log1p |
Add Log |
log1p ✅
log1p_f64 ✅ |
0.11.0 |
| lax.logistic |
Sigmoid |
lax_logistic_basic ✅
lax_logistic_basic_f64 ✅ |
0.7.2 |
| lax.lt |
Less |
lt ✅
lt_f64 ✅ |
0.2.0 |
| lax.lu |
Abs ArgMax Div Mul ScatterElements ScatterND Sub |
lu_square_3x3 ✅
lu_square_3x3_f64 ✅
lu_rectangular_4x3 ✅
lu_rectangular_4x3_f64 ✅ |
0.12.1 |
| lax.lu_pivots_to_permutation |
Concat Gather ScatterElements Squeeze Unsqueeze |
lu_pivots_to_permutation_basic ✅
lu_pivots_to_permutation_basic_f64 ✅
lu_pivots_to_permutation_identity ✅
lu_pivots_to_permutation_identity_f64 ✅
lu_pivots_to_permutation_batched ✅
lu_pivots_to_permutation_batched_f64 ✅ |
0.12.1 |
| lax.max |
Max |
max ✅
max_f64 ✅ |
0.2.0 |
| lax.min |
Min |
min_test1 ✅
min_test1_f64 ✅ |
0.1.0 |
| lax.mul |
Mul |
mul_test1 ✅
mul_test1_f64 ✅
mul_test2 ✅
mul_test2_f64 ✅
mul_pyfloat_promotes_to_array_dtype_f64 ✅
mul_scalar_broadcast_promote_to_f64 ✅
mul_complex128 ✅
mul_complex64 ✅ |
0.1.0 |
| lax.ne |
Equal Not |
ne ✅
ne_f64 ✅ |
0.2.0 |
| lax.neg |
Neg |
neg ✅
neg_f64 ✅ |
0.2.0 |
| lax.nextafter |
IsNaN Log Pow Sign Where |
nextafter_vector ✅
nextafter_special_values ✅ |
0.12.1 |
| lax.optimization_barrier |
Identity |
optimization_barrier_single ✅
optimization_barrier_single_f64 ✅
optimization_barrier_tuple ✅
optimization_barrier_tuple_f64 ✅ |
0.12.1 |
| lax.or |
BitwiseOr Or |
or_bool_vec ✅
or_bool_vec_f64 ✅
or_int_vec ✅
or_int_vec_f64 ✅ |
0.7.2 |
| lax.pad |
Pad |
pad_const_1d ✅
pad_const_1d_f64 ✅
pad_const_2d ✅
pad_const_2d_f64 ✅
pad_const_2d_cval ✅
pad_const_2d_cval_f64 ✅
pad_inside_scan_smoke_f64 ✅
pad_inside_nested_scan_smoke_f64 ✅ |
0.8.0 |
| lax.pjit |
➖ |
pjit_inline_mul ✅
pjit_inline_mul_f64 ✅
pjit_inline_tuple ✅
pjit_inline_tuple_f64 ✅ |
0.1.0 |
| lax.polygamma |
Exp Floor Pow Round Where |
polygamma_orders ✅
polygamma_zero_order ✅ |
0.12.1 |
| lax.population_count |
BitShift BitwiseAnd |
population_count_i32 ✅
population_count_i32_f64 ✅
population_count_u8 ✅
population_count_u8_f64 ✅ |
0.12.1 |
| lax.pow |
Pow |
pow_basic ✅
pow_basic_f64 ✅
pow_lax ✅
pow_lax_f64 ✅ |
0.8.2 |
| lax.qr |
MatMul Mul ScatterND Sqrt Sub |
qr_reduced_tall ✅
qr_reduced_tall_f64 ✅
qr_reduced_wide ✅
qr_reduced_wide_f64 ✅
qr_full_tall ✅
qr_full_tall_f64 ✅
qr_full_wide ✅
qr_full_wide_f64 ✅ |
0.12.1 |
| lax.real |
Identity |
real_complex64_input ✅
real_complex64_input_f64 ✅ |
0.10.2 |
| lax.reduce |
ReduceMax ReduceMin |
reduce_max_lambda ✅
reduce_max_lambda_f64 ✅
reduce_min_lambda ✅
reduce_min_lambda_f64 ✅ |
0.12.1 |
| lax.reduce_and |
ReduceMin |
reduce_and_all_true ✅
reduce_and_all_true_f64 ✅
reduce_and_one_false ✅
reduce_and_one_false_f64 ✅
reduce_and_keepdims ✅
reduce_and_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_max |
ReduceMax |
reduce_max ✅
reduce_max_f64 ✅
reduce_max_allaxes ✅
reduce_max_allaxes_f64 ✅
reduce_max_axes_input ✅
reduce_max_axes_input_f64 ✅
reduce_max_keepdims ✅
reduce_max_keepdims_f64 ✅ |
0.2.0 |
| lax.reduce_min |
ReduceMin |
reduce_min ✅
reduce_min_f64 ✅
reduce_min_allaxes ✅
reduce_min_allaxes_f64 ✅
reduce_min_keepdims ✅
reduce_min_keepdims_f64 ✅ |
0.2.0 |
| lax.reduce_or |
ReduceMax |
reduce_or_all_false ✅
reduce_or_all_false_f64 ✅
reduce_or_one_true ✅
reduce_or_one_true_f64 ✅
reduce_or_keepdims ✅
reduce_or_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_precision |
Floor Log Pow Round Where |
reduce_precision_mantissa10 ✅
reduce_precision_underflow_overflow ✅ |
0.12.1 |
| lax.reduce_prod |
ReduceProd |
reduce_prod ✅
reduce_prod_f64 ✅
reduce_prod_allaxes ✅
reduce_prod_allaxes_f64 ✅
reduce_prod_dtype ✅
reduce_prod_dtype_f64 ✅
reduce_prod_keepdims ✅
reduce_prod_keepdims_f64 ✅ |
0.6.1 |
| lax.reduce_sum |
ReduceL1 ReduceSumSquare ReduceSum |
reduce_sum ✅
reduce_sum_f64 ✅
reduce_sum_allaxes ✅
reduce_sum_allaxes_f64 ✅
reduce_sum_dtype ✅
reduce_sum_dtype_f64 ✅
reduce_sum_keepdims ✅
reduce_sum_keepdims_f64 ✅
reduce_sum_of_abs_axis1 ✅
reduce_sum_of_abs_axis1_f64 ✅
reduce_sum_of_square_axis1 ✅
reduce_sum_of_square_axis1_f64 ✅ |
0.2.0 |
| lax.reduce_window |
Conv MaxPool Neg |
reduce_window_add_lambda ✅
reduce_window_max_lambda ✅
reduce_window_min_lambda_stride ✅ |
0.12.1 |
| lax.reduce_window_max |
MaxPool |
reduce_window_max_primitive ✅ |
0.12.2 |
| lax.reduce_window_sum |
Conv LpPool |
reduce_window_sum_valid ✅
reduce_window_sum_same_padding ✅
reduce_window_sum_stride_dilate ✅
reduce_window_sum_int32 ✅
reduce_window_sum_base_dilation ✅
reduce_window_sum_abs_lppool_opset23 ✅
reduce_window_sum_abs_lppool_dilated_opset23 ✅ |
0.10.1 |
| lax.reduce_xor |
Mod ReduceSum |
reduce_xor_all_false ✅
reduce_xor_all_false_f64 ✅
reduce_xor_one_true ✅
reduce_xor_one_true_f64 ✅
reduce_xor_two_true ✅
reduce_xor_two_true_f64 ✅
reduce_xor_keepdims ✅
reduce_xor_keepdims_f64 ✅ |
0.6.1 |
| lax.rem |
Div Mod |
rem_int ✅
rem_int_f64 ✅
rem_float ✅
rem_float_f64 ✅
rem_int_neg ✅
rem_int_neg_f64 ✅
rem_float_neg ✅
rem_float_neg_f64 ✅ |
0.6.5 |
| lax.remat2 |
➖ |
remat2_scalar_sin_chain ✅
remat2_scalar_sin_chain_f64 ✅
remat2_tuple_passthrough ✅
remat2_tuple_passthrough_f64 ✅ |
0.6.5 |
| lax.reshape |
Reshape |
reshape_after_transpose_folds_const_shape ✅
reshape_after_transpose_folds_const_shape_f64 ✅
reshape_flatten_trailing_folds_const_shape ✅
reshape_flatten_trailing_folds_const_shape_f64 ✅
reshape ✅
reshape_f64 ✅
reshape_valid_squeeze_middle_dim_from_problematic_source ✅
reshape_valid_squeeze_middle_dim_from_problematic_source_f64 ✅
reshape_valid_flatten_trailing ✅
reshape_valid_flatten_trailing_f64 ✅
reshape_with_target_shape_from_symbolic_dim_computation ✅
reshape_with_target_shape_from_symbolic_dim_computation_f64 ✅
reshape_with_inferred_dimension_from_input_dynamic_dynamic ✅
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64 ✅
reshape_with_inferred_dimension_from_input_dynamic ✅
reshape_with_inferred_dimension_from_input_dynamic_f64 ✅
reshape_with_inferred_dimension_from_input ✅
reshape_with_inferred_dimension_from_input_f64 ✅
reshape_merge_symbolic_with_static_and_check_name_dynamic ✅
reshape_merge_symbolic_with_static_and_check_name ✅ |
0.2.0 |
| lax.rev |
Gather Range |
rev_vector ✅
rev_vector_f64 ✅
rev_matrix_axes01 ✅
rev_matrix_axes01_f64 ✅ |
0.7.5 |
| lax.rng_bit_generator |
Cast Identity RandomUniform |
rng_bit_generator_u32 ✅ |
0.12.1 |
| lax.rng_uniform |
Add Mul RandomUniform Sub |
rng_uniform_f32 ✅ |
0.12.1 |
| lax.round |
Round |
round ✅
round_f64 ✅ |
0.12.1 |
| lax.rsqrt |
Div Sqrt |
rsqrt ✅
rsqrt_f64 ✅ |
0.10.2 |
| lax.scan |
Loop |
scan_identity_slice_helper ✅
scan_identity_slice_helper_f64 ✅
scan_cumsum ✅
scan_cumsum_f64 ✅
scan_carry_only ✅
scan_carry_only_f64 ✅
scan_multiple_sequences ✅
scan_multiple_sequences_f64 ✅
scan_multiple_carry ✅
scan_multiple_carry_f64 ✅
scan_matrix_carry_multidim_xs ✅
scan_matrix_carry_multidim_xs_f64 ✅
scan_no_xs ✅
scan_no_xs_f64 ✅
scan_fn ✅
scan_fn_f64 ✅
scan_jit_no_xs ✅
scan_jit_no_xs_f64 ✅
scan_captured_scalar ✅
scan_captured_scalar_f64 ✅
scan_rank0_sequence_vectorized ✅
scan_rank0_sequence_vectorized_f64 ✅
scan_two_diff_lengths ✅
scan_two_diff_lengths_f64 ✅
scan_two_diff_lengths_broadcast ✅
scan_two_diff_lengths_broadcast_f64 ✅
scan_two_diff_lengths_with_broadcast ✅
scan_nested_len_mismatch ✅
scan_nested_len_mismatch_f64 ✅
scan_captured_scalar_with_xs ✅
scan_captured_vector_with_xs_f64 ✅ |
0.5.1 |
| lax.scatter |
NonZero ScatterElements ScatterND Scatter |
scatter_set_axis0 ✅
scatter_set_axis0_f64 ✅
scatter_set_middle ✅
scatter_set_middle_f64 ✅
scatter_set_single ✅
scatter_set_single_f64 ✅
scatter_set_vector ✅
scatter_set_vector_f64 ✅
scatter_elements_set_vector_promise_in_bounds ✅
scatter_elements_set_vector_promise_in_bounds_f64 ✅
scatter_correct_axis_determination ✅
scatter_correct_axis_determination_f64 ✅
scatter_updates_slice_needed_axis0 ✅
scatter_updates_slice_needed_axis0_f64 ✅
scatter_from_user_warning_shapes_valid_jax ✅
scatter_from_user_warning_shapes_valid_jax_f64 ✅
scatter_user_error_scenario_precise ✅
scatter_user_error_scenario_precise_f64 ✅
scatter_window_update_f64 ✅
scatter_window_update_depth3_shapes_ok ✅
scatter_static_slice_set_f64 ✅
scatter_depth2_fp64_type_mismatch ✅
scatter_clip_2d_window_at_edge ✅
scatter_simple_2d_window_out_of_bounds ✅
scatter_depth2_mixed_dtypes_fp_mismatch_f64 ✅
scatter_depth2_mixed_dtypes_fp_mismatch ✅ |
0.4.4 |
| lax.scatter_add |
ScatterND(reduction='add') |
scatter_add_vector ✅
scatter_add_vector_f64 ✅
scatter_add_scalar ✅
scatter_add_scalar_f64 ✅
scatter_add_simple_1d ✅
scatter_add_simple_1d_f64 ✅
scatter_add_batch_updates_1d_operand ✅
scatter_add_batch_updates_1d_operand_f64 ✅
scatter_add_window_2d_operand_1d_indices ✅
scatter_add_window_2d_operand_1d_indices_f64 ✅
scatter_add_mismatched_window_dims_from_user_report ✅
scatter_add_mismatched_window_dims_from_user_report2 ✅
scatter_add_mismatched_window_dims_from_user_report3 ✅
scatter_add_fluids_pattern_updates_5_4_1_1 ✅
scatter_add_in_cond_float64 ✅
scatter_add_fp64_dtype_mismatch ✅
scatter_add_depth2_depth2_helper_regression ✅
scatter_depth2_fp64_type_mismatch ✅ |
0.5.3 |
| lax.scatter_max |
ScatterND(reduction='max') |
scatter_max_simple_1d ✅
scatter_max_simple_1d_f64 ✅
scatter_max_batch_updates_1d_operand ✅
scatter_max_batch_updates_1d_operand_f64 ✅
scatter_max_window_2d_operand_1d_indices ✅
scatter_max_window_2d_operand_1d_indices_f64 ✅
scatter_max_fp64_dtype_path_check ✅
scatter_max_depth2_helper_regression_fp64 ✅ |
0.7.5 |
| lax.scatter_min |
ScatterND(reduction='min') |
scatter_min_simple_1d ✅
scatter_min_simple_1d_f64 ✅
scatter_min_batch_updates_1d_operand ✅
scatter_min_batch_updates_1d_operand_f64 ✅
scatter_min_window_2d_operand_1d_indices ✅
scatter_min_window_2d_operand_1d_indices_f64 ✅
scatter_min_fp64_dtype_path_check ✅
scatter_min_depth2_helper_regression_fp64 ✅ |
0.7.5 |
| lax.scatter_mul |
ScatterND(reduction='mul') |
scatter_mul_simple_1d ✅
scatter_mul_simple_1d_f64 ✅
scatter_mul_batch_updates_1d_operand ✅
scatter_mul_batch_updates_1d_operand_f64 ✅
scatter_mul_window_2d_operand_1d_indices ✅
scatter_mul_window_2d_operand_1d_indices_f64 ✅
scatter_mul_mismatched_window_dims_from_user_report ✅
scatter_mul_mismatched_window_dims_from_user_report2 ✅
scatter_mul_mismatched_window_dims_from_user_report3 ✅
scatter_mul_fluids_pattern_updates_5_4_1_1 ✅
scatter_mul_in_cond_float64 ✅ |
0.6.4 |
| lax.scatter_sub |
Neg ScatterND(reduction='add') |
scatter_sub_simple_1d ✅
scatter_sub_simple_1d_f64 ✅
scatter_sub_window_2d_operand_1d_indices ✅
scatter_sub_window_2d_operand_1d_indices_f64 ✅ |
0.12.1 |
| lax.schur |
Identity |
schur_1x1_default ✅
schur_1x1_default_f64 ✅
schur_1x1_no_vectors ✅
schur_1x1_no_vectors_f64 ✅ |
0.12.1 |
| lax.select |
Where |
select_simple ✅
select_simple_f64 ✅
select_basic ✅
select_basic_f64 ✅
select_mask_scores_tensor_else_dynamic ✅
select_mask_scores_tensor_else_dynamic_f64 ✅
select_mask_scores_tensor_else ✅
select_mask_scores_tensor_else_f64 ✅ |
0.7.1 |
| lax.select_n |
Equal Where |
select_n_bool_predicate_two_cases_float ✅
select_n_bool_predicate_two_cases_float_f64 ✅
select_n_bool_predicate_two_cases_int ✅
select_n_bool_predicate_two_cases_int_f64 ✅
select_n_bool_predicate_scalar_broadcast ✅
select_n_bool_predicate_scalar_broadcast_f64 ✅
select_n_int_indices_three_cases ✅
select_n_int_indices_three_cases_f64 ✅
select_n_int_indices_four_cases ✅
select_n_int_indices_four_cases_f64 ✅ |
0.2.0 |
| lax.shard_map |
➖ |
shard_map_inline_add ✅
shard_map_inline_add_f64 ✅ |
0.10.2 |
| lax.shift_left |
BitShift |
shift_left_vec ✅
shift_left_vec_f64 ✅
shift_left_scalar ✅
shift_left_scalar_f64 ✅ |
0.12.1 |
| lax.shift_right_arithmetic |
BitShift |
shift_right_arithmetic_vec ✅
shift_right_arithmetic_vec_f64 ✅
shift_right_arithmetic_scalar ✅
shift_right_arithmetic_scalar_f64 ✅ |
0.12.1 |
| lax.shift_right_logical |
BitShift |
shift_right_logical_vec ✅
shift_right_logical_vec_f64 ✅
shift_right_logical_scalar ✅
shift_right_logical_scalar_f64 ✅ |
0.7.2 |
| lax.sign |
Sign |
sign ✅
sign_f64 ✅ |
0.5.0 |
| lax.sin |
Sin |
sin ✅
sin_f64 ✅ |
0.4.4 |
| lax.sinh |
Sinh |
sinh ✅
sinh_f64 ✅ |
0.4.4 |
| lax.slice |
Slice |
slice_test1 ✅
slice_test1_f64 ✅
slice_3d_none_strides ✅
slice_3d_none_strides_f64 ✅
slice_scan_axis_drop ✅
slice_scan_axis_drop_f64 ✅ |
0.1.0 |
| lax.sort |
GatherElements TopK |
sort_1d ✅
sort_1d_f64 ✅
sort_2d ✅
sort_2d_f64 ✅ |
0.2.0 |
| lax.split |
Split |
lax_split_equal_parts ✅
lax_split_equal_parts_f64 ✅
lax_split_unequal_parts ✅
lax_split_unequal_parts_f64 ✅ |
0.7.2 |
| lax.sqrt |
ReduceL2 Sqrt |
sqrt ✅
sqrt_f64 ✅
sqrt_reduce_sum_square_axis1 ✅
sqrt_reduce_sum_square_axis1_f64 ✅ |
0.2.0 |
| lax.square |
Mul |
square ✅
square_f64 ✅ |
0.2.0 |
| lax.squeeze |
Squeeze |
squeeze_single_axis ✅
squeeze_single_axis_f64 ✅
squeeze_all_unit_dims_default ✅
squeeze_all_unit_dims_default_f64 ✅
lax_squeeze_specific_axis_0 ✅
lax_squeeze_specific_axis_0_f64 ✅
lax_squeeze_multiple_axes ✅
lax_squeeze_multiple_axes_f64 ✅
lax_squeeze_no_op_empty_dims ✅
lax_squeeze_no_op_empty_dims_f64 ✅
lax_squeeze_problem_case_input_squeeze_only_axis_0 ✅
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64 ✅
lax_squeeze_problem_case_input_squeeze_axes_0_2 ✅
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64 ✅
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly ✅
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64 ✅ |
0.2.0 |
| lax.stop_gradient |
Identity |
stop_gradient ✅
stop_gradient_f64 ✅
stop_gradient_basic ✅
stop_gradient_basic_f64 ✅ |
0.2.0 |
| lax.sub |
Sub |
sub_test1 ✅
sub_test1_f64 ✅
sub_test2 ✅
sub_test2_f64 ✅
sub_const ✅
sub_const_f64 ✅
sub_complex64 ✅
sub_complex64_f64 ✅ |
0.1.0 |
| lax.svd |
Abs Add Div Gather Identity Max Mul Reshape Sign Sub Where |
svd_1x1_default ✅
svd_1x1_default_f64 ✅
svd_1x1_values_only ✅
svd_1x1_values_only_f64 ✅
svd_2x2_values_only ✅
svd_2x2_values_only_f64 ✅
svd_3x2_values_only ✅
svd_3x2_values_only_f64 ✅
svd_2x4_values_only ✅
svd_2x4_values_only_f64 ✅ |
0.12.1 |
| lax.symmetric_product |
Add MatMul Mul Transpose |
symmetric_product_default ✅
symmetric_product_default_f64 ✅
symmetric_product_alpha_beta ✅
symmetric_product_alpha_beta_f64 ✅ |
0.12.1 |
| lax.tan |
Tan |
tan ✅ |
0.12.1 |
| lax.tanh |
Tanh |
tanh ✅
tanh_f64 ✅ |
0.2.0 |
| lax.top_k |
TopK |
top_k_last_axis ✅
top_k_last_axis_f64 ✅
top_k_matrix ✅
top_k_matrix_f64 ✅ |
0.10.2 |
| lax.transpose |
Transpose |
transpose_basic ✅
transpose_basic_f64 ✅
transpose_square_matrix ✅
transpose_square_matrix_f64 ✅
transpose_3d ✅
transpose_3d_f64 ✅
transpose_4d ✅
transpose_4d_f64 ✅
transpose_reverse ✅
transpose_reverse_f64 ✅
transpose_no_axes ✅
transpose_no_axes_f64 ✅
transpose_nhwc_to_nchw ✅
transpose_nhwc_to_nchw_f64 ✅ |
0.2.0 |
| lax.triangular_solve |
Div Gather |
triangular_solve_left_basic ✅
triangular_solve_left_basic_f64 ✅
triangular_solve_batched_unit_diag ✅ |
0.12.1 |
| lax.tridiagonal |
Gather Identity ScatterND |
tridiagonal_2x2_lower_true ✅
tridiagonal_2x2_lower_true_f64 ✅
tridiagonal_2x2_lower_false ✅
tridiagonal_2x2_lower_false_f64 ✅ |
0.12.1 |
| lax.tridiagonal_solve |
Concat Div Gather Mul Sub |
tridiagonal_solve_single_rhs ✅
tridiagonal_solve_single_rhs_f64 ✅
tridiagonal_solve_multi_rhs ✅
tridiagonal_solve_multi_rhs_f64 ✅ |
0.12.1 |
| lax.while_loop |
Loop |
while_scalar_counter ✅
while_scalar_counter_f64 ✅
while_tuple_state ✅
while_tuple_state_f64 ✅
while_loop_counter ✅
while_loop_counter_f64 ✅
while_loop_vector ✅
while_loop_vector_f64 ✅
while_loop_f64 ✅
while_loop_multi_state_f32 ✅
while_loop_multi_state_f64 ✅
while_loop_with_closure ✅
while_loop_with_closure_f64 ✅
while_loop_basic ✅
while_loop_two_state ✅
while_loop_captured_tracer ✅
while_loop_with_scalar_state ✅
while_loop_renamed_passthrough ✅
while_loop_closure_topo ✅
while_loop_mixed_rank ✅
while_loop_tracer_passthrough ✅
while_loop_no_loop_output_reused_as_input ✅
while_loop_4d_and_scalar_state ✅
while_loop_4d_and_scalar_state_f64 ✅
while_loop_cnn_scalar_state_bug ✅
while_loop_cnn_scalar_state_bug_f64 ✅
while_loop_nnx_repro ✅
while_loop_nnx_repro_f64 ✅ |
0.5.1 |
| lax.xor |
BitwiseXor Xor |
xor_bool_vec ✅
xor_bool_vec_f64 ✅
xor_int_vec ✅
xor_int_vec_f64 ✅ |
0.12.1 |
| lax.zeta |
Add Div Pow Where |
zeta_positive ✅
zeta_broadcast ✅ |
0.12.1 |
| linen.activation |
➖ |
activation_glu_basic ✅
activation_glu_basic_f64 ✅
activation_hard_sigmoid_basic ✅
activation_hard_silu_basic ✅
activation_hard_silu_basic_f64 ✅
activation_hard_swish_basic ✅
activation_hard_tanh_basic ✅
activation_hard_tanh_basic_f64 ✅
activation_log_sigmoid_basic ✅
activation_log_sigmoid_basic_f64 ✅
activation_log_softmax_basic ✅
activation_log_softmax_basic_f64 ✅
activation_relu6_basic ✅
activation_relu6_basic_f64 ✅
activation_silu_basic ✅
activation_silu_basic_f64 ✅
activation_swish_basic ✅
activation_swish_basic_f64 ✅
activation_tanh_basic ✅
activation_tanh_basic_f64 ✅
activation_normalize_basic ✅
activation_normalize_basic_f64 ✅
activation_one_hot_basic ✅ |
0.11.0 |
| linen.avg_pool |
AveragePool GlobalAveragePool Transpose |
avg_pool_dynamic ✅
avg_pool ✅
avg_pool_same_padding_dynamic ✅
avg_pool_same_padding ✅
avg_pool_default_padding_dynamic ✅
avg_pool_default_padding ✅
avg_pool_stride1_dynamic ✅
avg_pool_stride1 ✅
avg_pool_win3x3_stride2_dynamic ✅
avg_pool_win3x3_stride2 ✅
avg_pool_stride_none_dynamic ✅
avg_pool_stride_none ✅
avg_pool_count_include_pad_false_dynamic ✅
avg_pool_count_include_pad_false ✅
avg_pool_global_window_dynamic ✅
avg_pool_global_window ✅ |
0.11.0 |
| linen.batch_norm |
BatchNormalization |
batch_norm_no_bias_no_scale_dynamic ✅
batch_norm_no_bias_no_scale ✅
batch_norm_bias_no_scale_dynamic ✅
batch_norm_bias_no_scale ✅
batch_norm_no_bias_scale_dynamic ✅
batch_norm_no_bias_scale ✅
batch_norm_bias_scale_dynamic ✅
batch_norm_bias_scale ✅
batch_norm_3d_dynamic ✅
batch_norm_3d ✅
batch_norm_4d_dynamic ✅
batch_norm_4d ✅
batch_norm_4d_no_bias_no_scale_dynamic ✅
batch_norm_4d_no_bias_no_scale ✅ |
0.11.0 |
| linen.bidirectional |
Concat Loop |
bidirectional_basic_dynamic ✅
bidirectional_basic_dynamic_f64 ✅
bidirectional_basic ✅
bidirectional_basic_f64 ✅ |
0.11.0 |
| linen.conv |
CastLike Conv Reshape Transpose |
conv_basic_dynamic ✅
conv_basic ✅
conv_no_bias ✅
conv_stride ✅ |
0.11.0 |
| linen.conv_local |
Conv Gemm Reshape Transpose |
conv_local_valid ✅
conv_local_same ✅ |
0.11.0 |
| linen.conv_lstm_cell |
Add Conv Gemm Mul Sigmoid Tanh |
conv_lstm_cell_basic_dynamic ✅
conv_lstm_cell_basic ✅ |
0.11.0 |
| linen.conv_transpose |
ConvTranspose |
conv_transpose_basic_dynamic ✅
conv_transpose_basic ✅
conv_transpose_valid_stride ✅ |
0.11.0 |
| linen.dense |
Gemm |
dense_basic_dynamic ✅
dense_basic_dynamic_f64 ✅
dense_basic ✅
dense_basic_f64 ✅
dense_high_rank_dynamic_dynamic ✅
dense_high_rank_dynamic_dynamic_f64 ✅
dense_high_rank_static ✅
dense_high_rank_static_f64 ✅
dense_high_rank_no_bias ✅
dense_high_rank_no_bias_f64 ✅
dense_no_bias_dynamic ✅
dense_no_bias_dynamic_f64 ✅
dense_no_bias ✅
dense_no_bias_f64 ✅ |
0.11.0 |
| linen.dense_general |
CastLike Concat Gemm Reshape Shape Slice |
dense_general_basic_dynamic ✅
dense_general_basic_dynamic_f64 ✅
dense_general_basic ✅
dense_general_basic_f64 ✅
dense_general_multi_out ✅
dense_general_multi_out_f64 ✅
dense_general_contract_last_two ✅
dense_general_contract_last_two_f64 ✅
dense_general_dynamic_batch_dynamic ✅
dense_general_dynamic_batch_dynamic_f64 ✅
dense_general_no_bias ✅
dense_general_no_bias_f64 ✅ |
0.11.0 |
| linen.dot_product_attention |
MatMul Mul Softmax Transpose Where |
dot_product_attention_basic ✅
dot_product_attention_basic_f64 ✅ |
0.11.0 |
| linen.dot_product_attention_weights |
Add Div MatMul Mul ReduceSum Softmax Where |
dot_product_attention_weights_basic ✅
dot_product_attention_weights_basic_f64 ✅ |
0.11.0 |
| linen.dropout |
Dropout |
dropout_init_params_dynamic ✅
dropout_init_params_dynamic_f64 ✅
dropout_init_params ✅
dropout_init_params_f64 ✅
dropout_call_params_dynamic ✅
dropout_call_params_dynamic_f64 ✅
dropout_call_params ✅
dropout_call_params_f64 ✅ |
0.11.0 |
| linen.einsum |
Add Einsum Reshape |
einsum_with_bias ✅
einsum_with_bias_f64 ✅
einsum_no_bias ✅
einsum_no_bias_f64 ✅ |
0.11.0 |
| linen.embed |
Gather |
token_embedding_dynamic ✅
token_embedding_dynamic_f64 ✅
token_embedding ✅
token_embedding_f64 ✅
positional_embedding_dynamic ✅
positional_embedding_dynamic_f64 ✅
positional_embedding ✅
positional_embedding_f64 ✅ |
0.11.0 |
| linen.group_norm |
GroupNormalization |
group_norm_rank4 ✅
group_norm_rank2_dynamic ✅
group_norm_rank2 ✅
group_norm_no_bias_no_scale_dynamic ✅
group_norm_no_bias_no_scale ✅ |
0.11.0 |
| linen.gru_cell |
Add Gemm Mul Sigmoid Tanh |
gru_cell_basic_dynamic ✅
gru_cell_basic_dynamic_f64 ✅
gru_cell_basic ✅
gru_cell_basic_f64 ✅ |
0.11.0 |
| linen.instance_norm |
GroupNormalization InstanceNormalization |
instance_norm_rank4_dynamic ✅
instance_norm_rank4 ✅
instance_norm_rank2_dynamic ✅
instance_norm_rank2 ✅ |
0.11.0 |
| linen.layer_norm |
LayerNormalization |
layer_norm_dynamic ✅
layer_norm ✅
layer_norm_no_bias_no_scale_dynamic ✅
layer_norm_no_bias_no_scale ✅
layer_norm_multiaxis_dynamic ✅
layer_norm_multiaxis ✅
layer_norm_default_epsilon_dynamic ✅
layer_norm_default_epsilon ✅ |
0.11.0 |
| linen.lstm_cell |
Add Gemm Mul Sigmoid Tanh |
lstm_cell_basic_dynamic ✅
lstm_cell_basic_dynamic_f64 ✅
lstm_cell_basic ✅
lstm_cell_basic_f64 ✅ |
0.11.0 |
| linen.make_attention_mask |
Cast Mul |
make_attention_mask_basic ✅
make_attention_mask_basic_f64 ✅ |
0.11.0 |
| linen.make_causal_mask |
Cast Less |
make_causal_mask_basic ✅
make_causal_mask_basic_f64 ✅ |
0.11.0 |
| linen.max_pool |
GlobalMaxPool MaxPool |
max_pool ✅
max_pool_same_padding ✅
max_pool_basic ✅
max_pool_same_dynamic ✅
max_pool_same ✅
max_pool_global_window_dynamic ✅
max_pool_global_window ✅ |
0.11.0 |
| linen.mgu_cell |
Add Gemm Mul Sigmoid Tanh |
mgu_cell_basic_dynamic ✅
mgu_cell_basic_dynamic_f64 ✅
mgu_cell_basic ✅
mgu_cell_basic_f64 ✅ |
0.11.0 |
| linen.min_pool |
MaxPool Neg |
min_pool ✅
min_pool_same_padding ✅
min_pool_basic ✅
min_pool_same_dynamic ✅
min_pool_same ✅ |
0.11.0 |
| linen.multi_head_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
multi_head_attention_basic_dynamic ✅
multi_head_attention_basic_dynamic_f64 ✅
multi_head_attention_basic ✅
multi_head_attention_basic_f64 ✅
multi_head_attention_no_bias_dynamic ✅
multi_head_attention_no_bias_dynamic_f64 ✅
multi_head_attention_no_bias ✅
multi_head_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.multi_head_dot_product_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
multi_head_dot_product_attention_basic_dynamic ✅
multi_head_dot_product_attention_basic_dynamic_f64 ✅
multi_head_dot_product_attention_basic ✅
multi_head_dot_product_attention_basic_f64 ✅
multi_head_dot_product_attention_no_bias_dynamic ✅
multi_head_dot_product_attention_no_bias_dynamic_f64 ✅
multi_head_dot_product_attention_no_bias ✅
multi_head_dot_product_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.optimized_lstm_cell |
Add Gemm Mul Sigmoid Tanh |
optimized_lstm_cell_basic_dynamic ✅
optimized_lstm_cell_basic_dynamic_f64 ✅
optimized_lstm_cell_basic ✅
optimized_lstm_cell_basic_f64 ✅ |
0.11.0 |
| linen.pool |
AveragePool MaxPool Mul Neg Transpose |
pool_max_basic_dynamic ✅
pool_max_basic ✅
pool_min_basic_dynamic ✅
pool_min_basic ✅
pool_sum_basic_dynamic ✅
pool_sum_basic ✅ |
0.11.0 |
| linen.prelu |
PRelu |
linen_prelu_default_dynamic ✅
linen_prelu_default ✅
linen_prelu_custom_slope ✅ |
0.12.1 |
| linen.rms_norm |
RMSNormalization |
rms_norm_basic ✅
rms_norm_use_scale_false ✅
rms_norm_4d_dynamic_dynamic ✅
rms_norm_4d_dynamic ✅ |
0.11.0 |
| linen.rnn |
Loop |
rnn_basic_dynamic ✅
rnn_basic_dynamic_f64 ✅
rnn_basic ✅
rnn_basic_f64 ✅ |
0.11.0 |
| linen.self_attention |
Add Gemm MatMul Mul Reshape Softmax Transpose |
self_attention_basic_dynamic ✅
self_attention_basic_dynamic_f64 ✅
self_attention_basic ✅
self_attention_basic_f64 ✅
self_attention_no_bias_dynamic ✅
self_attention_no_bias_dynamic_f64 ✅
self_attention_no_bias ✅
self_attention_no_bias_f64 ✅ |
0.11.0 |
| linen.simple_cell |
Add Gemm Mul Sigmoid Tanh |
simple_cell_basic_dynamic ✅
simple_cell_basic_dynamic_f64 ✅
simple_cell_basic ✅
simple_cell_basic_f64 ✅ |
0.11.0 |
| linen.spectral_norm |
Div MatMul |
spectral_norm_dense_dynamic ✅
spectral_norm_dense ✅ |
0.11.0 |
| linen.weight_norm |
Div Mul |
weight_norm_dense_dynamic ✅
weight_norm_dense ✅ |
0.11.0 |
| nn.celu |
Celu |
jaxnn_celu ✅
jaxnn_celu_1 ✅
jaxnn_celu_alpha_default ✅
jaxnn_celu_alpha_custom_dynamic ✅
jaxnn_celu_alpha_custom ✅
celu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.dot_product_attention |
Add MatMul Mul Softmax Transpose Where |
dpa_basic ✅
dpa_basic_f64 ✅
dpa_positional_bias_mask ✅
dpa_positional_bias_mask_f64 ✅
dpa_diff_heads_embed ✅
dpa_diff_heads_embed_f64 ✅
dpa_batch4_seq16 ✅
dpa_batch4_seq16_f64 ✅
dpa_float64 ✅
dpa_heads1_embed4 ✅
dpa_heads1_embed4_f64 ✅
dpa_heads8_embed8 ✅
dpa_heads8_embed8_f64 ✅
dpa_batch1_seq2 ✅
dpa_batch1_seq2_f64 ✅
dpa_batch8_seq4 ✅
dpa_batch8_seq4_f64 ✅
dpa_axis1 ✅
dpa_axis1_f64 ✅
dpa_with_tensor_mask ✅
dpa_with_tensor_mask_f64 ✅
dpa_tiny_mask_all_valid ✅
dpa_tiny_mask_all_valid_f64 ✅
dpa_tiny_mask_mixed ✅
dpa_tiny_mask_mixed_f64 ✅
dpa_one_false ✅
dpa_one_false_f64 ✅
dpa_mostly_false ✅
dpa_mostly_false_f64 ✅
dpa_with_causal_mask ✅
dpa_with_causal_mask_f64 ✅
dpa_with_padding_mask ✅
dpa_with_padding_mask_f64 ✅
dpa_with_local_window_mask ✅
dpa_with_local_window_mask_f64 ✅
dpa_vmap_tnh_issue190 ✅
dpa_mask_none ✅ |
0.8.0 |
| nn.elu |
Elu |
jaxnn_elu ✅
jaxnn_elu_1 ✅
jaxnn_elu_default ✅
jaxnn_elu_custom_alpha_dynamic ✅
jaxnn_elu_custom_alpha ✅
elu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.gelu |
Gelu |
jaxnn_gelu ✅
jaxnn_gelu_1 ✅
jaxnn_gelu_approx ✅
jaxnn_gelu_exact ✅
jaxnn_gelu_tanh_dynamic ✅
jaxnn_gelu_tanh ✅
gelu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.glu |
Mul Sigmoid Split |
jaxnn_glu_axis_last ✅
jaxnn_glu_axis_last_f64 ✅
jaxnn_glu_axis1_dynamic_dynamic ✅
jaxnn_glu_axis1_dynamic_dynamic_f64 ✅
jaxnn_glu_axis1_dynamic ✅
jaxnn_glu_axis1_dynamic_f64 ✅ |
0.12.1 |
| nn.hard_sigmoid |
HardSigmoid |
jaxnn_hard_sigmoid ✅
jaxnn_hard_sigmoid_dynamic_dynamic ✅
jaxnn_hard_sigmoid_dynamic ✅ |
0.12.1 |
| nn.hard_swish |
HardSwish |
jaxnn_hard_swish ✅
jaxnn_hard_swish_dynamic_dynamic ✅
jaxnn_hard_swish_dynamic ✅ |
0.12.1 |
| nn.hard_tanh |
Max Min |
jaxnn_hard_tanh ✅
jaxnn_hard_tanh_f64 ✅
jaxnn_hard_tanh_dynamic_dynamic ✅
jaxnn_hard_tanh_dynamic_dynamic_f64 ✅
jaxnn_hard_tanh_dynamic ✅
jaxnn_hard_tanh_dynamic_f64 ✅ |
0.12.1 |
| nn.hardmax |
Hardmax |
jaxnn_hardmax_default_axis ✅
jaxnn_hardmax_axis0 ✅
jaxnn_hardmax_dynamic_dynamic ✅
jaxnn_hardmax_dynamic ✅
jaxnn_hardmax_vmap_batching ✅ |
0.12.1 |
| nn.identity |
Identity |
jaxnn_identity ✅
jaxnn_identity_f64 ✅
jaxnn_identity_1 ✅
jaxnn_identity_1_f64 ✅
jaxnn_identity_basic ✅
jaxnn_identity_basic_f64 ✅
jaxnn_identity_dynamic_dynamic ✅
jaxnn_identity_dynamic_dynamic_f64 ✅
jaxnn_identity_dynamic ✅
jaxnn_identity_dynamic_f64 ✅ |
0.7.1 |
| nn.leaky_relu |
LeakyRelu |
jaxnn_leaky_relu ✅
jaxnn_leaky_relu_1 ✅
jaxnn_leaky_relu_default_dynamic ✅
jaxnn_leaky_relu_default ✅
jaxnn_leaky_relu_custom ✅
leaky_relu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.log1mexp |
Exp Less Log Neg Sub Where |
jaxnn_log1mexp_basic ✅
jaxnn_log1mexp_basic_f64 ✅
jaxnn_log1mexp_dynamic_dynamic ✅
jaxnn_log1mexp_dynamic_dynamic_f64 ✅
jaxnn_log1mexp_dynamic ✅
jaxnn_log1mexp_dynamic_f64 ✅ |
0.12.1 |
| nn.log_sigmoid |
Log Sigmoid |
jaxnn_log_sigmoid ✅
jaxnn_log_sigmoid_f64 ✅
jaxnn_log_sigmoid_dynamic_dynamic ✅
jaxnn_log_sigmoid_dynamic_dynamic_f64 ✅
jaxnn_log_sigmoid_dynamic ✅
jaxnn_log_sigmoid_dynamic_f64 ✅ |
0.12.1 |
| nn.log_softmax |
LogSoftmax |
jaxnn_log_softmax_default_axis_dynamic ✅
jaxnn_log_softmax_default_axis_dynamic_f64 ✅
jaxnn_log_softmax_default_axis ✅
jaxnn_log_softmax_default_axis_f64 ✅
jaxnn_log_softmax_axis0 ✅
jaxnn_log_softmax_axis0_f64 ✅
jaxnn_log_softmax_axis_last_3d ✅
jaxnn_log_softmax_axis_last_3d_f64 ✅
log_softmax_grad_issue_batch_diff_rules ✅ |
0.12.1 |
| nn.logmeanexp |
ReduceLogSumExp Sub |
jaxnn_logmeanexp_axis1 ✅
jaxnn_logmeanexp_axis1_f64 ✅
jaxnn_logmeanexp_axis12_keepdims ✅
jaxnn_logmeanexp_axis12_keepdims_f64 ✅
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic ✅
jaxnn_logmeanexp_dynamic_static_reduction_axis_dynamic_f64 ✅
jaxnn_logmeanexp_dynamic_static_reduction_axis ✅
jaxnn_logmeanexp_dynamic_static_reduction_axis_f64 ✅ |
0.12.1 |
| nn.logsumexp |
ReduceLogSumExp |
jaxnn_logsumexp_axis1 ✅
jaxnn_logsumexp_axis1_f64 ✅
jaxnn_logsumexp_axis12_keepdims ✅
jaxnn_logsumexp_axis12_keepdims_f64 ✅
jaxscipy_logsumexp_axis_last_dynamic ✅
jaxscipy_logsumexp_axis_last_dynamic_f64 ✅
jaxscipy_logsumexp_axis_last ✅
jaxscipy_logsumexp_axis_last_f64 ✅
logsumexp_grad_issue_batch_diff_rules ✅ |
0.12.1 |
| nn.mish |
Mish |
jaxnn_mish ✅
jaxnn_mish_1 ✅
jaxnn_mish_basic ✅
mish_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.one_hot |
OneHot |
jaxnn_one_hot_default ✅
jaxnn_one_hot_axis0 ✅ |
0.12.1 |
| nn.relu |
Relu |
jaxnn_relu ✅
jaxnn_relu_f64 ✅
jaxnn_relu_1 ✅
jaxnn_relu_1_f64 ✅
jaxnn_relu_basic ✅
jaxnn_relu_basic_f64 ✅
jaxnn_relu_dynamic_dynamic ✅
jaxnn_relu_dynamic_dynamic_f64 ✅
jaxnn_relu_dynamic ✅
jaxnn_relu_dynamic_f64 ✅
relu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.relu6 |
Max Min |
jaxnn_relu6 ✅
jaxnn_relu6_f64 ✅
jaxnn_relu6_dynamic_dynamic ✅
jaxnn_relu6_dynamic_dynamic_f64 ✅
jaxnn_relu6_dynamic ✅
jaxnn_relu6_dynamic_f64 ✅ |
0.12.1 |
| nn.scaled_dot_general |
Einsum Gemm MatMul |
jaxnn_scaled_dot_general_basic ✅
jaxnn_scaled_dot_general_batched ✅
jaxnn_scaled_dot_general_batched_f64 ✅ |
0.12.1 |
| nn.scaled_matmul |
MatMul Mul Reshape Tile Transpose Unsqueeze |
jaxnn_scaled_matmul_basic ✅ |
0.12.1 |
| nn.selu |
Selu |
jaxnn_selu ✅
jaxnn_selu_1 ✅
jaxnn_selu_basic_dynamic ✅
jaxnn_selu_basic ✅
selu_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.sigmoid |
Sigmoid |
jaxnn_sigmoid ✅
jaxnn_sigmoid_f64 ✅
jaxnn_sigmoid_1 ✅
jaxnn_sigmoid_1_f64 ✅
sigmoid_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.silu |
Mul Sigmoid Swish |
jaxnn_silu ✅
jaxnn_silu_f64 ✅
jaxnn_silu_dynamic_dynamic ✅
jaxnn_silu_dynamic_dynamic_f64 ✅
jaxnn_silu_dynamic ✅
jaxnn_silu_dynamic_f64 ✅
jaxnn_swish_alias ✅
jaxnn_swish_alias_f64 ✅
silu_grad_issue_batch_diff_rules ✅ |
0.12.1 |
| nn.soft_sign |
Softsign |
jaxnn_soft_sign ✅
jaxnn_soft_sign_1 ✅
jaxnn_softsign_basic ✅
softsign_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.softmax |
Softmax |
softmax ✅
softmax_f64 ✅
softmax_2d ✅
softmax_2d_f64 ✅
softmax_3d ✅
softmax_3d_f64 ✅
softmax_mask_where ✅
softmax_mask_where_f64 ✅
softmax_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.softplus |
Softplus |
jaxnn_softplus ✅
jaxnn_softplus_1 ✅
jaxnn_softplus_basic ✅
softplus_grad_issue_batch_diff_rules ✅ |
0.7.1 |
| nn.sparse_plus |
Add Greater Less Mul Where |
jaxnn_sparse_plus ✅
jaxnn_sparse_plus_f64 ✅
jaxnn_sparse_plus_dynamic_dynamic ✅
jaxnn_sparse_plus_dynamic_dynamic_f64 ✅
jaxnn_sparse_plus_dynamic ✅
jaxnn_sparse_plus_dynamic_f64 ✅ |
0.12.1 |
| nn.sparse_sigmoid |
Add Greater Less Mul Where |
jaxnn_sparse_sigmoid ✅
jaxnn_sparse_sigmoid_f64 ✅
jaxnn_sparse_sigmoid_dynamic_dynamic ✅
jaxnn_sparse_sigmoid_dynamic_dynamic_f64 ✅
jaxnn_sparse_sigmoid_dynamic ✅
jaxnn_sparse_sigmoid_dynamic_f64 ✅ |
0.12.1 |
| nn.squareplus |
Add Mul Sqrt |
jaxnn_squareplus_default_b ✅
jaxnn_squareplus_default_b_f64 ✅
jaxnn_squareplus_broadcast_b ✅
jaxnn_squareplus_broadcast_b_f64 ✅ |
0.12.1 |
| nn.standardize |
MeanVarianceNormalization |
jaxnn_standardize_axis_last_eps0 ✅
jaxnn_standardize_axis_tuple_eps0_dynamic ✅
jaxnn_standardize_axis_tuple_eps0 ✅
standardize_grad_issue_batch_diff_rules ✅ |
0.12.1 |
| nn.tanh |
Tanh |
jaxnn_tanh ✅
jaxnn_tanh_f64 ✅
jaxnn_tanh_dynamic_dynamic ✅
jaxnn_tanh_dynamic_dynamic_f64 ✅
jaxnn_tanh_dynamic ✅
jaxnn_tanh_dynamic_f64 ✅ |
0.12.1 |
| nn.thresholded_relu |
ThresholdedRelu |
jaxnn_thresholded_relu_default ✅
jaxnn_thresholded_relu_custom_dynamic ✅
jaxnn_thresholded_relu_custom ✅ |
0.12.1 |
| nn.truncated_normal |
➖ |
initializer ✅
random_truncated_normal_positional ✅
flax_dense_like_init ✅ |
0.7.1 |
| nnx.avg_pool |
AveragePool GlobalAveragePool Transpose |
avg_pool_dynamic ✅
avg_pool ✅
avg_pool_same_padding_dynamic ✅
avg_pool_same_padding ✅
avg_pool_default_padding_dynamic ✅
avg_pool_default_padding ✅
avg_pool_stride1_dynamic ✅
avg_pool_stride1 ✅
avg_pool_win3x3_stride2_dynamic ✅
avg_pool_win3x3_stride2 ✅
avg_pool_stride_none_dynamic ✅
avg_pool_stride_none ✅
avg_pool_count_include_pad_false_dynamic ✅
avg_pool_count_include_pad_false ✅
avg_pool_global_window_dynamic ✅
avg_pool_global_window ✅ |
0.1.0 |
| nnx.batch_norm |
BatchNormalization |
batch_norm_no_bias_no_scale_dynamic ✅
batch_norm_no_bias_no_scale ✅
batch_norm_bias_no_scale_dynamic ✅
batch_norm_bias_no_scale ✅
batch_norm_no_bias_scale_dynamic ✅
batch_norm_no_bias_scale ✅
batch_norm_bias_scale_dynamic ✅
batch_norm_bias_scale ✅
batch_norm_3d_dynamic ✅
batch_norm_3d ✅
batch_norm_4d_dynamic ✅
batch_norm_4d ✅
batch_norm_4d_no_bias_no_scale_dynamic ✅
batch_norm_4d_no_bias_no_scale ✅ |
0.1.0 |
| nnx.combine_masks |
And Cast |
combine_masks_two ✅
combine_masks_with_none ✅ |
0.12.2 |
| nnx.conv |
CastLike Conv Reshape Transpose |
conv_basic_bias_dynamic ✅
conv_basic_bias ✅
conv_basic_bias_2 ✅
conv_basic_bias_3 ✅
conv_stride2_bias ✅
conv_no_bias_dynamic ✅
conv_no_bias ✅
conv_valid_padding ✅
conv_stride1 ✅
conv_stride2 ✅
conv_2d_reflect_padding ✅
conv_2d_circular_padding ✅
conv_different_kernel ✅
conv_float64 ✅
conv_single_batch ✅
conv_large_batch ✅
conv_1d_causal_padding ✅
conv_1d ✅
conv_1d_more_1d_inputs ✅
conv_1d_more_2d_inputs ✅
conv_1d_large_kernel ✅
conv_1d_dilation ✅
conv_1d_stride_dilation ✅
conv_2d_asymmetric_kernel ✅
conv_2d_asymmetric_stride ✅
conv_2d_asymmetric_dilation ✅
conv_2d_large_dilation ✅
conv_2d_large_stride ✅
conv_2d_mixed_params ✅
conv_2d_same_padding_mixed_dilation ✅
conv_3d_basic ✅
conv_3d_stride ✅
conv_3d_asymmetric ✅
conv_3d_dilation ✅
conv_2d_small_input ✅
conv_2d_many_channels ✅
conv_1d_wide_input ✅
conv_2d_kernel_1x1 ✅
conv_1d_kernel_1 ✅
conv_2d_group_conv ✅
conv_1d_group_conv_more_dims ✅
conv_2d_depthwise ✅
conv_1d_complex_on_4d ✅
conv_2d_complex_on_5d ✅
conv_2d_asymmetric_on_5d ✅
conv_1d_high_dilation_on_3d ✅
conv_1d_large_kernel_on_4d ✅
conv_2d_group_stride_dilation ✅
conv_1d_group_on_higher_dim ✅
conv_1d_same_padding_on_3d ✅
conv_3d_group_complex ✅
conv_1d_unit_group_on_multi_dim ✅ |
0.1.0 |
| nnx.dot_product_attention |
Add Attention MatMul Mul Softmax Transpose Where |
dpa_basic ✅
dpa_basic_f64 ✅
dpa_with_tensor_mask ✅
dpa_with_bias ✅
dpa_with_causal_mask ✅
dpa_with_causal_mask_f64 ✅
dpa_with_mask_and_bias ✅
dpa_gqa_basic ✅
dpa_gqa_basic_opset23 ✅ |
0.1.0 |
| nnx.dropout |
Constant Dropout |
dropout_init_params_dynamic ✅
dropout_init_params_dynamic_f64 ✅
dropout_init_params ✅
dropout_init_params_f64 ✅
dropout_call_params_dynamic ✅
dropout_call_params_dynamic_f64 ✅
dropout_call_params ✅
dropout_call_params_f64 ✅ |
0.1.0 |
| nnx.einsum |
Add Einsum |
einsum_module_with_bias ✅
einsum_module_with_bias_f64 ✅
einsum_module_no_bias ✅
einsum_module_no_bias_f64 ✅ |
0.4.2 |
| nnx.elu |
Elu |
elu ✅
elu_default_dynamic ✅
elu_default ✅
elu_alpha ✅ |
0.2.0 |
| nnx.embed |
Gather |
token_embedding_dynamic ✅
token_embedding_dynamic_f64 ✅
token_embedding ✅
token_embedding_f64 ✅
positional_embedding_dynamic ✅
positional_embedding_dynamic_f64 ✅
positional_embedding ✅
positional_embedding_f64 ✅ |
0.7.0 |
| nnx.flip_sequences |
GatherElements Slice |
flip_sequences_batch_major_with_lengths ✅
flip_sequences_batch_major_with_lengths_f64 ✅
flip_sequences_batch_major_no_lengths_dynamic ✅
flip_sequences_batch_major_no_lengths_dynamic_f64 ✅
flip_sequences_batch_major_no_lengths ✅
flip_sequences_batch_major_no_lengths_f64 ✅
flip_sequences_time_major_with_lengths ✅
flip_sequences_time_major_with_lengths_f64 ✅ |
0.12.2 |
| nnx.gelu |
Gelu |
gelu ✅
gelu_1 ✅
gelu_2 ✅
gelu_2_f64 ✅
gelu_3_dynamic ✅
gelu_3_dynamic_f64 ✅
gelu_3 ✅
gelu_3_f64 ✅
gelu_4 ✅
gelu_4_f64 ✅
gelu_5_dynamic ✅
gelu_5_dynamic_f64 ✅
gelu_5 ✅
gelu_5_f64 ✅ |
0.1.0 |
| nnx.glu |
Mul Sigmoid Split |
glu_last_axis ✅
glu_last_axis_f64 ✅
glu_axis_1_dynamic ✅
glu_axis_1_dynamic_f64 ✅
glu_axis_1 ✅
glu_axis_1_f64 ✅ |
0.12.2 |
| nnx.group_norm |
GroupNormalization |
group_norm ✅
group_norm_rank2_dynamic ✅
group_norm_rank2 ✅
group_norm_rank4 ✅
group_norm_no_bias_dynamic ✅
group_norm_no_bias ✅
group_norm_no_bias_no_scale_dynamic ✅
group_norm_no_bias_no_scale ✅
group_norm_bias_no_scale_dynamic ✅
group_norm_bias_no_scale ✅
group_norm_no_scale_dynamic ✅
group_norm_no_scale ✅
group_norm_no_bias_scale_dynamic ✅
group_norm_no_bias_scale ✅
group_norm_bias_scale_dynamic ✅
group_norm_bias_scale ✅ |
0.2.0 |
| nnx.hard_tanh |
Clip |
hard_tanh_basic_dynamic ✅
hard_tanh_basic_dynamic_f64 ✅
hard_tanh_basic ✅
hard_tanh_basic_f64 ✅ |
0.12.2 |
| nnx.layer_norm |
LayerNormalization |
layer_norm_dynamic ✅
layer_norm ✅
layer_norm_no_bias_no_scale_dynamic ✅
layer_norm_no_bias_no_scale ✅
layer_norm_bias_no_scale_dynamic ✅
layer_norm_bias_no_scale ✅
layer_norm_no_bias_scale_dynamic ✅
layer_norm_no_bias_scale ✅
layer_norm_bias_scale_dynamic ✅
layer_norm_bias_scale ✅
layer_norm_multiaxis_dynamic ✅
layer_norm_multiaxis ✅
layer_norm_symbolic_batch_dynamic ✅
layer_norm_symbolic_batch ✅
layer_norm_symbolic_batch_seq10_feat3_dynamic ✅
layer_norm_symbolic_batch_seq10_feat3 ✅
layer_norm_symbolic_batch_seq10_feat3_2_dynamic ✅
layer_norm_symbolic_batch_seq10_feat3_2 ✅
layer_norm_negative_axis_no_div_dynamic ✅
layer_norm_negative_axis_no_div ✅ |
0.1.0 |
| nnx.leaky_relu |
LeakyRelu |
leaky_relu ✅
leaky_relu_default_dynamic ✅
leaky_relu_default ✅
leaky_relu_custom ✅ |
0.2.0 |
| nnx.linear |
CastLike Concat Gemm Reshape Shape Slice |
linear_symbolic_batch_dynamic ✅
linear_symbolic_batch_dynamic_f64 ✅
linear_symbolic_batch ✅
linear_symbolic_batch_f64 ✅
linear_high_rank_dynamic ✅
linear_high_rank_dynamic_f64 ✅
linear_high_rank_static ✅
linear_high_rank_static_f64 ✅
linear_no_bias_dynamic ✅
linear_no_bias_dynamic_f64 ✅
linear_no_bias ✅
linear_no_bias_f64 ✅
linear_high_rank_no_bias_dynamic ✅
linear_high_rank_no_bias_dynamic_f64 ✅
linear_high_rank_no_bias ✅
linear_high_rank_no_bias_f64 ✅
linear_merge_symbolic_dim_dynamic ✅ |
0.1.0 |
| nnx.linear_general |
CastLike Concat Gemm Reshape Shape Slice |
linear_general_merge_symbolic_dim_dynamic ✅
linear_general_dynamic ✅
linear_general ✅
linear_general_2 ✅
linear_general_3 ✅
linear_general_4 ✅
linear_general_abstract_eval_axes ✅
linear_general_abstract_eval_axes_pair ✅
dynamic_batch_and_feature_dims_dynamic ✅ |
0.1.0 |
| nnx.log_sigmoid |
Neg Softplus |
log_sigmoid_basic_dynamic ✅
log_sigmoid_basic_dynamic_f64 ✅
log_sigmoid_basic ✅
log_sigmoid_basic_f64 ✅ |
0.12.2 |
| nnx.log_softmax |
LogSoftmax |
log_softmax ✅
log_softmax_f64 ✅
log_softmax_default_axis_dynamic ✅
log_softmax_default_axis_dynamic_f64 ✅
log_softmax_default_axis ✅
log_softmax_default_axis_f64 ✅
log_softmax_axis0 ✅
log_softmax_axis0_f64 ✅ |
0.2.0 |
| nnx.lora |
Add MatMul |
lora_basic_dynamic ✅
lora_basic_dynamic_f64 ✅
lora_basic ✅
lora_basic_f64 ✅
lora_static ✅
lora_static_f64 ✅ |
0.12.2 |
| nnx.lora_linear |
Add MatMul |
lora_linear_basic_dynamic ✅
lora_linear_basic_dynamic_f64 ✅
lora_linear_basic ✅
lora_linear_basic_f64 ✅
lora_linear_high_rank_dynamic ✅
lora_linear_high_rank_dynamic_f64 ✅
lora_linear_high_rank ✅
lora_linear_high_rank_f64 ✅ |
0.12.2 |
| nnx.max_pool |
GlobalMaxPool MaxPool |
max_pool ✅
max_pool_same_padding ✅
max_pool_basic ✅
max_pool_same_dynamic ✅
max_pool_same ✅
max_pool_global_window_dynamic ✅
max_pool_global_window ✅ |
0.2.0 |
| nnx.prelu |
PRelu |
prelu_default ✅
prelu_custom_slope_dynamic ✅
prelu_custom_slope ✅ |
0.12.1 |
| nnx.relu |
Relu |
relu_1d ✅
relu_1d_f64 ✅
relu_4d_dynamic ✅
relu_4d_dynamic_f64 ✅
relu_4d ✅
relu_4d_f64 ✅ |
0.2.0 |
| nnx.rms_norm |
RMSNormalization |
rms_norm_basic ✅
rms_norm_use_scale_false ✅
rms_norm_4d_dynamic_dynamic ✅
rms_norm_4d_dynamic ✅
rms_norm_4d_dynamic_no_scale_dynamic ✅
rms_norm_4d_dynamic_no_scale ✅ |
0.2.0 |
| nnx.sigmoid |
Sigmoid |
sigmoid_dynamic ✅
sigmoid_dynamic_f64 ✅
sigmoid ✅
sigmoid_f64 ✅ |
0.2.0 |
| nnx.softmax |
Softmax |
softmax_dynamic ✅
softmax_dynamic_f64 ✅
softmax ✅
softmax_f64 ✅ |
0.1.0 |
| nnx.softplus |
Softplus |
softplus ✅ |
0.1.0 |
| nnx.tanh |
Tanh |
tanh ✅
tanh_f64 ✅ |
0.1.0 |
| random.bernoulli |
Bernoulli |
bernoulli_scalar_prob_shape ✅
bernoulli_tensor_prob ✅ |
0.12.1 |
| random.random_bits |
Cast Floor RandomUniformLike RandomUniform |
random_bits_uint32 ✅
random_bits_uint32_f64 ✅ |
0.7.2 |
| random.random_categorical |
Multinomial |
random_categorical_logits_batch ✅
random_categorical_logits_batch_opset23 ✅
random_categorical_logits_rank3 ✅
random_categorical_logits_rank3_opset23 ✅
random_categorical_logits_symbolic_batch_dynamic ✅
random_categorical_logits_symbolic_batch_opset23_dynamic ✅ |
0.12.1 |
| random.random_fold_in |
Identity |
random_fold_in_passthrough ✅
random_fold_in_passthrough_f64 ✅ |
0.2.0 |
| random.random_normal |
RandomNormalLike RandomNormal |
random_normal_f32_2x3 ✅ |
0.12.1 |
| random.random_seed |
Cast Concat |
random_seed_basic ✅
random_seed_basic_f64 ✅ |
0.2.0 |