Skip to content

GPT-OSS Weights Workflow

The GPT-OSS examples in this repo come in two flavors:

  • Flax/NNX - the recommended path for new exports.
  • Equinox - retained for historical comparison.

Unless you specifically need the Equinox version, follow the Flax/NNX workflow below. The Equinox workflow is listed separately at the end.

All commands assume you are at the project root with the Poetry environment available. The workflow targets CPU-only tools; feel free to switch the device flags to cuda:X if you have GPU support.

For detailed component exports (like GPT-OSS Attention, MLP, RMSNorm), see the GPT-OSS entries in the Examples table.

1. Download a GPT-OSS checkpoint

OpenAI publishes GPT-OSS on Hugging Face under both 20B and 120B variants. The commands below fetch the original shard layout expected by gpt_oss.torch.model.Transformer.from_checkpoint.

mkdir -p ~/.cache/gpt_oss/gpt-oss-20b
poetry run huggingface-cli download openai/gpt-oss-20b original \
  --repo-type model \
  --local-dir ~/.cache/gpt_oss/gpt-oss-20b \
  --local-dir-use-symlinks False

huggingface_hub ≥1.0 no longer installs the huggingface-cli entry point by default, and recent releases ignore --local-dir-use-symlinks. If the command above fails with “command not found” or you prefer to stay within Python, download the checkpoint folder directly via snapshot_download:

poetry run python - <<'PY'
from pathlib import Path

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="openai/gpt-oss-20b",
    repo_type="model",
    allow_patterns=["original/*"],
    local_dir=Path("~/.cache/gpt_oss/gpt-oss-20b").expanduser(),
)
PY

This grabs the original/ shard set expected by Transformer.from_checkpoint. Omit allow_patterns if you want the full repo contents (tokenizer, chat template, etc.).

After the download finishes you should have a directory containing config.json and a set of *.safetensors shards.

2. Stage GPT-OSS weights for Flax/NNX

Run the staging helper to materialize a Flax .msgpack bundle plus a matching config.json. The exporter consumes this format directly.

poetry run python scripts/export_flax_gpt_oss_params.py \
  --checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
  --output ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack

Use --gpt-oss-path if the helper repo lives somewhere other than the default tmp/gpt-oss-jax-vs-torch-numerical-comparison. The script automatically detects Orbax vs. SafeTensors checkpoints and writes flax_params.msgpack.config.json beside the serialized parameters.

3. Export the Flax/NNX transformer to ONNX

With staged params in place, call the ONNX exporter. It instantiates the examples.nnx_gpt_oss.FlaxTransformer module, loads the staged parameters via nnx.Param assignments, and traces the full embedding → blocks → norm → head pipeline.

poetry run python scripts/export_flax_gpt_oss_to_onnx.py \
  --params ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack \
  --output artifacts/gpt_oss_flax.onnx \
  --sequence-length 256

Notes:

  • --sequence-length controls both the tracing inputs and the rotary/mask tables. Start small (e.g. 128) while verifying the workflow, then bump the length to your deployment target.
  • Pass --config /path/to/config.json if the staging script’s JSON lives elsewhere.
  • The exporter includes validation by default. Use --skip-validation only when you explicitly want to inspect an exported graph before runtime validation.

Tip: When iterating on the exporter it can be helpful to trim the staged checkpoint down to a couple of layers. The snippet below keeps only the first two Transformer blocks while preserving the original hidden size/head layout, producing a much smaller bundle that exports quickly:

poetry run python - <<'PY'
from pathlib import Path
import json
import flax.serialization as serialization

root = Path("~/.cache/gpt_oss/gpt-oss-20b").expanduser()
params = serialization.msgpack_restore((root / "flax_params.msgpack").read_bytes())
keep = {k: params[k] for k in ["embedding", "norm", "unembedding"]}
for idx in range(2):
    keep[f"block_{idx}"] = params[f"block_{idx}"]
(root / "flax_params_2layers.msgpack").write_bytes(serialization.to_bytes(keep))

config = json.loads((root / "flax_params.config.json").read_text())
config["num_hidden_layers"] = 2
(root / "flax_params_2layers.config.json").write_text(json.dumps(config, indent=2))
print("Wrote 2-layer bundle under", root)
PY

Export with --params .../flax_params_2layers.msgpack --config .../flax_params_2layers.config.json to keep traces under a few minutes.

4. ONNX-only smoke test

For a short-window ONNX-only run (no JAX/Torch in memory), export a small model and run tokenizer-backed generation:

JAX_PLATFORM_NAME=cpu \
poetry run python scripts/export_flax_gpt_oss_to_onnx.py \
  --params ~/.cache/gpt_oss/gpt-oss-20b/flax_params.msgpack \
  --config ~/.cache/gpt_oss/gpt-oss-20b/flax_params.config.json \
  --output /tmp/gpt_oss_transformer_flax_seq16.onnx \
  --sequence-length 16 \
  --skip-validation

mkdir -p artifacts/gpt_oss_full_seq16
mv /tmp/gpt_oss_transformer_flax_seq16.onnx artifacts/gpt_oss_full_seq16/
mv /tmp/gpt_oss_transformer_flax_seq16.onnx.data artifacts/gpt_oss_full_seq16/

LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
poetry run python scripts/run_onnx_only.py \
  --onnx artifacts/gpt_oss_full_seq16/gpt_oss_transformer_flax_seq16.onnx \
  --config ~/.cache/gpt_oss/gpt-oss-20b/flax_params.config.json \
  --prompt "France capital? Answer:" \
  --sequence-length 16 \
  --generate-steps 8 \
  --expand-functions \
  --runtime ort

For longer prompts/responses, re-export with a larger --sequence-length or add KV cache support to avoid re-running the full window.

5. (Legacy) Export the Equinox example to ONNX

Use the helper script to load the checkpoint, mirror it into the IR-only Equinox modules, and emit an ONNX graph. On memory constrained systems it helps to run the export in two stages:

  1. Stage the Equinox weights (reads SafeTensors → writes .eqx, no ONNX yet):
poetry run python scripts/export_eqx_gpt_oss_example_with_mapped_weights.py \
  --checkpoint ~/.cache/gpt_oss/gpt-oss-20b/original \
  --save-eqx ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.eqx \
  --seq-len 256 \
  --dynamic-b \
  --skip-onnx
  1. Convert the cached Equinox model to ONNX (no PyTorch in memory):
poetry run python scripts/export_eqx_gpt_oss_example_with_mapped_weights.py \
  --eqx ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.eqx \
  --output ~/.cache/gpt_oss/gpt-oss-20b/eqx_gpt_oss_transformer.onnx \
  --seq-len 256 \
  --dynamic-b 
  • --dynamic-b emits a symbolic batch axis (B).
  • Omit --dynamic-b and/or add --dynamic-seq to tailor the exported shapes.
  • --save-eqx keeps the mapped Equinox parameters around for future exports.
  • Pass a higher --seq-len (e.g. 512) once the 256-token run succeeds; longer sequences raise memory pressure while tracing the attention blocks.