| MlpExample |
A simple Equinox MLP (converter pipeline). |
mlp_training_mode ✅
mlp_training_mode_f64 ✅
mlp_inference_mode ✅
mlp_inference_mode_f64 ✅
mlp_batched_training_mode ✅
mlp_batched_training_mode_f64 ✅ |
0.8.0 |
| SimpleLinearExample |
A simple linear layer example using Equinox (converter). |
simple_linear ✅
simple_linear_f64 ✅
nn_linear ✅
nn_linear_f64 ✅ |
0.7.1 |
| Attention |
Multi-Head Self-Attention using Equinox modules. |
attention_dynamic ✅
attention ✅ |
0.10.0 |
| AttentionCore |
Multi-Head Self-Attention without rotary processing. |
attention_core_dynamic ✅
attention_core ✅ |
0.10.0 |
| Block |
Transformer Block. |
transformer_block_dynamic ✅
transformer_block ✅ |
0.10.0 |
| DINOv3VisionTransformer |
DINOv3 Vision Transformer |
eqx_dinov3_vit_Ti14_dynamic ✅
eqx_dinov3_vit_Ti14 ✅
eqx_dinov3_vit_S14_dynamic ✅
eqx_dinov3_vit_S14 ✅
eqx_dinov3_vit_B14_dynamic ✅
eqx_dinov3_vit_B14 ✅
eqx_dinov3_vit_S16_dynamic ✅
eqx_dinov3_vit_S16 ✅ |
0.10.0 |
| PatchEmbed |
Image to Patch Embedding. |
patch_embed ✅ |
0.10.0 |
| AttentionBlock |
Self-attention block with rotary embeddings and sinks. |
gpt_oss_attention_block ✅ |
0.10.2 |
| MLPBlock |
Mixture-of-experts SwiGLU feed-forward block. |
gpt_oss_mlp_block ✅ |
0.10.2 |
| RMSNorm |
Root mean square normalisation used by GPT-OSS. |
gpt_oss_rmsnorm_dynamic ✅
gpt_oss_rmsnorm ✅ |
0.10.2 |
| Transformer |
Full GPT-OSS Transformer stack. |
gpt_oss_transformer ✅ |
0.10.2 |
| TransformerBlock |
GPT-OSS Transformer layer (attention + MoE). |
gpt_oss_transformer_block_dynamic ✅
gpt_oss_transformer_block ✅ |
0.10.2 |
| GPT |
A simple GPT model that reuses nnx.MultiHeadAttention. |
gpt_dynamic ✅
gpt ✅ |
0.7.0 |
| GPT_Attention |
A multi-head attention layer. |
gpt_attention ✅ |
0.7.1 |
| GPT_CausalSelfAttention |
A causal self-attention module. |
gpt_causal_self_attention_dynamic ✅
gpt_causal_self_attention ✅ |
0.7.0 |
| GPT_Embeddings |
Combines token and position embeddings with dropout. |
gpt_embeddings_dynamic ✅
gpt_embeddings ✅ |
0.7.0 |
| GPT_Head |
The head of the GPT model. |
gpt_head_dynamic ✅
gpt_head ✅ |
0.7.0 |
| GPT_MLP |
An MLP block with GELU activation from nanoGPT. |
gpt_mlp_dynamic ✅
gpt_mlp ✅ |
0.7.0 |
| GPT_PositionEmbedding |
A positional embedding layer using nnx.Embed. |
gpt_position_embedding ✅ |
0.7.0 |
| GPT_TokenEmbedding |
A token embedding layer using nnx.Embed. |
gpt_token_embedding_dynamic ✅
gpt_token_embedding ✅ |
0.7.0 |
| GPT_TransformerBlock |
A transformer block combining attention and MLP. |
gpt_block_dynamic ✅
gpt_block ✅ |
0.7.0 |
| GPT_TransformerStack |
A stack of transformer blocks. |
gpt_transformer_stack_dynamic ✅
gpt_transformer_stack ✅ |
0.7.0 |
| GPT_broadcast_add |
Simple dynamic broadcast + add |
gpt_broadcast_add_dynamic_dynamic ✅
gpt_broadcast_add_dynamic_dynamic_f64 ✅
gpt_broadcast_add_dynamic ✅
gpt_broadcast_add_dynamic_f64 ✅ |
0.7.0 |
| cfl_timestep |
Tests the CFL condition timestep calculation. |
cfl_timestep_f64 ✅ |
0.6.5 |
| weno_reconstruction |
Tests the complex arithmetic pattern found in WENO schemes. |
weno_reconstruction_f64 ✅ |
0.6.5 |
| fori_loop_test |
fori_loop_test: demonstrates jax.lax.fori_loop with a simple loop. |
fori_loop_test ✅
fori_loop_test_f64 ✅ |
0.6.3 |
| issue18_abs |
Test jnp.abs from issue 18 |
abs_fn ✅
abs_fn_f64 ✅ |
0.6.3 |
| issue18_arange |
Test jnp.arange from issue 18 |
arange_fn ✅ |
0.6.3 |
| issue18_fori_loop |
Test jax.lax.fori_loop from issue 18 |
fori_loop_fn ✅
fori_loop_fn_f64 ✅ |
0.6.3 |
| issue18_linspace |
Test jnp.linspace from issue 18 |
linspace_fn ✅ |
0.6.3 |
| issue18_scan |
Test jax.lax.scan from issue 18 (no xs) |
scan_fn ✅ |
0.6.3 |
| issue18_sign |
Test jnp.sign from issue 18 |
sign_fn ✅
sign_fn_f64 ✅ |
0.6.3 |
| issue18_where |
Test jnp.where from issue 18 |
where_fn ✅
where_fn_f64 ✅ |
0.6.3 |
| issue18_while_loop |
Test jax.lax.while_loop from issue 18 |
while_loop_fn ✅ |
0.9.0 |
| select_test |
Demonstrates jnp.select with scalar and tensor predicates. |
select_test_all_options ✅
select_test_scalar_select_option_0 ✅
select_test_scalar_select_option_1 ✅
select_test_scalar_select_option_2 ✅
select_test_default_case ✅ |
0.9.0 |
| sort_test |
sort_test: demonstrates jnp.sort on slices of an input array. |
sort_test_basic ✅ |
0.9.0 |
| cond_scatter_add_mul |
Scatter add/mul inside conditional branches (converter). |
cond_scatter_add_mul_f64_a ✅
cond_scatter_add_mul_f64_b ✅ |
0.8.0 |
| cond_scatter_repro |
Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. |
cond_scatter_repro_f64 ✅ |
0.6.4 |
| remat2 |
Tests a simple case of jax.checkpoint (also known as jax.remat2). |
checkpoint_scalar_f32 ✅
checkpoint_scalar_f32_f64 ✅ |
0.6.5 |
| scatter_window |
Window-scatter (H×W patch) with implicit batch (depth-3 path). Exercises GatherScatterMode.FILL_OR_DROP and double precision. Regression of a prior conversion failure. |
scatter_window_update_f64_example ✅ |
0.7.4 |
| two_times_silu |
Regression for calling jax.nn.silu twice (issue #139). |
two_times_silu_scalar ✅
two_times_silu_scalar_f64 ✅ |
0.10.2 |
| LinenCNN |
A simple convolutional neural network (CNN). |
simple_cnn_static ✅
simple_cnn_dynamic ✅ |
0.11.0 |
| LinenMLP |
A simple Linen MLP with BatchNorm, Dropout, and GELU activation. |
simple_linen_mlp_static ✅
simple_linen_mlp_static_f64 ✅
simple_linen_mlp_dynamic ✅
simple_linen_mlp_dynamic_f64 ✅
simple_linen_mlp_with_call_params_dynamic ✅
simple_linen_mlp_with_call_params_dynamic_f64 ✅
simple_linen_mlp_with_call_params ✅
simple_linen_mlp_with_call_params_f64 ✅ |
0.11.0 |
| LinenMLPSequential |
A Linen MLP built from flax.linen.Sequential. |
simple_linen_mlp_sequential_static ✅
simple_linen_mlp_sequential_static_f64 ✅
simple_linen_mlp_sequential_dynamic ✅
simple_linen_mlp_sequential_dynamic_f64 ✅ |
0.11.0 |
| MaxText_deepseek2_16b |
MaxText model: deepseek2-16b |
maxtext_deepseek2-16b ✅ |
0.11.1 |
| MaxText_deepseek2_236b |
MaxText model: deepseek2-236b |
maxtext_deepseek2-236b ✅ |
0.11.1 |
| MaxText_deepseek3_671b |
MaxText model: deepseek3-671b |
maxtext_deepseek3-671b ✅ |
0.11.1 |
| MaxText_deepseek3_671b_2dfsdp |
MaxText model: deepseek3-671b-2dfsdp |
maxtext_deepseek3-671b-2dfsdp ✅ |
0.11.1 |
| MaxText_deepseek3_test |
MaxText model: deepseek3-test |
maxtext_deepseek3-test ✅ |
0.11.1 |
| MaxText_deepseek3_tiny |
MaxText model: deepseek3-tiny |
maxtext_deepseek3-tiny ✅ |
0.11.1 |
| MaxText_gemma2_27b |
MaxText model: gemma2-27b |
maxtext_gemma2-27b ✅ |
0.11.1 |
| MaxText_gemma2_2b |
MaxText model: gemma2-2b |
maxtext_gemma2-2b ✅ |
0.11.1 |
| MaxText_gemma2_9b |
MaxText model: gemma2-9b |
maxtext_gemma2-9b ✅ |
0.11.1 |
| MaxText_gemma3_12b |
MaxText model: gemma3-12b |
maxtext_gemma3-12b ✅ |
0.11.1 |
| MaxText_gemma3_27b |
MaxText model: gemma3-27b |
maxtext_gemma3-27b ✅ |
0.11.1 |
| MaxText_gemma3_4b |
MaxText model: gemma3-4b |
maxtext_gemma3-4b ✅ |
0.11.1 |
| MaxText_gemma_2b |
MaxText model: gemma-2b |
maxtext_gemma-2b ✅ |
0.11.1 |
| MaxText_gemma_7b |
MaxText model: gemma-7b |
maxtext_gemma-7b ✅ |
0.11.1 |
| MaxText_gpt3_175b |
MaxText model: gpt3-175b |
maxtext_gpt3-175b ✅ |
0.11.1 |
| MaxText_gpt3_22b |
MaxText model: gpt3-22b |
maxtext_gpt3-22b ✅ |
0.11.1 |
| MaxText_gpt3_52k |
MaxText model: gpt3-52k |
maxtext_gpt3-52k ✅ |
0.11.1 |
| MaxText_gpt3_6b |
MaxText model: gpt3-6b |
maxtext_gpt3-6b ✅ |
0.11.1 |
| MaxText_kimi_k2_1t |
MaxText model: kimi-k2-1t |
maxtext_kimi-k2-1t ✅ |
0.11.1 |
| MaxText_llama2_13b |
MaxText model: llama2-13b |
maxtext_llama2-13b ✅ |
0.11.1 |
| MaxText_llama2_70b |
MaxText model: llama2-70b |
maxtext_llama2-70b ✅ |
0.11.1 |
| MaxText_llama2_7b |
MaxText model: llama2-7b |
maxtext_llama2-7b ✅ |
0.11.1 |
| MaxText_llama3_1_405b |
MaxText model: llama3.1-405b |
maxtext_llama3.1-405b ✅ |
0.11.1 |
| MaxText_llama3_1_70b |
MaxText model: llama3.1-70b |
maxtext_llama3.1-70b ✅ |
0.11.1 |
| MaxText_llama3_1_8b |
MaxText model: llama3.1-8b |
maxtext_llama3.1-8b ✅ |
0.11.1 |
| MaxText_llama3_3_70b |
MaxText model: llama3.3-70b |
maxtext_llama3.3-70b ✅ |
0.11.1 |
| MaxText_llama3_405b |
MaxText model: llama3-405b |
maxtext_llama3-405b ✅ |
0.11.1 |
| MaxText_llama3_70b |
MaxText model: llama3-70b |
maxtext_llama3-70b ✅ |
0.11.1 |
| MaxText_llama3_8b |
MaxText model: llama3-8b |
maxtext_llama3-8b ✅ |
0.11.1 |
| MaxText_mistral_7b |
MaxText model: mistral-7b |
maxtext_mistral-7b ✅ |
0.11.1 |
| MaxText_qwen3_0_6b |
MaxText model: qwen3-0.6b |
maxtext_qwen3-0.6b ✅ |
0.11.1 |
| MaxText_qwen3_14b |
MaxText model: qwen3-14b |
maxtext_qwen3-14b ✅ |
0.11.1 |
| MaxText_qwen3_235b_a22b |
MaxText model: qwen3-235b-a22b |
maxtext_qwen3-235b-a22b ✅ |
0.11.1 |
| MaxText_qwen3_30b_a3b |
MaxText model: qwen3-30b-a3b |
maxtext_qwen3-30b-a3b ✅ |
0.11.1 |
| MaxText_qwen3_32b |
MaxText model: qwen3-32b |
maxtext_qwen3-32b ✅ |
0.11.1 |
| MaxText_qwen3_480b_a35b |
MaxText model: qwen3-480b-a35b |
maxtext_qwen3-480b-a35b ✅ |
0.11.1 |
| MaxText_qwen3_4b |
MaxText model: qwen3-4b |
maxtext_qwen3-4b ✅ |
0.11.1 |
| MaxText_qwen3_4b_thinking_2507 |
MaxText model: qwen3-4b-thinking-2507 |
maxtext_qwen3-4b-thinking-2507 ✅ |
0.11.1 |
| MaxText_qwen3_8b |
MaxText model: qwen3-8b |
maxtext_qwen3-8b ✅ |
0.11.1 |
| MaxText_qwen3_next_80b_a3b |
MaxText model: qwen3-next-80b-a3b |
maxtext_qwen3-next-80b-a3b ✅ |
0.11.1 |
| MaxText_qwen3_omni_30b_a3b |
MaxText model: qwen3-omni-30b-a3b |
maxtext_qwen3-omni-30b-a3b ✅ |
0.11.1 |
| AutoEncoder |
A simple autoencoder example (converter pipeline). |
simple_autoencoder ✅
simple_autoencoder_f64 ✅ |
0.2.0 |
| CNN |
A simple convolutional neural network (CNN). |
simple_cnn_static ✅
simple_cnn_dynamic ✅ |
0.2.0 |
| DepthToSpaceResNet |
Residual conv stack followed by dm_pix.depth_to_space upsampling. |
depth_to_space_resnet_static ✅
depth_to_space_resnet_inputs_outputs_as_nchw ✅
depth_to_space_resnet_inputs_outputs_as_nchw_dynamic_hw ✅
depth_to_space_resnet_scaled_inputs_outputs_as_nchw ✅ |
0.11.2 |
| ForiLoop |
fori_loop example using nnx-compatible primitives (converter). |
fori_loop_counter ✅
fori_loop_counter_f64 ✅ |
0.5.1 |
| GRUCell |
Flax/nnx GRUCell lowered through converter primitives. |
gru_cell_basic ✅ |
0.7.2 |
| MLP |
A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. |
simple_mlp_static ✅
simple_mlp_static_f64 ✅
simple_mlp_dynamic ✅
simple_mlp_dynamic_f64 ✅
simple_mlp_with_call_params_dynamic ✅
simple_mlp_with_call_params_dynamic_f64 ✅
simple_mlp_with_call_params ✅
simple_mlp_with_call_params_f64 ✅ |
0.1.0 |
| MultiHeadAttention |
nnx.MultiHeadAttention exercised in several configurations, including custom attention_fn and symbolic batch variants. |
multihead_attention_nn_dynamic ✅
multihead_attention_nn ✅
multihead_attention_nnx_dynamic ✅
multihead_attention_nnx ✅
multihead_attention_2_nnx_dynamic ✅
multihead_attention_2_nnx ✅ |
0.2.0 |
| NestedResidualGroup |
Nested residual blocks inside a residual group; regression harness for issue #173. |
nested_residual_group_static ✅
nested_residual_group_static_nchw ✅
nested_residual_group_with_lead_conv_static ✅
nested_residual_stack_static ✅
nested_residual_stack_static_no_extra_transpose ✅
nested_residual_stack_with_lead_conv_static ✅ |
0.12.0 |
| ResBlock |
Residual block with squeeze-and-excite channel attention (from issue #168). |
resblock_channel_attention_static ✅
resblock_channel_attention_static_nchw ✅
resblock_channel_attention_dynamic_hw ✅
resblock_channel_attention_dynamic_hw_nchw ✅ |
0.12.0 |
| SequentialReLU |
Two stateless nnx.relu activations chained via nnx.Sequential. |
sequential_double_relu ✅
sequential_double_relu_f64 ✅ |
0.7.1 |
| SequentialWithResidual |
nnx.Sequential nested within a residual block to regress earlier bugs. |
sequential_nested_with_residual ✅ |
0.7.1 |
| SimpleModel |
Minimal NNX model that applies jnp.clip. |
simple_model_clip_nhwc ✅
simple_model_clip_nchw_io ✅ |
0.12.0 |
| TransformerDecoderWithSequential |
Tiny nnx Transformer decoder using nnx.Sequential in the FFN block. |
tiny_decoder_with_sequential ✅
tiny_decoder_with_sequential_and_full_dynamic_shapes_dynamic ✅ |
0.7.1 |
| TransformerDecoderWithoutSequential |
Tiny nnx Transformer decoder with explicit FFN layers (no Sequential). |
tiny_decoder_without_sequential ✅ |
0.7.1 |
| FlaxDINOv3VisionTransformer |
DINOv3 Vision Transformer |
nnx_dinov3_vit_Ti14_dynamic ✅
nnx_dinov3_vit_Ti14 ✅
nnx_dinov3_vit_S14_dynamic ✅
nnx_dinov3_vit_S14 ✅
nnx_dinov3_vit_B14_dynamic ✅
nnx_dinov3_vit_B14 ✅
nnx_dinov3_vit_S16_dynamic ✅
nnx_dinov3_vit_S16 ✅ |
0.10.3 |
| NnxDinoAttention |
Multi-Head Self-Attention using Flax/NNX modules. |
nnx_attention_dynamic ✅
nnx_attention ✅ |
0.10.3 |
| NnxDinoAttentionCore |
Multi-Head Self-Attention without rotary processing. |
nnx_attention_core_dynamic ✅
nnx_attention_core ✅ |
0.10.3 |
| NnxDinoBlock |
Transformer Block. |
nnx_transformer_block_dynamic ✅
nnx_transformer_block ✅ |
0.10.3 |
| NnxDinoPatchEmbed |
Image to Patch Embedding. |
nnx_patch_embed ✅ |
0.10.3 |
| FlaxAttentionBlock |
Attention block from the GPT-OSS Flax reference (no KV cache). |
gpt_oss_attention_flax ✅ |
0.10.2 |
| FlaxMLPBlock |
Mixture-of-experts MLP block from the GPT-OSS Flax port. |
gpt_oss_mlp_flax ✅ |
0.10.2 |
| FlaxRMSNorm |
Flax RMSNorm used in the GPT-OSS JAX port. |
gpt_oss_rmsnorm_flax_dynamic ✅
gpt_oss_rmsnorm_flax ✅ |
0.10.2 |
| FlaxRotaryEmbedding |
Rotary position embedding helper from the GPT-OSS Flax port. |
gpt_oss_rotary_flax ✅ |
0.10.2 |
| FlaxSDPA |
JIT sdpa helper from the GPT-OSS Flax port. |
gpt_oss_sdpa_flax ✅ |
0.10.2 |
| FlaxTransformer |
Full GPT-OSS Flax transformer (embedding, blocks, head). |
gpt_oss_transformer_flax ✅ |
0.10.2 |
| FlaxTransformerBlock |
Single GPT-OSS Flax transformer block (attention + MoE MLP). |
gpt_oss_transformer_block_flax ✅ |
0.10.2 |
| onnx_functions_000 |
One function boundary on an outer NNX module (new-world). |
000_one_function_on_outer_layer_dynamic ✅
000_one_function_on_outer_layer ✅ |
0.4.0 |
| onnx_functions_001 |
one function on an inner layer. |
001_one_function_inner_dynamic ✅
001_one_function_inner ✅ |
0.4.0 |
| onnx_functions_002 |
two nested functions. |
002_two_nested_functions_dynamic ✅
002_two_nested_functions ✅ |
0.4.0 |
| onnx_functions_003 |
two nested functions. |
003_two_simple_nested_functions_dynamic ✅
003_two_simple_nested_functions ✅ |
0.4.0 |
| onnx_functions_004 |
nested function plus component |
004_nested_function_plus_component_dynamic ✅
004_nested_function_plus_component ✅ |
0.4.0 |
| onnx_functions_005 |
nested function plus more components |
005_nested_function_plus_component_dynamic ✅
005_nested_function_plus_component ✅ |
0.4.0 |
| onnx_functions_006 |
one function on an outer layer. |
006_one_function_outer_dynamic ✅
006_one_function_outer ✅ |
0.4.0 |
| onnx_functions_007 |
transformer block with nested mlp block with call parameter |
007_transformer_block_dynamic ✅
007_transformer_block ✅ |
0.4.0 |
| onnx_functions_008 |
transformer block with nested mlp block no call parameter |
008_transformer_block_dynamic ✅
008_transformer_block ✅ |
0.4.0 |
| onnx_functions_009 |
transformer block using decorator on class and function |
009_transformer_block_dynamic ✅
009_transformer_block ✅ |
0.4.0 |
| onnx_functions_010 |
transformer stack |
010_transformer_stack_dynamic ✅
010_transformer_stack ✅ |
0.4.0 |
| onnx_functions_012 |
Vision Transformer (ViT) |
012_vit_conv_embedding_dynamic ✅
012_vit_conv_embedding ✅ |
0.4.0 |
| onnx_functions_013 |
Vision Transformer (ViT) |
013_vit_conv_embedding_with_call_params_dynamic ✅
013_vit_conv_embedding_with_call_params ✅
013_vit_conv_embedding_with_internal_call_params_dynamic ✅
013_vit_conv_embedding_with_internal_call_params ✅ |
0.4.0 |
| onnx_functions_014 |
one function on an outer layer. |
014_one_function_with_input_param_with_default_value ✅
014_one_function_without_input_param_with_default_value_dynamic ✅
014_one_function_without_input_param_with_default_value ✅ |
0.4.0 |
| onnx_functions_015 |
one function on an outer layer. |
015_one_function_with_input_param_without_default_value_dynamic ✅
015_one_function_with_input_param_without_default_value ✅ |
0.4.0 |
| onnx_functions_016 |
nested function plus more components |
016_internal_function_with_input_param_with_default_value_dynamic ✅
016_internal_function_with_input_param_with_default_value ✅ |
0.4.0 |
| onnx_functions_017 |
Demonstrates @onnx_function(unique=True) reuse across call sites. |
017_unique_function_reuse ✅ |
0.10.0 |
| ClassificationHead |
Classification head for Vision Transformer |
vit_classification_head_dynamic ✅
vit_classification_head ✅ |
0.4.0 |
| ClassificationHeadFlatten |
Classification head for Vision Transformer |
vit_classification_head_flat_dynamic ✅
vit_classification_head_flat ✅ |
0.4.0 |
| ConcatClsToken |
Concatenate CLS token to the input embedding |
vit_concat_cls_token_dynamic ✅
vit_concat_cls_token ✅ |
0.4.0 |
| ConcatClsTokenFlatten |
Concatenate CLS token to the input embedding |
vit_concat_cls_token_flat_dynamic ✅
vit_concat_cls_token_flat ✅ |
0.4.0 |
| ConvEmbedding |
Convolutional Token Embedding for MNIST with hierarchical downsampling. |
vit_mnist_conv_embedding_dynamic ✅
vit_mnist_conv_embedding ✅ |
0.1.0 |
| ConvEmbeddingFlatten |
Convolutional Token Embedding for MNIST with hierarchical downsampling. |
vit_mnist_conv_embedding_flat_dynamic ✅
vit_mnist_conv_embedding_flat ✅ |
0.1.0 |
| FeedForward |
MLP in Transformer |
vit_feed_forward_dynamic ✅
vit_feed_forward ✅ |
0.1.0 |
| FeedForwardFlatten |
MLP in Transformer |
vit_feed_forward_flat_dynamic ✅
vit_feed_forward_flat ✅ |
0.1.0 |
| GetToken |
Get the CLS token from the input embedding |
vit_get_token_dynamic ✅
vit_get_token ✅ |
0.4.0 |
| GetTokenFlatten |
Get the CLS token from the input embedding |
vit_get_token_flat_dynamic ✅
vit_get_token_flat ✅ |
0.4.0 |
| PatchEmbedding |
Cutting the image into patches and linearly embedding them. |
vit_patch_embedding_dynamic ✅
vit_patch_embedding ✅ |
0.1.0 |
| PatchEmbeddingFlatten |
Cutting the image into patches and linearly embedding them. |
vit_patch_embedding_flat_dynamic ✅
vit_patch_embedding_flat ✅ |
0.1.0 |
| PositionalEmbedding |
Add positional embedding to the input embedding |
vit_positional_embedding_dynamic ✅
vit_positional_embedding ✅ |
0.4.0 |
| PositionalEmbeddingFlatten |
Add positional embedding to the input embedding |
vit_positional_embedding_flat_dynamic ✅
vit_positional_embedding_flat ✅ |
0.4.0 |
| TransformerBlock |
Transformer from 'Attention Is All You Need.' |
vit_transformer_block_dynamic ✅
vit_transformer_block ✅ |
0.1.0 |
| TransformerBlockFlatten |
Transformer from 'Attention Is All You Need.' |
vit_transformer_block_flat_dynamic ✅
vit_transformer_block_flat ✅ |
0.1.0 |
| TransformerStack |
Stack of Transformer blocks |
vit_transformer_stack_dynamic ✅
vit_transformer_stack ✅ |
0.1.0 |
| TransformerStackFlatten |
Stack of Transformer blocks |
vit_transformer_stack_flat_dynamic ✅
vit_transformer_stack_flat ✅ |
0.1.0 |
| VisionTransformer |
A Vision Transformer (ViT) model for MNIST with configurable embedding type. |
vit_model_conv_embedding_dynamic ✅
vit_model_conv_embedding ✅
vit_model_patch_embedding ✅ |
0.2.0 |
| VisionTransformerFlatten |
A Vision Transformer (ViT) model for MNIST with configurable embedding type. |
vit_model_conv_embedding_flat_dynamic ✅
vit_model_conv_embedding_flat ✅
vit_model_patch_embedding_flat_dynamic ✅
vit_model_patch_embedding_flat ✅ |
0.2.0 |