GPT-OSS Weights Workflow¶
The GPT-OSS examples in this repo come in two flavors:
- Flax/NNX - the recommended path for new exports.
- Equinox - retained for historical comparison.
Unless you specifically need the Equinox version, follow the Flax/NNX workflow below. The Equinox workflow is listed separately at the end.
All commands assume you are at the project root with the Poetry environment
available. The workflow targets CPU-only tools; feel free to switch the device
flags to cuda:X if you have GPU support.
Related Examples¶
For detailed component exports (like GPT-OSS Attention, MLP, RMSNorm), see the GPT-OSS entries in the Examples table.
1. Download a GPT-OSS checkpoint¶
OpenAI publishes GPT-OSS on Hugging Face under both 20B and 120B variants. The
commands below fetch the original shard layout expected by
gpt_oss.torch.model.Transformer.from_checkpoint.
mkdir -p ~/.cache/gpt_oss/gpt-oss-20b
poetry run huggingface-cli download openai/gpt-oss-20b original \
--repo-type model \
--local-dir ~/.cache/gpt_oss/gpt-oss-20b \
--local-dir-use-symlinks False
huggingface_hub≥1.0 no longer installs thehuggingface-clientry point by default, and recent releases ignore--local-dir-use-symlinks. If the command above fails with “command not found” or you prefer to stay within Python, download the checkpoint folder directly viasnapshot_download:
poetry run python - <<'PY'
from pathlib import Path
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="openai/gpt-oss-20b",
repo_type="model",
allow_patterns=["original/*"],
local_dir=Path("~/.cache/gpt_oss/gpt-oss-20b").expanduser(),
)
PY
This grabs the original/ shard set expected by
Transformer.from_checkpoint. Omit allow_patterns if you want the full repo
contents (tokenizer, chat template, etc.).
After the download finishes you should have a directory containing config.json
and a set of *.safetensors shards.
2. Stage GPT-OSS weights for Flax/NNX¶
Run the staging helper to materialize a Flax .msgpack bundle plus a matching
config.json. The exporter consumes this format directly.
poetry run python scripts/export_flax_gpt_oss_params.py \
--checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
--output ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack
Use --gpt-oss-path if the helper repo lives somewhere other than the default
tmp/gpt-oss-jax-vs-torch-numerical-comparison. The script automatically
detects Orbax vs. SafeTensors checkpoints and writes
flax_params.msgpack.config.json beside the serialized parameters.
3. Export the Flax/NNX transformer to ONNX¶
With staged params in place, call the ONNX exporter. It instantiates the
examples.nnx_gpt_oss.FlaxTransformer module, loads the staged parameters via
nnx.Param assignments, and traces the full embedding → blocks → norm → head
pipeline.
poetry run python scripts/export_flax_gpt_oss_to_onnx.py \
--params ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack \
--output artifacts/gpt_oss_flax.onnx \
--sequence-length 256
Notes:
--sequence-lengthcontrols both the tracing inputs and the rotary/mask tables. Start small (e.g. 128) while verifying the workflow, then bump the length to your deployment target.- Pass
--config /path/to/config.jsonif the staging script’s JSON lives elsewhere. - The exporter includes validation by default. Use
--skip-validationonly when you explicitly want to inspect an exported graph before runtime validation.
Tip: When iterating on the exporter it can be helpful to trim the staged checkpoint down to a couple of layers. The snippet below keeps only the first two Transformer blocks while preserving the original hidden size/head layout, producing a much smaller bundle that exports quickly:
poetry run python - <<'PY' from pathlib import Path import json import flax.serialization as serialization root = Path("~/.cache/gpt_oss/gpt-oss-20b").expanduser() params = serialization.msgpack_restore((root / "flax_params.msgpack").read_bytes()) keep = {k: params[k] for k in ["embedding", "norm", "unembedding"]} for idx in range(2): keep[f"block_{idx}"] = params[f"block_{idx}"] (root / "flax_params_2layers.msgpack").write_bytes(serialization.to_bytes(keep)) config = json.loads((root / "flax_params.config.json").read_text()) config["num_hidden_layers"] = 2 (root / "flax_params_2layers.config.json").write_text(json.dumps(config, indent=2)) print("Wrote 2-layer bundle under", root) PYExport with
--params .../flax_params_2layers.msgpack --config .../flax_params_2layers.config.jsonto keep traces under a few minutes.
4. ONNX-only smoke test¶
For a short-window ONNX-only run (no JAX/Torch in memory), export a small model and run tokenizer-backed generation:
JAX_PLATFORM_NAME=cpu \
poetry run python scripts/export_flax_gpt_oss_to_onnx.py \
--params ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack \
--config ~/.cache/gpt_oss/gpt-oss-20b/flax_params.config.json \
--output /tmp/gpt_oss_transformer_flax_seq16.onnx \
--sequence-length 16 \
--skip-validation
mkdir -p artifacts/gpt_oss_full_seq16
mv /tmp/gpt_oss_transformer_flax_seq16.onnx artifacts/gpt_oss_full_seq16/
mv /tmp/gpt_oss_transformer_flax_seq16.onnx.data artifacts/gpt_oss_full_seq16/
LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
poetry run python scripts/run_onnx_only.py \
--onnx artifacts/gpt_oss_full_seq16/gpt_oss_transformer_flax_seq16.onnx \
--config ~/.cache/gpt_oss/gpt-oss-20b/flax_params.config.json \
--prompt "France capital? Answer:" \
--sequence-length 16 \
--generate-steps 8 \
--expand-functions \
--runtime ort
For longer prompts/responses, re-export with a larger --sequence-length or add KV cache support to avoid re-running the full window.
5. (Legacy) Export the Equinox example to ONNX¶
Use the helper script to load the checkpoint, mirror it into the IR-only Equinox modules, and emit an ONNX graph. On memory constrained systems it helps to run the export in two stages:
- Stage the Equinox weights (reads SafeTensors → writes
.eqx, no ONNX yet):
poetry run python scripts/export_eqx_gpt_oss_example_with_mapped_weights.py \
--checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
--save-eqx ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.eqx \
--seq-len 256 \
--dynamic-b \
--skip-onnx
- Convert the cached Equinox model to ONNX (no PyTorch in memory):
poetry run python scripts/export_eqx_gpt_oss_example_with_mapped_weights.py \
--eqx ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.eqx \
--output ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.onnx \
--seq-len 256 \
--dynamic-b
--dynamic-bemits a symbolic batch axis (B).- Omit
--dynamic-band/or add--dynamic-seqto tailor the exported shapes. --save-eqxkeeps the mapped Equinox parameters around for future exports.- Pass a higher
--seq-len(e.g. 512) once the 256-token run succeeds; longer sequences raise memory pressure while tracing the attention blocks.