Skip to content

API Reference

jax2onnx exposes four public entry points:

  • to_onnx(...) for export
  • @onnx_function for reusable subgraphs
  • allclose(...) for JAX-vs-ONNX validation
  • allclose_onnxruntime_web(...) for ONNX Runtime Web/WASM parity checks

Common Export Flow

from jax2onnx import to_onnx


model_path = to_onnx(
    fn,
    inputs=[("B", 128)],
    return_mode="file",
    output_path="model.onnx",
)

Use string dimensions such as "B" when you want symbolic dynamic axes.

For direct file export, set return_mode="file" and provide output_path.

For a full export validation checklist, see Validation & Deployment Readiness.

For browser deployment with onnxruntime-web, use the Web export profile:

model_path = to_onnx(
    fn,
    inputs=[("B", 128)],
    return_mode="file",
    output_path="model.web.onnx",
    export_mode="web",
)

export_mode="web" keeps the ONNX graph semantics unchanged, but serializes a single self-contained .onnx file instead of spilling large initializers into a .onnx.data sidecar. That is the easiest artifact shape to serve to onnxruntime-web/wasm.

Parameters To Reach For First

  • inputs: Positional input specs, either concrete arrays, ShapeDtypeStruct values, or shape tuples like ("B", 128).
  • input_params: Runtime flags or keyword-like values that should stay model inputs instead of being baked into the export.
  • return_mode: "proto" for an onnx.ModelProto, "ir" for the intermediate onnx_ir.Model, or "file" to serialize directly to disk.
  • export_mode: "standard" for normal serialization, or "web" for single-file browser/WASM artifacts.
  • enable_double_precision: Temporarily enables x64 export and emits tensor(double) where appropriate.
  • inputs_as_nchw / outputs_as_nchw: Adapt the external ONNX interface to NCHW while keeping the traced JAX computation in its original layout.
  • input_names / output_names: Apply stable user-facing names after conversion.

Browser/WASM Validation

The generated test harness can optionally validate exported models with onnxruntime-web/wasm in Node.js or Chrome/Chromium:

npm install
JAX2ONNX_VALIDATE_ONNXRUNTIME_WEB=1 poetry run pytest -q tests/primitives/test_nn.py

For the browser runner, add:

npx playwright install chromium
JAX2ONNX_VALIDATE_ONNXRUNTIME_WEB=1 \
JAX2ONNX_ONNXRUNTIME_WEB_RUNNER=chrome \
poetry run pytest -q tests/primitives/test_nn.py

When this flag is enabled, generated tests export with export_mode="web", keep the existing JAX-vs-Python-ONNX-Runtime CPU check, then compare the same ONNX model and inputs against onnxruntime-web/wasm.

For a smaller local smoke run that covers the Quickstart Web model plus representative generated LAX/JAX NumPy examples, run the explicit smoke scripts:

scripts/run_onnxruntime_web_smoke.sh
scripts/run_onnxruntime_web_chrome_smoke.sh

When Web runtime validation is requested through the central repository check runner, it runs the full pytest suite with export_mode="web" and the selected runtime runner:

JAX2ONNX_RUN_ONNXRUNTIME_WEB=1 ./scripts/run_all_checks.sh
JAX2ONNX_RUN_ONNXRUNTIME_WEB_CHROME=1 ./scripts/run_all_checks.sh

For browser loading code, validation helpers, CI usage, and troubleshooting, see Browser/WASM Deployment.

Use @onnx_function when repeated callables should become reusable ONNX functions, and allclose(...) when you want a quick numerical check against ONNX Runtime after export.

jax2onnx.user_interface

User-facing API for converting JAX functions and models to ONNX.

This module provides the primary interface for exporting JAX/Flax models to the ONNX format. It supports dynamic shapes, runtime parameters, file-oriented export profiles, and numerical validation against ONNX Runtime.

Key Functions:

  • to_onnx: Convert a JAX function or Flax module to an ONNX model.
  • onnx_function: Decorator to mark a function or class as an ONNX function node.
  • allclose: Validate numerical equivalence between JAX and ONNX Runtime outputs.
  • allclose_onnxruntime_web: Validate ONNX Runtime Web/WASM output against Python ONNX Runtime CPU output.
Example
import jax.numpy as jnp

from jax2onnx import to_onnx


def my_model(x):
    return jnp.sin(x)


to_onnx(
    my_model,
    inputs=[("B", 10)],
    return_mode="file",
    output_path="model.onnx",
)

to_onnx(
    my_model,
    inputs=[("B", 10)],
    return_mode="file",
    output_path="model.web.onnx",
    export_mode="web",
)

allclose(fn, onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, enable_double_precision=False, inputs_as_nchw=None, outputs_as_nchw=None)

Checks if JAX and ONNX Runtime outputs remain numerically close.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX callable to compare against the exported ONNX model.

required
onnx_model_path str

Path to a serialized model that ORT can execute.

required
inputs Sequence[Any]

Concrete input arrays (or shape tuples, which will be sampled).

required
input_params Optional[Dict[str, Any]]

Optional keyword arguments applied to both call sites.

None
rtol float

Relative tolerance for floating-point comparisons.

0.001
atol float

Absolute tolerance for floating-point comparisons.

1e-05
enable_double_precision bool

Temporarily enable jax_enable_x64 while running the comparison. Defaults to False.

False

Returns:

Type Description
bool

(is_match, message) where is_match indicates success and message

str

provides context when a mismatch occurs.

Example
import jax.numpy as jnp

from jax2onnx import allclose, to_onnx


def my_func(x):
    return jnp.sin(x)


model_path = to_onnx(
    my_func,
    inputs=[("B", 10)],
    return_mode="file",
    output_path="my_model.onnx",
)

validation_inputs = [(5, 10)]
is_match, msg = allclose(
    my_func,
    model_path,
    inputs=validation_inputs,
    atol=1e-5,
)

assert is_match, f"Validation failed: {msg}"
Source code in jax2onnx/user_interface.py
def allclose(
    fn: Callable[..., Any],
    onnx_model_path: str,
    inputs: Sequence[Any],
    input_params: Optional[Dict[str, Any]] = None,
    rtol: float = 1e-3,
    atol: float = 1e-5,
    *,
    enable_double_precision: bool = False,
    inputs_as_nchw: Optional[Sequence[int]] = None,
    outputs_as_nchw: Optional[Sequence[int]] = None,
) -> Tuple[bool, str]:
    """
    Checks if JAX and ONNX Runtime outputs remain numerically close.

    Args:
        fn: JAX callable to compare against the exported ONNX model.
        onnx_model_path: Path to a serialized model that ORT can execute.
        inputs: Concrete input arrays (or shape tuples, which will be sampled).
        input_params: Optional keyword arguments applied to both call sites.
        rtol: Relative tolerance for floating-point comparisons.
        atol: Absolute tolerance for floating-point comparisons.
        enable_double_precision: Temporarily enable `jax_enable_x64` while running the
            comparison. Defaults to False.

    Returns:
        `(is_match, message)` where `is_match` indicates success and `message`
        provides context when a mismatch occurs.

    Example:
        ```python
        import jax.numpy as jnp

        from jax2onnx import allclose, to_onnx


        def my_func(x):
            return jnp.sin(x)


        model_path = to_onnx(
            my_func,
            inputs=[("B", 10)],
            return_mode="file",
            output_path="my_model.onnx",
        )

        validation_inputs = [(5, 10)]
        is_match, msg = allclose(
            my_func,
            model_path,
            inputs=validation_inputs,
            atol=1e-5,
        )

        assert is_match, f"Validation failed: {msg}"
        ```
    """

    logging.info(
        "Comparing JAX and ONNX outputs (path=%s, rtol=%s, atol=%s)",
        onnx_model_path,
        rtol,
        atol,
    )

    xs = _validation_inputs_to_arrays(inputs)

    params = dict(input_params or {})
    with _temporary_x64(enable_double_precision):
        with jax.default_matmul_precision("float32"):
            return _run_allclose(
                fn,
                onnx_model_path,
                xs,
                params,
                rtol=rtol,
                atol=atol,
                inputs_as_nchw=inputs_as_nchw,
                outputs_as_nchw=outputs_as_nchw,
            )

allclose_onnxruntime_web(onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, inputs_as_nchw=None, node_command='node', runner=None)

Validate that onnxruntime-web/wasm matches Python ONNX Runtime CPU output.

This helper is intended for deployment-readiness checks. It executes the serialized model once with Python onnxruntime on CPUExecutionProvider, then executes the same model and feeds through onnxruntime-web/wasm. The default runner is Node.js; pass runner="chrome" or set JAX2ONNX_ONNXRUNTIME_WEB_RUNNER=chrome to validate in a real browser via Playwright/Chromium.

Parameters:

Name Type Description Default
onnx_model_path str

Path to a serialized ONNX model, typically produced with to_onnx(..., return_mode="file", export_mode="web").

required
inputs Sequence[Any]

Concrete input arrays (or shape tuples, which will be sampled).

required
input_params Optional[Dict[str, Any]]

Optional named ONNX inputs that correspond to runtime parameters materialized during export.

None
rtol float

Relative tolerance for output comparison.

0.001
atol float

Absolute tolerance for output comparison.

1e-05
inputs_as_nchw Optional[Sequence[int]]

Optional sequence of input indices that should be transposed from NHWC validation arrays to NCHW ONNX feeds.

None
node_command str

Node.js executable used to run the Web/WASM validator.

'node'
runner Optional[str]

Optional runtime selector. Use "node" for Node.js WASM validation or "chrome"/"browser" for Playwright Chromium.

None

Returns:

Type Description
bool

(is_match, message) where is_match indicates success and message

str

contains the Web runner result or failure details.

Example
import numpy as np

from jax2onnx import allclose_onnxruntime_web


is_match, message = allclose_onnxruntime_web(
    "model.web.onnx",
    inputs=[np.zeros((2, 10), dtype=np.float32)],
    rtol=1e-5,
    atol=1e-5,
    runner="node",
)

assert is_match, message
Source code in jax2onnx/user_interface.py
def allclose_onnxruntime_web(
    onnx_model_path: str,
    inputs: Sequence[Any],
    input_params: Optional[Dict[str, Any]] = None,
    rtol: float = 1e-3,
    atol: float = 1e-5,
    *,
    inputs_as_nchw: Optional[Sequence[int]] = None,
    node_command: str = "node",
    runner: Optional[str] = None,
) -> Tuple[bool, str]:
    """
    Validate that `onnxruntime-web/wasm` matches Python ONNX Runtime CPU output.

    This helper is intended for deployment-readiness checks. It executes the
    serialized model once with Python `onnxruntime` on `CPUExecutionProvider`,
    then executes the same model and feeds through `onnxruntime-web/wasm`.
    The default runner is Node.js; pass `runner="chrome"` or set
    `JAX2ONNX_ONNXRUNTIME_WEB_RUNNER=chrome` to validate in a real browser via
    Playwright/Chromium.

    Args:
        onnx_model_path: Path to a serialized ONNX model, typically produced with
            `to_onnx(..., return_mode="file", export_mode="web")`.
        inputs: Concrete input arrays (or shape tuples, which will be sampled).
        input_params: Optional named ONNX inputs that correspond to runtime
            parameters materialized during export.
        rtol: Relative tolerance for output comparison.
        atol: Absolute tolerance for output comparison.
        inputs_as_nchw: Optional sequence of input indices that should be
            transposed from NHWC validation arrays to NCHW ONNX feeds.
        node_command: Node.js executable used to run the Web/WASM validator.
        runner: Optional runtime selector. Use `"node"` for Node.js WASM
            validation or `"chrome"`/`"browser"` for Playwright Chromium.

    Returns:
        `(is_match, message)` where `is_match` indicates success and `message`
        contains the Web runner result or failure details.

    Example:
        ```python
        import numpy as np

        from jax2onnx import allclose_onnxruntime_web


        is_match, message = allclose_onnxruntime_web(
            "model.web.onnx",
            inputs=[np.zeros((2, 10), dtype=np.float32)],
            rtol=1e-5,
            atol=1e-5,
            runner="node",
        )

        assert is_match, message
        ```
    """

    xs = _validation_inputs_to_arrays(inputs)
    params = dict(input_params or {})

    try:
        ort = cast(Any, importlib.import_module("onnxruntime"))
    except ImportError as exc:
        return False, f"onnxruntime is required for CPU reference execution: {exc}"

    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
    sess_options.enable_mem_pattern = False
    sess_options.intra_op_num_threads = 1
    sess_options.inter_op_num_threads = 1

    try:
        session = ort.InferenceSession(
            onnx_model_path,
            sess_options=sess_options,
            providers=["CPUExecutionProvider"],
        )
    except Exception as exc:  # pragma: no cover - defensive aid
        return False, f"Failed to create CPU ONNX Runtime session: {exc}"

    ort_xs = list(xs)
    if inputs_as_nchw:
        nhwc_to_nchw = [0, 3, 1, 2]
        for idx in inputs_as_nchw:
            if 0 <= idx < len(ort_xs):
                val = ort_xs[idx]
                if hasattr(val, "ndim") and val.ndim == 4:
                    ort_xs[idx] = np.transpose(val, nhwc_to_nchw)

    try:
        ort_inputs = _build_ort_inputs(session, ort_xs, params)
        expected_outputs = session.run(None, ort_inputs)
    except Exception as exc:  # pragma: no cover - defensive aid
        return False, f"Failed to run CPU ONNX Runtime reference: {exc}"

    output_names = [output_meta.name for output_meta in session.get_outputs()]
    if len(output_names) != len(expected_outputs):
        return (
            False,
            f"Output metadata mismatch (names={len(output_names)} outputs={len(expected_outputs)})",
        )

    try:
        spec = {
            "modelPath": str(Path(onnx_model_path).resolve()),
            "rtol": float(rtol),
            "atol": float(atol),
            "inputs": [
                _array_to_onnxruntime_web_tensor(name, value)
                for name, value in ort_inputs.items()
            ],
            "outputs": [
                _array_to_onnxruntime_web_tensor(name, value)
                for name, value in zip(output_names, expected_outputs, strict=True)
            ],
        }
    except TypeError as exc:
        return False, str(exc)

    configured_runner = (
        runner
        if runner is not None
        else os.getenv("JAX2ONNX_ONNXRUNTIME_WEB_RUNNER", "node")
    )
    runner_name = configured_runner.strip().lower()
    if runner_name == "node":
        script_name = "validate_onnxruntime_web.mjs"
    elif runner_name in {"browser", "chrome"}:
        script_name = "validate_onnxruntime_web_chrome.mjs"
        runner_name = "chrome"
    else:
        return (
            False,
            "Unsupported onnxruntime-web runner "
            f"'{runner_name}'. Use 'node' or 'chrome'.",
        )

    script_path = Path(__file__).resolve().parent.parent / "scripts" / script_name
    if not script_path.is_file():
        return (
            False,
            f"onnxruntime-web/{runner_name} validation script not found: {script_path}",
        )

    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".json", encoding="utf-8", delete=False
    ) as tmp:
        spec_path = Path(tmp.name)
        json.dump(spec, tmp, allow_nan=False)

    try:
        completed = subprocess.run(
            [node_command, str(script_path), str(spec_path)],
            cwd=str(script_path.parent.parent),
            text=True,
            capture_output=True,
            check=False,
            env=os.environ.copy(),
        )
    except FileNotFoundError:
        return False, f"Node.js command not found: {node_command}"
    finally:
        try:
            spec_path.unlink()
        except OSError:
            pass

    if completed.returncode != 0:
        stderr = completed.stderr.strip()
        stdout = completed.stdout.strip()
        details = stderr or stdout or f"exit code {completed.returncode}"
        return False, f"onnxruntime-web/{runner_name} validation failed: {details}"

    summary = (
        completed.stdout.strip() or f"onnxruntime-web/{runner_name} outputs match."
    )
    return True, summary

onnx_function(target=None, *, unique=False, namespace=None, name=None, type=None)

onnx_function(
    target: None = ...,
    *,
    unique: bool = ...,
    namespace: Optional[str] = ...,
    name: Optional[str] = ...,
    type: Optional[str] = ...
) -> OnnxFunctionDecorator
onnx_function(
    target: OnnxFunctionTarget,
    *,
    unique: bool = ...,
    namespace: Optional[str] = ...,
    name: Optional[str] = ...,
    type: Optional[str] = ...
) -> OnnxFunctionTarget

Decorator to mark a function or class as an ONNX function.

This decorator is used to indicate that a function or class should be converted to an ONNX function node when included in a model. It allows the function to be traced and exported as a reusable component with its own namespace in the ONNX graph.

Parameters:

Name Type Description Default
target Optional[OnnxFunctionTarget]

The target function or class to decorate. When omitted, the decorator must be called with parentheses.

None
unique bool

If True, reuse a single ONNX Function definition for all call sites that share the same callable type and captured parameters.

False
namespace Optional[str]

Custom domain prefix for the emitted FunctionProto. Defaults to "custom" when omitted.

None
name Optional[str]

Optional human-readable base name for the ONNX function. When set, this overrides the callable's Python name for the function op_type and FunctionProto name; the domain still derives from namespace.

None
type Optional[str]

Alias for name; preferred keyword for setting the function op_type/display name in ONNX. If both name and type are supplied, type takes precedence.

None

Raises:

Type Description
ValueError

If the same target is decorated again with a conflicting namespace or display-name override.

Returns:

Type Description
Union[OnnxFunctionTarget, OnnxFunctionDecorator]

The decorated function or class with ONNX function capabilities.

Example
import jax.numpy as jnp
from flax import nnx

from jax2onnx import onnx_function


@onnx_function
def my_custom_op(x, y):
    return jnp.sin(x) * y


@onnx_function
class MLPBlock(nnx.Module):
    def __init__(self, features, rngs):
        self.dense = nnx.Linear(features, rngs=rngs)
        self.activation = nnx.relu

    def __call__(self, x):
        return self.activation(self.dense(x))
Source code in jax2onnx/user_interface.py
def onnx_function(
    target: Optional[OnnxFunctionTarget] = None,
    *,
    unique: bool = False,
    namespace: Optional[str] = None,
    name: Optional[str] = None,
    type: Optional[str] = None,  # noqa: A002 - user-facing keyword
) -> Union[OnnxFunctionTarget, OnnxFunctionDecorator]:
    """
    Decorator to mark a function or class as an ONNX function.

    This decorator is used to indicate that a function or class should be converted to
    an ONNX function node when included in a model. It allows the function to be traced
    and exported as a reusable component with its own namespace in the ONNX graph.

    Args:
        target: The target function or class to decorate. When omitted, the decorator
            must be called with parentheses.
        unique: If True, reuse a single ONNX Function definition for all call sites
            that share the same callable type and captured parameters.
        namespace: Custom domain prefix for the emitted FunctionProto. Defaults to
            ``"custom"`` when omitted.
        name: Optional human-readable base name for the ONNX function. When set,
            this overrides the callable's Python name for the function `op_type`
            and FunctionProto name; the domain still derives from ``namespace``.
        type: Alias for ``name``; preferred keyword for setting the function
            `op_type`/display name in ONNX. If both ``name`` and ``type`` are
            supplied, ``type`` takes precedence.

    Raises:
        ValueError: If the same target is decorated again with a conflicting
            namespace or display-name override.

    Returns:
        The decorated function or class with ONNX function capabilities.

    Example:
        ```python
        import jax.numpy as jnp
        from flax import nnx

        from jax2onnx import onnx_function


        @onnx_function
        def my_custom_op(x, y):
            return jnp.sin(x) * y


        @onnx_function
        class MLPBlock(nnx.Module):
            def __init__(self, features, rngs):
                self.dense = nnx.Linear(features, rngs=rngs)
                self.activation = nnx.relu

            def __call__(self, x):
                return self.activation(self.dense(x))
        ```
    """

    return cast(
        Union[OnnxFunctionTarget, OnnxFunctionDecorator],
        onnx_function_impl(
            target, unique=unique, namespace=namespace, name=name, type=type
        ),
    )

to_onnx(fn, inputs, input_params=None, model_name='jax_model', opset=23, *, enable_double_precision=False, record_primitive_calls_file=None, return_mode='proto', output_path=None, inputs_as_nchw=None, outputs_as_nchw=None, input_names=None, output_names=None, export_mode='standard')

to_onnx(
    fn: Callable[..., Any],
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["proto"] = ...,
    output_path: None = ...,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...,
    input_names: Optional[Sequence[str]] = ...,
    output_names: Optional[Sequence[str]] = ...,
    export_mode: ExportMode = ...
) -> onnx.ModelProto
to_onnx(
    fn: Callable[..., Any],
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["ir"],
    output_path: Optional[PathLikeStr] = ...,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...,
    input_names: Optional[Sequence[str]] = ...,
    output_names: Optional[Sequence[str]] = ...,
    export_mode: ExportMode = ...
) -> ir.Model
to_onnx(
    fn: Callable[..., Any],
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["file"],
    output_path: PathLikeStr,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...,
    input_names: Optional[Sequence[str]] = ...,
    output_names: Optional[Sequence[str]] = ...,
    export_mode: ExportMode = ...
) -> str

Converts a JAX function or model into an ONNX model.

This function serves as the main entry point for converting JAX/Flax models to ONNX format. It supports dynamic shapes and additional runtime parameters.

Parameters:

Name Type Description Default
fn Callable[..., Any]

The JAX function or Flax module to convert.

required
inputs Sequence[UserInputSpec]

Sequence of input specifications. Each entry may be: * a jax.ShapeDtypeStruct (or jax.core.ShapedArray); * any array-like object exposing .shape and .dtype (e.g. jax.Array, np.ndarray); * a tuple/list of ints/strs describing the desired shape. Example: [('B', 128), (1, 10)] implies two inputs, the first with a dynamic batch dimension 'B' and fixed size 128.

required
input_params Optional[Mapping[str, object]]

Optional mapping of string keys to runtime parameters that should be exposed as inputs in the ONNX model rather than baked into the export (e.g. "deterministic" flags).

None
model_name str

Name to give the ONNX model. Defaults to "jax_model".

'jax_model'
opset int

ONNX opset version to target. Defaults to 23.

23
enable_double_precision bool

If True, export tensors as tensor(double). Defaults to False (use tensor(float)).

False
record_primitive_calls_file Optional[str]

Optional path to a file. If provided, details of each JAX primitive encountered during conversion will be recorded to this file. This log can be used by developers to manually create new test cases. Defaults to None (disabled).

None
return_mode ReturnMode

Output mode. "proto" (default) returns an ONNX ModelProto, "ir" returns the intermediate onnx_ir.Model, and "file" serialises directly to disk.

'proto'
output_path Optional[PathLikeStr]

Destination path (str or PathLike) required when return_mode is "file". Ignored otherwise.

None
inputs_as_nchw Optional[Sequence[int]]

Optional sequence of input indices (0-based) that should be treated as NCHW layout. If specified for an input, jax2onnx assumes the external input is NCHW and will automatically transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images). This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics.

None
outputs_as_nchw Optional[Sequence[int]]

Optional sequence of output indices (0-based) that should be treated as NCHW layout. If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically transpose the NHWC output derived from JAX graph to NCHW before returning it.

None
input_names Optional[Sequence[str]]

Optional sequence of names for positional model inputs. When provided, names are assigned by positional index (0-based) after conversion. These names do not apply to entries supplied through input_params.

None
output_names Optional[Sequence[str]]

Optional sequence of names for model outputs (0-based). Names are assigned after conversion in output order.

None
export_mode ExportMode

Serialization profile. "standard" preserves the existing file behavior, spilling large initializers into .onnx.data sidecars when needed. "web" writes a single self-contained .onnx file for browser/WASM deployment via onnxruntime-web. This only affects return_mode="file"; "proto" and "ir" return values are unchanged.

'standard'

Returns:

Type Description
Union[ModelProto, Model, str]
  • If return_mode="proto" (default): Returns an onnx.ModelProto object.
Union[ModelProto, Model, str]
  • If return_mode="ir": Returns an onnx_ir.Model object (intermediate representation).
Union[ModelProto, Model, str]
  • If return_mode="file": Returns the string path to the saved file.

Raises:

Type Description
ValueError

If return_mode is "file" but output_path is not provided.

ValueError

If return_mode is invalid.

TypeError

If input_params keys are not strings.

Example
import jax.numpy as jnp

from jax2onnx import to_onnx


def linear(x, w, b):
    return jnp.dot(x, w) + b


input_specs = [
    ("B", 32),  # x: [Batch, 32]
    (32, 10),  # w: [32, 10]
    (10,),  # b: [10]
]

to_onnx(
    linear,
    inputs=input_specs,
    model_name="linear_model",
    return_mode="file",
    output_path="linear_model.onnx",
)
Source code in jax2onnx/user_interface.py
def to_onnx(
    fn: Callable[..., Any],
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = None,
    model_name: str = "jax_model",
    opset: int = 23,
    *,  # All arguments after this must be keyword-only
    enable_double_precision: bool = False,
    record_primitive_calls_file: Optional[str] = None,
    return_mode: ReturnMode = "proto",
    output_path: Optional[PathLikeStr] = None,
    inputs_as_nchw: Optional[Sequence[int]] = None,
    outputs_as_nchw: Optional[Sequence[int]] = None,
    input_names: Optional[Sequence[str]] = None,
    output_names: Optional[Sequence[str]] = None,
    export_mode: ExportMode = "standard",
) -> Union[onnx.ModelProto, ir.Model, str]:
    """
    Converts a JAX function or model into an ONNX model.

    This function serves as the main entry point for converting JAX/Flax models to ONNX format.
    It supports dynamic shapes and additional runtime parameters.

    Args:
        fn: The JAX function or Flax module to convert.
        inputs: Sequence of input specifications. Each entry may be:
            * a `jax.ShapeDtypeStruct` (or `jax.core.ShapedArray`);
            * any array-like object exposing `.shape` and `.dtype`
              (e.g. `jax.Array`, `np.ndarray`);
            * a tuple/list of ints/strs describing the desired shape.
              Example: `[('B', 128), (1, 10)]` implies two inputs, the first with a dynamic batch dimension 'B' and fixed size 128.
        input_params: Optional mapping of string keys to runtime parameters that
            should be exposed as inputs in the ONNX model rather than baked into
            the export (e.g. `"deterministic"` flags).
        model_name: Name to give the ONNX model. Defaults to "jax_model".
        opset: ONNX opset version to target. Defaults to 23.
        enable_double_precision: If True, export tensors as tensor(double). Defaults to False (use tensor(float)).
        record_primitive_calls_file: Optional path to a file. If provided,
            details of each JAX primitive encountered during conversion will be
            recorded to this file. This log can be used by developers to manually
            create new test cases. Defaults to None (disabled).
        return_mode: Output mode. `"proto"` (default) returns an ONNX ModelProto,
            `"ir"` returns the intermediate onnx_ir.Model, and `"file"`
            serialises directly to disk.
        output_path: Destination path (str or PathLike) required when `return_mode` is
            `"file"`. Ignored otherwise.
        inputs_as_nchw: Optional sequence of input indices (0-based) that should be treated as NCHW layout.
            If specified for an input, jax2onnx assumes the external input is NCHW and will automatically
            transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images).
            This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics.
        outputs_as_nchw: Optional sequence of output indices (0-based) that should be treated as NCHW layout.
            If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically
            transpose the NHWC output derived from JAX graph to NCHW before returning it.
        input_names: Optional sequence of names for positional model inputs.
            When provided, names are assigned by positional index (0-based) after conversion.
            These names do not apply to entries supplied through `input_params`.
        output_names: Optional sequence of names for model outputs (0-based).
            Names are assigned after conversion in output order.
        export_mode: Serialization profile. `"standard"` preserves the existing
            file behavior, spilling large initializers into `.onnx.data` sidecars
            when needed. `"web"` writes a single self-contained `.onnx` file for
            browser/WASM deployment via `onnxruntime-web`. This only affects
            `return_mode="file"`; `"proto"` and `"ir"` return values are unchanged.

    Returns:
        * If `return_mode="proto"` (default): Returns an `onnx.ModelProto` object.
        * If `return_mode="ir"`: Returns an `onnx_ir.Model` object (intermediate representation).
        * If `return_mode="file"`: Returns the string path to the saved file.

    Raises:
        ValueError: If `return_mode` is "file" but `output_path` is not provided.
        ValueError: If `return_mode` is invalid.
        TypeError: If `input_params` keys are not strings.

    Example:
        ```python
        import jax.numpy as jnp

        from jax2onnx import to_onnx


        def linear(x, w, b):
            return jnp.dot(x, w) + b


        input_specs = [
            ("B", 32),  # x: [Batch, 32]
            (32, 10),  # w: [32, 10]
            (10,),  # b: [10]
        ]

        to_onnx(
            linear,
            inputs=input_specs,
            model_name="linear_model",
            return_mode="file",
            output_path="linear_model.onnx",
        )
        ```
    """

    logging.info(
        f"Converting JAX function to ONNX model with parameters: "
        f"model_name={model_name}, opset={opset}, input_shapes={inputs}, "
        f"input_params={input_params}, "
        f"enable_double_precision={enable_double_precision}, "
        f"record_primitive_calls_file={record_primitive_calls_file}, "
        f"return_mode={return_mode}, output_path={output_path}, "
        f"inputs_as_nchw={inputs_as_nchw}, outputs_as_nchw={outputs_as_nchw}, "
        f"input_names={input_names}, output_names={output_names}, "
        f"export_mode={export_mode}"
    )

    # Determine the nature of the 'inputs' argument to prepare for to_onnx_impl
    normalized_mode = _normalize_return_mode(return_mode)
    normalized_export_mode = _normalize_export_mode(export_mode)

    file_path: Optional[str] = None
    if normalized_mode == "file":
        if output_path is None:
            raise ValueError(
                "`output_path` must be provided when return_mode is 'file'."
            )
        file_path = os.fspath(output_path)

    normalized_inputs: List[InputSpec] = []
    if inputs:
        normalized_inputs = _normalize_input_specs(inputs)

    normalized_input_names = _normalize_io_names(input_names, kind="input_names")
    normalized_output_names = _normalize_io_names(output_names, kind="output_names")

    if normalized_input_names is not None and len(normalized_input_names) != len(
        normalized_inputs
    ):
        raise ValueError(
            f"input_names length ({len(normalized_input_names)}) must match "
            f"the number of positional inputs ({len(normalized_inputs)})."
        )

    param_map: Dict[str, object] = {}
    if input_params:
        for key, value in input_params.items():
            if not isinstance(key, str):
                raise TypeError(
                    "input_params must use string keys; "
                    f"received key of type {type(key)}."
                )
            param_map[key] = value

    if normalized_input_names is not None:
        collisions = sorted(set(normalized_input_names).intersection(param_map.keys()))
        if collisions:
            names = ", ".join(collisions)
            raise ValueError(
                f"input_names collide with names reserved by input_params: {names}."
            )

    if normalized_output_names is not None:
        collisions = sorted(set(normalized_output_names).intersection(param_map.keys()))
        if collisions:
            names = ", ".join(collisions)
            raise ValueError(
                f"output_names collide with names reserved by input_params: {names}."
            )

    with _temporary_x64(enable_double_precision):
        result = to_onnx_impl(
            fn=fn,
            inputs=normalized_inputs,
            input_params=param_map,
            model_name=model_name,
            opset=opset,
            enable_double_precision=enable_double_precision,
            record_primitive_calls_file=record_primitive_calls_file,
            protective_clone=(normalized_mode == "ir"),
            inputs_as_nchw=inputs_as_nchw,
            outputs_as_nchw=outputs_as_nchw,
            input_names=normalized_input_names,
            output_names=normalized_output_names,
        )

        postprocess_ir_model(
            result,
            promote_to_double=enable_double_precision,
        )

    def _save_model_proto(
        model_proto: onnx.ModelProto,
        dest: str,
        *,
        mode: ExportMode,
        external_threshold: int = 1_048_576,  # 1 MB default before spilling to .data
    ) -> str:
        dest_dir = os.path.dirname(dest)
        if dest_dir:
            os.makedirs(dest_dir, exist_ok=True)
        data_location = os.path.basename(dest) + ".data"
        data_path = os.path.join(dest_dir or ".", data_location)

        if mode == "web":
            onnx.save_model(model_proto, dest, save_as_external_data=False)
            # A previous standard export to the same path may have left a sidecar.
            try:
                if os.path.exists(data_path):
                    os.remove(data_path)
            except OSError:
                pass
            return dest

        onnx.save_model(
            model_proto,
            dest,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=data_location,
            size_threshold=external_threshold,
            convert_attribute=False,
        )
        # Only keep the .data sidecar if the export actually referenced external data.
        if not any(init.external_data for init in model_proto.graph.initializer):
            # No external payloads; remove an empty sidecar if one was produced.
            try:
                if os.path.exists(data_path) and os.path.getsize(data_path) == 0:
                    os.remove(data_path)
            except OSError:
                pass
        return dest

    _materialize_input_params_on_ir(result, param_map)
    _apply_custom_io_names_on_ir(
        result,
        input_names=normalized_input_names,
        output_names=normalized_output_names,
        positional_input_count=len(normalized_inputs),
    )
    if normalized_mode == "ir":
        return result

    model_proto = ir.to_proto(result)
    if normalized_mode == "file":
        assert file_path is not None
        return _save_model_proto(model_proto, file_path, mode=normalized_export_mode)
    return model_proto