Skip to content

MaxText Support 🚀

MaxText is a high-performance, arbitrary-scale, open-source LLM framework written in pure Python/JAX. jax2onnx provides a self-contained example stack to export these models to ONNX.

  • MaxText (DeepSeek, Gemma, GPT-3, Kimi, Llama, Mistral, Qwen) - https://github.com/AI-Hypercomputer/maxtext

All supported MaxText model families (DeepSeek, Gemma, Llama, Mistral, Qwen, etc.) are listed with their test status in the Examples table.

Supported Families

We support exporting the following model families from the MaxText model zoo:

  • DeepSeek (v2 / v3)
  • Gemma (2 / 3)
  • GPT-3
  • Kimi (K2)
  • Llama (2 / 3 / 3.1 / 4)
  • Mistral
  • Qwen (3 / 3-Next / Omni)

Usage

Dependencies

To run the MaxText examples, you need to install the following additional dependencies:

poetry install --with maxtext

Note: This installs omegaconf, transformers, sentencepiece, tensorflow-cpu, and tensorboardX. tensorflow-cpu is required because MaxText uses tensorboard and some TF utilities. It does not install the MaxText source tree itself; use JAX2ONNX_MAXTEXT_SRC (recommended) or install a MaxText package separately.

Environment Configuration

  • JAX2ONNX_MAXTEXT_SRC (Optional): Path to a local clone of the MaxText repository. If not set, the system attempts to resolve it from an installed MaxText package.
  • JAX2ONNX_MAXTEXT_MODELS (Optional): A comma-separated list of model config names to test (e.g., llama2-7b.yml). If unset, it defaults to a standard set of representative models.
  • JAX2ONNX_MAXTEXT_REF (Optional): Git branch/tag/commit used by run_all_checks.sh. Defaults to commit 17d805e3488104b5de96bd19be09491ff73c57c1 (17d805e).

If auto-discovery finds an incompatible MaxText package, jax2onnx now skips examples.maxtext registration (instead of generating placeholder failing tests).
Set JAX2ONNX_MAXTEXT_SRC to a known compatible checkout to force and validate MaxText integration.

Testing

To run the pinned MaxText examples (use poetry run to stay in the project venv):

cd tmp
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext
git checkout 17d805e
cd ../..
export JAX2ONNX_MAXTEXT_SRC=tmp/maxtext
export JAX2ONNX_MAXTEXT_MODELS=all  # or "gemma-2b,llama2-7b"
export JAX2ONNX_MAXTEXT_REF=17d805e3488104b5de96bd19be09491ff73c57c1
poetry install --with maxtext
poetry run python scripts/generate_tests.py
poetry run pytest -q tests/examples/test_maxtext.py

ONNX outputs land in docs/onnx/examples/maxtext.

You can also include the same MaxText SotA checks in the standard repository runner:

JAX2ONNX_RUN_MAXTEXT=1 ./scripts/run_all_checks.sh

By default, run_all_checks.sh does not run MaxText checks. With JAX2ONNX_RUN_MAXTEXT=1, it prepares JAX2ONNX_MAXTEXT_SRC (default: tmp/maxtext) and JAX2ONNX_MAXTEXT_REF (default: 17d805e3488104b5de96bd19be09491ff73c57c1), installs --with maxtext, regenerates tests, runs tests/examples/test_maxtext.py, then executes the full pytest suite. On Python 3.11, the script automatically skips the MaxText block and continues with the regular checks.

This will: 1. Dynamically discover MaxText configs. 2. Instantiate the models with minimal inference settings (batch_size=1, seq_len=32). 3. Export them to ONNX and verify the graph structure.