Skip to content

Getting Started

Quickstart

Install and export your first model in minutes:

pip install jax2onnx

Convert your JAX callable to ONNX in just a few lines:

from flax import nnx
from jax2onnx import to_onnx

# Define a simple inference MLP
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, deterministic=True, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, use_running_average=True, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x):
        x = self.bn(self.linear1(x))
        x = self.dropout(x, deterministic=True)
        x = nnx.gelu(x)
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Export straight to disk without keeping the proto in memory
to_onnx(
    my_callable,
    [("B", 30)],
    return_mode="file",
    output_path="my_callable.onnx",
)

For a basic structural and numerical validation workflow, see Validation & Deployment Readiness.

For modules with dropout, batch normalization, mutable state, or RNG-dependent behavior, make the intended inference behavior explicit before export.

🔎 See it visualized: my_callable.onnx

Browser/WASM Export

For browser deployment with onnxruntime-web/wasm, export a self-contained ONNX file with export_mode="web":

from jax2onnx import to_onnx
from jax2onnx.quickstart import build_quickstart_web_model

model = build_quickstart_web_model()
to_onnx(
    model,
    [("B", 8)],
    return_mode="file",
    output_path="web_mlp.onnx",
    export_mode="web",
)

Generated test runs can validate the same model with Python ONNX Runtime CPU and onnxruntime-web/wasm:

npm install
JAX2ONNX_VALIDATE_ONNXRUNTIME_WEB=1 poetry run pytest -q tests/extra_tests/test_quickstart.py

For a broader but still lightweight smoke run, use the explicit smoke scripts:

scripts/run_onnxruntime_web_smoke.sh
scripts/run_onnxruntime_web_chrome_smoke.sh

The central repository check runner performs full pytest Web runtime validation when either Web runtime flag is enabled:

JAX2ONNX_RUN_ONNXRUNTIME_WEB=1 ./scripts/run_all_checks.sh
JAX2ONNX_RUN_ONNXRUNTIME_WEB_CHROME=1 ./scripts/run_all_checks.sh

See Browser/WASM Deployment for browser loading code, Node.js/Chrome validation, CI usage, and troubleshooting.

ONNX Functions — Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.

from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
to_onnx(
    callable,
    [(100, 256)],
    return_mode="file",
    output_path="model_with_function.onnx",
)

🔎 See it visualized: model_with_function.onnx

See ONNX Functions for naming, namespaces, and reuse options.

Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
    

  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).

For broader support boundaries, see Known Limitations.

Looking for provenance details while debugging? Check out the new Stacktrace Metadata guide.