Skip to content

Examples

Component Description Testcases Since
MlpExample A simple Equinox MLP (converter pipeline). mlp_training_mode
mlp_training_mode_f64
mlp_inference_mode
mlp_inference_mode_f64
mlp_batched_training_mode
mlp_batched_training_mode_f64
0.8.0
SimpleLinearExample A simple linear layer example using Equinox (converter). simple_linear
simple_linear_f64
nn_linear
nn_linear_f64
0.7.1
Attention Multi-Head Self-Attention using Equinox modules. attention_dynamic
attention
0.10.0
AttentionCore Multi-Head Self-Attention without rotary processing. attention_core_dynamic
attention_core
0.10.0
Block Transformer Block. transformer_block_dynamic
transformer_block
0.10.0
DINOv3VisionTransformer DINOv3 Vision Transformer eqx_dinov3_vit_Ti14_dynamic
eqx_dinov3_vit_Ti14
eqx_dinov3_vit_S14_dynamic
eqx_dinov3_vit_S14
eqx_dinov3_vit_B14_dynamic
eqx_dinov3_vit_B14
eqx_dinov3_vit_S16_dynamic
eqx_dinov3_vit_S16
0.10.0
PatchEmbed Image to Patch Embedding. patch_embed 0.10.0
AttentionBlock Self-attention block with rotary embeddings and sinks. gpt_oss_attention_block 0.10.2
MLPBlock Mixture-of-experts SwiGLU feed-forward block. gpt_oss_mlp_block 0.10.2
RMSNorm Root mean square normalisation used by GPT-OSS. gpt_oss_rmsnorm_dynamic
gpt_oss_rmsnorm
0.10.2
Transformer Full GPT-OSS Transformer stack. gpt_oss_transformer 0.10.2
TransformerBlock GPT-OSS Transformer layer (attention + MoE). gpt_oss_transformer_block_dynamic
gpt_oss_transformer_block
0.10.2
GPT A simple GPT model that reuses nnx.MultiHeadAttention. gpt_dynamic
gpt
0.7.0
GPT_Attention A multi-head attention layer. gpt_attention 0.7.1
GPT_CausalSelfAttention A causal self-attention module. gpt_causal_self_attention_dynamic
gpt_causal_self_attention
0.7.0
GPT_Embeddings Combines token and position embeddings with dropout. gpt_embeddings_dynamic
gpt_embeddings
0.7.0
GPT_Head The head of the GPT model. gpt_head_dynamic
gpt_head
0.7.0
GPT_MLP An MLP block with GELU activation from nanoGPT. gpt_mlp_dynamic
gpt_mlp
0.7.0
GPT_PositionEmbedding A positional embedding layer using nnx.Embed. gpt_position_embedding 0.7.0
GPT_TokenEmbedding A token embedding layer using nnx.Embed. gpt_token_embedding_dynamic
gpt_token_embedding
0.7.0
GPT_TransformerBlock A transformer block combining attention and MLP. gpt_block_dynamic
gpt_block
0.7.0
GPT_TransformerStack A stack of transformer blocks. gpt_transformer_stack_dynamic
gpt_transformer_stack
0.7.0
GPT_broadcast_add Simple dynamic broadcast + add gpt_broadcast_add_dynamic_dynamic
gpt_broadcast_add_dynamic_dynamic_f64
gpt_broadcast_add_dynamic
gpt_broadcast_add_dynamic_f64
0.7.0
cfl_timestep Tests the CFL condition timestep calculation. cfl_timestep_f64 0.6.5
weno_reconstruction Tests the complex arithmetic pattern found in WENO schemes. weno_reconstruction_f64 0.6.5
fori_loop_test fori_loop_test: demonstrates jax.lax.fori_loop with a simple loop. fori_loop_test
fori_loop_test_f64
0.6.3
issue18_abs Test jnp.abs from issue 18 abs_fn
abs_fn_f64
0.6.3
issue18_arange Test jnp.arange from issue 18 arange_fn 0.6.3
issue18_fori_loop Test jax.lax.fori_loop from issue 18 fori_loop_fn
fori_loop_fn_f64
0.6.3
issue18_linspace Test jnp.linspace from issue 18 linspace_fn 0.6.3
issue18_scan Test jax.lax.scan from issue 18 (no xs) scan_fn 0.6.3
issue18_sign Test jnp.sign from issue 18 sign_fn
sign_fn_f64
0.6.3
issue18_where Test jnp.where from issue 18 where_fn
where_fn_f64
0.6.3
issue18_while_loop Test jax.lax.while_loop from issue 18 while_loop_fn 0.9.0
select_test Demonstrates jnp.select with scalar and tensor predicates. select_test_all_options
select_test_scalar_select_option_0
select_test_scalar_select_option_1
select_test_scalar_select_option_2
select_test_default_case
0.9.0
sort_test sort_test: demonstrates jnp.sort on slices of an input array. sort_test_basic 0.9.0
cond_scatter_add_mul Scatter add/mul inside conditional branches (converter). cond_scatter_add_mul_f64_a
cond_scatter_add_mul_f64_b
0.8.0
cond_scatter_repro Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. cond_scatter_repro_f64 0.6.4
remat2 Tests a simple case of jax.checkpoint (also known as jax.remat2). checkpoint_scalar_f32
checkpoint_scalar_f32_f64
0.6.5
scatter_window Window-scatter (H×W patch) with implicit batch (depth-3 path). Exercises GatherScatterMode.FILL_OR_DROP and double precision. Regression of a prior conversion failure. scatter_window_update_f64_example 0.7.4
two_times_silu Regression for calling jax.nn.silu twice (issue #139). two_times_silu_scalar
two_times_silu_scalar_f64
0.10.2
LinenCNN A simple convolutional neural network (CNN). simple_cnn_static
simple_cnn_dynamic
0.11.0
LinenMLP A simple Linen MLP with BatchNorm, Dropout, and GELU activation. simple_linen_mlp_static
simple_linen_mlp_static_f64
simple_linen_mlp_dynamic
simple_linen_mlp_dynamic_f64
simple_linen_mlp_with_call_params_dynamic
simple_linen_mlp_with_call_params_dynamic_f64
simple_linen_mlp_with_call_params
simple_linen_mlp_with_call_params_f64
0.11.0
LinenMLPSequential A Linen MLP built from flax.linen.Sequential. simple_linen_mlp_sequential_static
simple_linen_mlp_sequential_static_f64
simple_linen_mlp_sequential_dynamic
simple_linen_mlp_sequential_dynamic_f64
0.11.0
MaxText_deepseek2_16b MaxText model: deepseek2-16b maxtext_deepseek2-16b 0.11.1
MaxText_deepseek2_236b MaxText model: deepseek2-236b maxtext_deepseek2-236b 0.11.1
MaxText_deepseek3_671b MaxText model: deepseek3-671b maxtext_deepseek3-671b 0.11.1
MaxText_deepseek3_671b_2dfsdp MaxText model: deepseek3-671b-2dfsdp maxtext_deepseek3-671b-2dfsdp 0.11.1
MaxText_deepseek3_test MaxText model: deepseek3-test maxtext_deepseek3-test 0.11.1
MaxText_deepseek3_tiny MaxText model: deepseek3-tiny maxtext_deepseek3-tiny 0.11.1
MaxText_gemma2_27b MaxText model: gemma2-27b maxtext_gemma2-27b 0.11.1
MaxText_gemma2_2b MaxText model: gemma2-2b maxtext_gemma2-2b 0.11.1
MaxText_gemma2_9b MaxText model: gemma2-9b maxtext_gemma2-9b 0.11.1
MaxText_gemma3_12b MaxText model: gemma3-12b maxtext_gemma3-12b 0.11.1
MaxText_gemma3_27b MaxText model: gemma3-27b maxtext_gemma3-27b 0.11.1
MaxText_gemma3_4b MaxText model: gemma3-4b maxtext_gemma3-4b 0.11.1
MaxText_gemma_2b MaxText model: gemma-2b maxtext_gemma-2b 0.11.1
MaxText_gemma_7b MaxText model: gemma-7b maxtext_gemma-7b 0.11.1
MaxText_gpt3_175b MaxText model: gpt3-175b maxtext_gpt3-175b 0.11.1
MaxText_gpt3_22b MaxText model: gpt3-22b maxtext_gpt3-22b 0.11.1
MaxText_gpt3_52k MaxText model: gpt3-52k maxtext_gpt3-52k 0.11.1
MaxText_gpt3_6b MaxText model: gpt3-6b maxtext_gpt3-6b 0.11.1
MaxText_kimi_k2_1t MaxText model: kimi-k2-1t maxtext_kimi-k2-1t 0.11.1
MaxText_llama2_13b MaxText model: llama2-13b maxtext_llama2-13b 0.11.1
MaxText_llama2_70b MaxText model: llama2-70b maxtext_llama2-70b 0.11.1
MaxText_llama2_7b MaxText model: llama2-7b maxtext_llama2-7b 0.11.1
MaxText_llama3_1_405b MaxText model: llama3.1-405b maxtext_llama3.1-405b 0.11.1
MaxText_llama3_1_70b MaxText model: llama3.1-70b maxtext_llama3.1-70b 0.11.1
MaxText_llama3_1_8b MaxText model: llama3.1-8b maxtext_llama3.1-8b 0.11.1
MaxText_llama3_3_70b MaxText model: llama3.3-70b maxtext_llama3.3-70b 0.11.1
MaxText_llama3_405b MaxText model: llama3-405b maxtext_llama3-405b 0.11.1
MaxText_llama3_70b MaxText model: llama3-70b maxtext_llama3-70b 0.11.1
MaxText_llama3_8b MaxText model: llama3-8b maxtext_llama3-8b 0.11.1
MaxText_mistral_7b MaxText model: mistral-7b maxtext_mistral-7b 0.11.1
MaxText_qwen3_0_6b MaxText model: qwen3-0.6b maxtext_qwen3-0.6b 0.11.1
MaxText_qwen3_14b MaxText model: qwen3-14b maxtext_qwen3-14b 0.11.1
MaxText_qwen3_235b_a22b MaxText model: qwen3-235b-a22b maxtext_qwen3-235b-a22b 0.11.1
MaxText_qwen3_30b_a3b MaxText model: qwen3-30b-a3b maxtext_qwen3-30b-a3b 0.11.1
MaxText_qwen3_32b MaxText model: qwen3-32b maxtext_qwen3-32b 0.11.1
MaxText_qwen3_480b_a35b MaxText model: qwen3-480b-a35b maxtext_qwen3-480b-a35b 0.11.1
MaxText_qwen3_4b MaxText model: qwen3-4b maxtext_qwen3-4b 0.11.1
MaxText_qwen3_4b_thinking_2507 MaxText model: qwen3-4b-thinking-2507 maxtext_qwen3-4b-thinking-2507 0.11.1
MaxText_qwen3_8b MaxText model: qwen3-8b maxtext_qwen3-8b 0.11.1
MaxText_qwen3_next_80b_a3b MaxText model: qwen3-next-80b-a3b maxtext_qwen3-next-80b-a3b 0.11.1
MaxText_qwen3_omni_30b_a3b MaxText model: qwen3-omni-30b-a3b maxtext_qwen3-omni-30b-a3b 0.11.1
AutoEncoder A simple autoencoder example (converter pipeline). simple_autoencoder
simple_autoencoder_f64
0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_static
simple_cnn_dynamic
0.2.0
DepthToSpaceResNet Residual conv stack followed by dm_pix.depth_to_space upsampling. depth_to_space_resnet_static
depth_to_space_resnet_inputs_outputs_as_nchw
depth_to_space_resnet_inputs_outputs_as_nchw_dynamic_hw
depth_to_space_resnet_scaled_inputs_outputs_as_nchw
0.11.2
ForiLoop fori_loop example using nnx-compatible primitives (converter). fori_loop_counter
fori_loop_counter_f64
0.5.1
GRUCell Flax/nnx GRUCell lowered through converter primitives. gru_cell_basic 0.7.2
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_static
simple_mlp_static_f64
simple_mlp_dynamic
simple_mlp_dynamic_f64
simple_mlp_with_call_params_dynamic
simple_mlp_with_call_params_dynamic_f64
simple_mlp_with_call_params
simple_mlp_with_call_params_f64
0.1.0
MultiHeadAttention nnx.MultiHeadAttention exercised in several configurations, including custom attention_fn and symbolic batch variants. multihead_attention_nn_dynamic
multihead_attention_nn
multihead_attention_nnx_dynamic
multihead_attention_nnx
multihead_attention_2_nnx_dynamic
multihead_attention_2_nnx
0.2.0
NestedResidualGroup Nested residual blocks inside a residual group; regression harness for issue #173. nested_residual_group_static
nested_residual_group_static_nchw
nested_residual_group_with_lead_conv_static
nested_residual_stack_static
nested_residual_stack_static_no_extra_transpose
nested_residual_stack_with_lead_conv_static
0.12.0
ResBlock Residual block with squeeze-and-excite channel attention (from issue #168). resblock_channel_attention_static
resblock_channel_attention_static_nchw
resblock_channel_attention_dynamic_hw
resblock_channel_attention_dynamic_hw_nchw
0.12.0
SequentialReLU Two stateless nnx.relu activations chained via nnx.Sequential. sequential_double_relu
sequential_double_relu_f64
0.7.1
SequentialWithResidual nnx.Sequential nested within a residual block to regress earlier bugs. sequential_nested_with_residual 0.7.1
SimpleModel Minimal NNX model that applies jnp.clip. simple_model_clip_nhwc
simple_model_clip_nchw_io
0.12.0
TransformerDecoderWithSequential Tiny nnx Transformer decoder using nnx.Sequential in the FFN block. tiny_decoder_with_sequential
tiny_decoder_with_sequential_and_full_dynamic_shapes_dynamic
0.7.1
TransformerDecoderWithoutSequential Tiny nnx Transformer decoder with explicit FFN layers (no Sequential). tiny_decoder_without_sequential 0.7.1
FlaxDINOv3VisionTransformer DINOv3 Vision Transformer nnx_dinov3_vit_Ti14_dynamic
nnx_dinov3_vit_Ti14
nnx_dinov3_vit_S14_dynamic
nnx_dinov3_vit_S14
nnx_dinov3_vit_B14_dynamic
nnx_dinov3_vit_B14
nnx_dinov3_vit_S16_dynamic
nnx_dinov3_vit_S16
0.10.3
NnxDinoAttention Multi-Head Self-Attention using Flax/NNX modules. nnx_attention_dynamic
nnx_attention
0.10.3
NnxDinoAttentionCore Multi-Head Self-Attention without rotary processing. nnx_attention_core_dynamic
nnx_attention_core
0.10.3
NnxDinoBlock Transformer Block. nnx_transformer_block_dynamic
nnx_transformer_block
0.10.3
NnxDinoPatchEmbed Image to Patch Embedding. nnx_patch_embed 0.10.3
FlaxAttentionBlock Attention block from the GPT-OSS Flax reference (no KV cache). gpt_oss_attention_flax 0.10.2
FlaxMLPBlock Mixture-of-experts MLP block from the GPT-OSS Flax port. gpt_oss_mlp_flax 0.10.2
FlaxRMSNorm Flax RMSNorm used in the GPT-OSS JAX port. gpt_oss_rmsnorm_flax_dynamic
gpt_oss_rmsnorm_flax
0.10.2
FlaxRotaryEmbedding Rotary position embedding helper from the GPT-OSS Flax port. gpt_oss_rotary_flax 0.10.2
FlaxSDPA JIT sdpa helper from the GPT-OSS Flax port. gpt_oss_sdpa_flax 0.10.2
FlaxTransformer Full GPT-OSS Flax transformer (embedding, blocks, head). gpt_oss_transformer_flax 0.10.2
FlaxTransformerBlock Single GPT-OSS Flax transformer block (attention + MoE MLP). gpt_oss_transformer_block_flax 0.10.2
onnx_functions_000 One function boundary on an outer NNX module (new-world). 000_one_function_on_outer_layer_dynamic
000_one_function_on_outer_layer
0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic
001_one_function_inner
0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic
002_two_nested_functions
0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic
003_two_simple_nested_functions
0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic
004_nested_function_plus_component
0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic
005_nested_function_plus_component
0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic
006_one_function_outer
0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic
007_transformer_block
0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic
008_transformer_block
0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic
009_transformer_block
0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic
010_transformer_stack
0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic
012_vit_conv_embedding
0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic
013_vit_conv_embedding_with_call_params
013_vit_conv_embedding_with_internal_call_params_dynamic
013_vit_conv_embedding_with_internal_call_params
0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value
014_one_function_without_input_param_with_default_value_dynamic
014_one_function_without_input_param_with_default_value
0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic
015_one_function_with_input_param_without_default_value
0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic
016_internal_function_with_input_param_with_default_value
0.4.0
onnx_functions_017 Demonstrates @onnx_function(unique=True) reuse across call sites. 017_unique_function_reuse 0.10.0
ClassificationHead Classification head for Vision Transformer vit_classification_head_dynamic
vit_classification_head
0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer vit_classification_head_flat_dynamic
vit_classification_head_flat
0.4.0
ConcatClsToken Concatenate CLS token to the input embedding vit_concat_cls_token_dynamic
vit_concat_cls_token
0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding vit_concat_cls_token_flat_dynamic
vit_concat_cls_token_flat
0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. vit_mnist_conv_embedding_dynamic
vit_mnist_conv_embedding
0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. vit_mnist_conv_embedding_flat_dynamic
vit_mnist_conv_embedding_flat
0.1.0
FeedForward MLP in Transformer vit_feed_forward_dynamic
vit_feed_forward
0.1.0
FeedForwardFlatten MLP in Transformer vit_feed_forward_flat_dynamic
vit_feed_forward_flat
0.1.0
GetToken Get the CLS token from the input embedding vit_get_token_dynamic
vit_get_token
0.4.0
GetTokenFlatten Get the CLS token from the input embedding vit_get_token_flat_dynamic
vit_get_token_flat
0.4.0
PatchEmbedding Cutting the image into patches and linearly embedding them. vit_patch_embedding_dynamic
vit_patch_embedding
0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. vit_patch_embedding_flat_dynamic
vit_patch_embedding_flat
0.1.0
PositionalEmbedding Add positional embedding to the input embedding vit_positional_embedding_dynamic
vit_positional_embedding
0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding vit_positional_embedding_flat_dynamic
vit_positional_embedding_flat
0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' vit_transformer_block_dynamic
vit_transformer_block
0.1.0
TransformerBlockFlatten Transformer from 'Attention Is All You Need.' vit_transformer_block_flat_dynamic
vit_transformer_block_flat
0.1.0
TransformerStack Stack of Transformer blocks vit_transformer_stack_dynamic
vit_transformer_stack
0.1.0
TransformerStackFlatten Stack of Transformer blocks vit_transformer_stack_flat_dynamic
vit_transformer_stack_flat
0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_model_conv_embedding_dynamic
vit_model_conv_embedding
vit_model_patch_embedding
0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_model_conv_embedding_flat_dynamic
vit_model_conv_embedding_flat
vit_model_patch_embedding_flat_dynamic
vit_model_patch_embedding_flat
0.2.0