SotA Example Maintenance¶
This page is maintainer-facing. User-facing SotA pages should explain stable export and validation workflows; generated tests, pinned upstream refs, parity baselines, and artifact promotion details belong here.
Reference Table Refresh¶
Use scripts/generate_readme.sh when the public component/example tables need
to include MaxText or MaxDiffusion status. It installs both optional dependency
groups, prepares pinned upstream checkouts, preserves dirty primary checkouts by
using clean fallback directories, then runs scripts/generate_readme.py.
Run this workflow from Python 3.12 or newer, matching the optional SotA stack
used by scripts/run_all_checks.sh.
Override JAX2ONNX_MAXTEXT_MODELS or JAX2ONNX_MAXDIFFUSION_MODELS before
running the script when a review should cover only a subset.
MaxText¶
Use a compatible MaxText checkout when refreshing generated examples:
mkdir -p tmp
if [ ! -d tmp/maxtext/.git ]; then
git clone https://github.com/AI-Hypercomputer/maxtext.git tmp/maxtext
fi
git -C tmp/maxtext fetch --depth 1 origin 17d805e3488104b5de96bd19be09491ff73c57c1
git -C tmp/maxtext checkout --detach FETCH_HEAD
export JAX2ONNX_MAXTEXT_SRC="$PWD/tmp/maxtext"
export JAX2ONNX_MAXTEXT_MODELS=all # or "gemma-2b,llama2-7b"
export JAX2ONNX_MAXTEXT_REF=17d805e3488104b5de96bd19be09491ff73c57c1
poetry install --with maxtext
poetry run python scripts/generate_tests.py
poetry run pytest -q tests/examples/test_maxtext.py
The standard repository runner can include the same checks:
By default, run_all_checks.sh does not run MaxText checks. With
JAX2ONNX_RUN_MAXTEXT=1, it prepares JAX2ONNX_MAXTEXT_SRC, installs
--with maxtext, regenerates tests, runs the MaxText example tests, then
continues with the regular suite. On Python versions below 3.12, the MaxText
block is skipped.
MaxDiffusion¶
Use a compatible MaxDiffusion checkout when refreshing generated examples:
mkdir -p tmp
if [ ! -d tmp/maxdiffusion/.git ]; then
git clone https://github.com/google/maxdiffusion.git tmp/maxdiffusion
fi
git -C tmp/maxdiffusion fetch --depth 1 origin b4f95730bf4f00c4fd9e3dd2fdda1b50484afda2
git -C tmp/maxdiffusion checkout --detach FETCH_HEAD
export JAX2ONNX_MAXDIFFUSION_SRC="$PWD/tmp/maxdiffusion"
export JAX2ONNX_MAXDIFFUSION_MODELS=all
export JAX2ONNX_MAXDIFFUSION_REF=b4f95730bf4f00c4fd9e3dd2fdda1b50484afda2
poetry install --with maxdiffusion
poetry run python scripts/generate_tests.py
poetry run pytest -q tests/examples/test_maxdiffusion.py
The standard repository runner can include the same checks:
By default, run_all_checks.sh does not run MaxDiffusion checks. With
JAX2ONNX_RUN_MAXDIFFUSION=1, it prepares JAX2ONNX_MAXDIFFUSION_SRC, checks
out the pinned ref, installs --with maxdiffusion, regenerates tests, runs the
MaxDiffusion example tests, then continues with the regular suite. On Python
versions below 3.12, the MaxDiffusion block is skipped.
GPT-OSS Parity¶
Before promoting a GPT-OSS ONNX export, prove the staged Flax bundle still matches the reference checkpoint. Keep parity transcripts in review artifacts or durable maintainer notes when they support a public example update; do not put one-off run logs in user-facing docs.
Routing parity:
JAX_PLATFORM_NAME=cpu poetry run python scripts/gpt_oss_routing_parity.py \
--gpt-oss-path tmp/gpt-oss-jax-vs-torch-numerical-comparison \
--jax-checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
--torch-checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
--prompt "What is the capital of France?" \
--max-layers 24 \
--max-tokens 4 \
--torch-device cpu \
--output-dir artifacts/gpt_oss_routing/flax
Two-layer debug export and hidden-state comparison:
JAX_PLATFORM_NAME=cpu ORT_LOG_SEVERITY_LEVEL=4 poetry run python scripts/export_flax_gpt_oss_to_onnx.py \
--params ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.msgpack \
--config ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.config.json \
--output /tmp/gpt_oss_transformer_flax_debug.onnx \
--sequence-length 16 \
--emit-hidden-states \
--emit-block-debug \
--skip-validation
JAX_PLATFORM_NAME=cpu ORT_LOG_SEVERITY_LEVEL=4 poetry run python scripts/run_flax_gpt_oss_onnx.py \
--prompt "What is the capital of France?" \
--params ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.msgpack \
--config ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.config.json \
--onnx /tmp/gpt_oss_transformer_flax_debug.onnx \
--sequence-length 16 \
--compare-hidden-states \
--compare-block-debug
Torch-to-Flax parity:
JAX_PLATFORM_NAME=cpu \
poetry run python scripts/probe_flax_gpt_oss_parity.py \
--prompt "France capital? Answer:" \
--params ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.msgpack \
--config ~/.cache/gpt_oss/gpt-oss-20b/flax_params_2layers.config.json \
--torch-checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
--sequence-length 16 \
--gpt-oss-path tmp/gpt-oss-jax-vs-torch-numerical-comparison \
--torch-device cpu \
--torch-max-layers 2
Do not commit generated .onnx or .onnx.data files; publish sample models
according to Artifact Publishing.
DINOv3 NNX¶
The NNX DINO example registry uses explicit component names such as
NnxDinoPatchEmbed, NnxDinoAttentionCore, NnxDinoAttention,
NnxDinoBlock, and FlaxDINOv3VisionTransformer so generated test files and
published model paths stay unambiguous.
To capture submodules alongside the full ViT during maintenance:
When the NNX stack or the Equinox sibling implementation changes, also run the forward-parity regression that copies Equinox weights into the NNX model:
Keep generated artifacts local while validating, then publish sample models via the artifact workflow when the public example table needs updated links.