Skip to content

API Reference

jax2onnx.user_interface

User-facing API for converting JAX functions and models to ONNX.

This module provides the primary interface for exporting JAX/Flax models to the ONNX format. It supports dynamic shapes, runtime parameters, and numerical validation against ONNX Runtime.

Key Functions:

  • to_onnx: Convert a JAX function or Flax module to an ONNX model.
  • onnx_function: Decorator to mark a function or class as an ONNX function node.
  • allclose: Validate numerical equivalence between JAX and ONNX Runtime outputs.
Example

from jax2onnx import to_onnx import jax.numpy as jnp

def my_model(x): ... return jnp.sin(x)

to_onnx(my_model, inputs=[('B', 10)], return_mode="file", output_path="model.onnx")

allclose(fn, onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, enable_double_precision=False, inputs_as_nchw=None, outputs_as_nchw=None)

Checks if JAX and ONNX Runtime outputs remain numerically close.

Parameters:

Name Type Description Default
fn Callable

JAX callable to compare against the exported ONNX model.

required
onnx_model_path str

Path to a serialized model that ORT can execute.

required
inputs List[Any]

Concrete input arrays (or shape tuples, which will be sampled).

required
input_params Optional[Dict[str, Any]]

Optional keyword arguments applied to both call sites.

None
rtol float

Relative tolerance for floating-point comparisons.

0.001
atol float

Absolute tolerance for floating-point comparisons.

1e-05
enable_double_precision bool

Temporarily enable jax_enable_x64 while running the comparison. Defaults to False.

False

Returns:

Type Description
bool

(is_match, message) where is_match indicates success and message

str

provides context when a mismatch occurs.

Example

import jax.numpy as jnp from jax2onnx import to_onnx, allclose

1. Define and Export

def my_func(x): ... return jnp.sin(x)

model_path = to_onnx( ... my_func, ... inputs=[('B', 10)], ... return_mode="file", ... output_path="my_model.onnx" ... )

2. Validate

Provide concrete shapes for validation (replacing dynamic dim 'B')

validation_inputs = [(5, 10)] is_match, msg = allclose(my_func, model_path, inputs=validation_inputs, atol=1e-5)

assert is_match, f"Validation failed: {msg}"

Source code in jax2onnx/user_interface.py
def allclose(
    fn: Callable,
    onnx_model_path: str,
    inputs: List[Any],
    input_params: Optional[Dict[str, Any]] = None,
    rtol: float = 1e-3,
    atol: float = 1e-5,
    *,
    enable_double_precision: bool = False,
    inputs_as_nchw: Optional[Sequence[int]] = None,
    outputs_as_nchw: Optional[Sequence[int]] = None,
) -> Tuple[bool, str]:
    """
    Checks if JAX and ONNX Runtime outputs remain numerically close.

    Args:
        fn: JAX callable to compare against the exported ONNX model.
        onnx_model_path: Path to a serialized model that ORT can execute.
        inputs: Concrete input arrays (or shape tuples, which will be sampled).
        input_params: Optional keyword arguments applied to both call sites.
        rtol: Relative tolerance for floating-point comparisons.
        atol: Absolute tolerance for floating-point comparisons.
        enable_double_precision: Temporarily enable `jax_enable_x64` while running the
            comparison. Defaults to False.

    Returns:
        `(is_match, message)` where `is_match` indicates success and `message`
        provides context when a mismatch occurs.

    Example:
        >>> import jax.numpy as jnp
        >>> from jax2onnx import to_onnx, allclose
        >>>
        >>> # 1. Define and Export
        >>> def my_func(x):
        ...     return jnp.sin(x)
        >>>
        >>> model_path = to_onnx(
        ...     my_func,
        ...     inputs=[('B', 10)],
        ...     return_mode="file",
        ...     output_path="my_model.onnx"
        ... )
        >>>
        >>> # 2. Validate
        >>> # Provide concrete shapes for validation (replacing dynamic dim 'B')
        >>> validation_inputs = [(5, 10)]
        >>> is_match, msg = allclose(my_func, model_path, inputs=validation_inputs, atol=1e-5)
        >>>
        >>> assert is_match, f"Validation failed: {msg}"
    """

    logging.info(
        "Comparing JAX and ONNX outputs (path=%s, rtol=%s, atol=%s)",
        onnx_model_path,
        rtol,
        atol,
    )

    def _is_shape(x: Any) -> bool:
        return isinstance(x, (tuple, list)) and all(
            isinstance(dim, (int, str)) for dim in x
        )

    xs: List[Any]
    if all(_is_shape(x) for x in inputs):
        xs = [
            np.random.rand(*[d if isinstance(d, int) else 2 for d in shape]).astype(
                np.float32
            )
            for shape in inputs
        ]
    else:
        xs = list(inputs)

    params = dict(input_params or {})
    with _temporary_x64(enable_double_precision):
        with jax.default_matmul_precision("float32"):
            return _run_allclose(
                fn,
                onnx_model_path,
                xs,
                params,
                rtol=rtol,
                atol=atol,
                inputs_as_nchw=inputs_as_nchw,
                outputs_as_nchw=outputs_as_nchw,
            )

onnx_function(target=None, *, unique=False, namespace=None, name=None, type=None)

Decorator to mark a function or class as an ONNX function.

This decorator is used to indicate that a function or class should be converted to an ONNX function node when included in a model. It allows the function to be traced and exported as a reusable component with its own namespace in the ONNX graph.

Parameters:

Name Type Description Default
target Optional[Union[Callable, type]]

The target function or class to decorate. When omitted, the decorator must be called with parentheses.

None
unique bool

If True, reuse a single ONNX Function definition for all call sites that share the same callable type and captured parameters.

False
namespace Optional[str]

Custom domain prefix for the emitted FunctionProto. Defaults to "custom" when omitted.

None
name Optional[str]

Optional human-readable base name for the ONNX function. When set, this overrides the callable's Python name for the function op_type and FunctionProto name; the domain still derives from namespace.

None
type Optional[str]

Alias for name; preferred keyword for setting the function op_type/display name in ONNX.

None

Returns:

Type Description
Union[Callable, type]

The decorated function or class with ONNX function capabilities.

Example

from jax2onnx import onnx_function import jax.numpy as jnp

@onnx_function def my_custom_op(x, y): ... return jnp.sin(x) * y

Also works with Flax modules:

from flax import nnx

@onnx_function class MLPBlock(nnx.Module): def init(self, features, rngs): self.dense = nnx.Linear(features, rngs=rngs) self.activation = nnx.relu

def __call__(self, x):
    return self.activation(self.dense(x))
Source code in jax2onnx/user_interface.py
def onnx_function(
    target: Optional[Union[Callable, type]] = None,
    *,
    unique: bool = False,
    namespace: Optional[str] = None,
    name: Optional[str] = None,
    type: Optional[str] = None,  # noqa: A002 - user-facing keyword
) -> Union[Callable, type]:
    """
    Decorator to mark a function or class as an ONNX function.

    This decorator is used to indicate that a function or class should be converted to
    an ONNX function node when included in a model. It allows the function to be traced
    and exported as a reusable component with its own namespace in the ONNX graph.

    Args:
        target: The target function or class to decorate. When omitted, the decorator
            must be called with parentheses.
        unique: If True, reuse a single ONNX Function definition for all call sites
            that share the same callable type and captured parameters.
        namespace: Custom domain prefix for the emitted FunctionProto. Defaults to
            ``"custom"`` when omitted.
        name: Optional human-readable base name for the ONNX function. When set,
            this overrides the callable's Python name for the function `op_type`
            and FunctionProto name; the domain still derives from ``namespace``.
        type: Alias for ``name``; preferred keyword for setting the function
            `op_type`/display name in ONNX.

    Returns:
        The decorated function or class with ONNX function capabilities.

    Example:
        >>> from jax2onnx import onnx_function
        >>> import jax.numpy as jnp
        >>>
        >>> @onnx_function
        >>> def my_custom_op(x, y):
        ...     return jnp.sin(x) * y
        >>>
        >>> # Also works with Flax modules:
        >>> from flax import nnx
        >>>
        >>> @onnx_function
        >>> class MLPBlock(nnx.Module):
        >>>     def __init__(self, features, rngs):
        >>>         self.dense = nnx.Linear(features, rngs=rngs)
        >>>         self.activation = nnx.relu
        >>>
        >>>     def __call__(self, x):
        >>>         return self.activation(self.dense(x))
    """

    # Prefer the explicit `type` override; fall back to `name` for BC.
    display = type if isinstance(type, str) and type else name
    return onnx_function_impl(
        target, unique=unique, namespace=namespace, name=display, type=display
    )

to_onnx(fn, inputs, input_params=None, model_name='jax_model', opset=23, *, enable_double_precision=False, record_primitive_calls_file=None, return_mode='proto', output_path=None, inputs_as_nchw=None, outputs_as_nchw=None)

to_onnx(
    fn: Callable,
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["proto"] = ...,
    output_path: None = ...,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...
) -> onnx.ModelProto
to_onnx(
    fn: Callable,
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["ir"],
    output_path: Optional[PathLikeStr] = ...,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...
) -> ir.Model
to_onnx(
    fn: Callable,
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = ...,
    model_name: str = ...,
    opset: int = ...,
    *,
    enable_double_precision: bool = ...,
    record_primitive_calls_file: Optional[str] = ...,
    return_mode: Literal["file"],
    output_path: PathLikeStr,
    inputs_as_nchw: Optional[Sequence[int]] = ...,
    outputs_as_nchw: Optional[Sequence[int]] = ...
) -> str

Converts a JAX function or model into an ONNX model.

This function serves as the main entry point for converting JAX/Flax models to ONNX format. It supports dynamic shapes and additional runtime parameters.

Parameters:

Name Type Description Default
fn Callable

The JAX function or Flax module to convert.

required
inputs Sequence[UserInputSpec]

Sequence of input specifications. Each entry may be: * a jax.ShapeDtypeStruct (or jax.core.ShapedArray); * any array-like object exposing .shape and .dtype (e.g. jax.Array, np.ndarray); * a tuple/list of ints/strs describing the desired shape. Example: [('B', 128), (1, 10)] implies two inputs, the first with a dynamic batch dimension 'B' and fixed size 128.

required
input_params Optional[Mapping[str, object]]

Optional mapping of string keys to runtime parameters that should be exposed as inputs in the ONNX model rather than baked into the export (e.g. "deterministic" flags).

None
model_name str

Name to give the ONNX model. Defaults to "jax_model".

'jax_model'
opset int

ONNX opset version to target. Defaults to 23.

23
enable_double_precision bool

If True, export tensors as tensor(double). Defaults to False (use tensor(float)).

False
record_primitive_calls_file Optional[str]

Optional path to a file. If provided, details of each JAX primitive encountered during conversion will be recorded to this file. This log can be used by developers to manually create new test cases. Defaults to None (disabled).

None
return_mode ReturnMode

Output mode. "proto" (default) returns an ONNX ModelProto, "ir" returns the intermediate onnx_ir.Model, and "file" serialises directly to disk.

'proto'
output_path Optional[PathLikeStr]

Destination path (str or PathLike) required when return_mode is "file". Ignored otherwise.

None
inputs_as_nchw Optional[Sequence[int]]

Optional sequence of input indices (0-based) that should be treated as NCHW layout. If specified for an input, jax2onnx assumes the external input is NCHW and will automatically transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images). This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics.

None
outputs_as_nchw Optional[Sequence[int]]

Optional sequence of output indices (0-based) that should be treated as NCHW layout. If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically transpose the NHWC output derived from JAX graph to NCHW before returning it.

None

Returns:

Type Description
Union[ModelProto, Model, str]
  • If return_mode="proto" (default): Returns an onnx.ModelProto object.
Union[ModelProto, Model, str]
  • If return_mode="ir": Returns an onnx_ir.Model object (intermediate representation).
Union[ModelProto, Model, str]
  • If return_mode="file": Returns the string path to the saved file.

Raises:

Type Description
ValueError

If return_mode is "file" but output_path is not provided.

ValueError

If return_mode is invalid.

TypeError

If input_params keys are not strings.

Example

import jax import jax.numpy as jnp from jax2onnx import to_onnx

Define a simple JAX function

def linear(x, w, b): ... return jnp.dot(x, w) + b

Define input shapes: 'B' indicates a dynamic batch dimension

input_specs = [ ... ('B', 32), # x: [Batch, 32] ... (32, 10), # w: [32, 10] ... (10,) # b: [10] ... ]

to_onnx( ... linear, ... inputs=input_specs, ... model_name="linear_model", ... return_mode="file", ... output_path="linear_model.onnx" ... )

Source code in jax2onnx/user_interface.py
def to_onnx(
    fn: Callable,
    inputs: Sequence[UserInputSpec],
    input_params: Optional[Mapping[str, object]] = None,
    model_name: str = "jax_model",
    opset: int = 23,
    *,  # All arguments after this must be keyword-only
    enable_double_precision: bool = False,
    record_primitive_calls_file: Optional[str] = None,
    return_mode: ReturnMode = "proto",
    output_path: Optional[PathLikeStr] = None,
    inputs_as_nchw: Optional[Sequence[int]] = None,
    outputs_as_nchw: Optional[Sequence[int]] = None,
) -> Union[onnx.ModelProto, ir.Model, str]:
    """
    Converts a JAX function or model into an ONNX model.

    This function serves as the main entry point for converting JAX/Flax models to ONNX format.
    It supports dynamic shapes and additional runtime parameters.

    Args:
        fn: The JAX function or Flax module to convert.
        inputs: Sequence of input specifications. Each entry may be:
            * a `jax.ShapeDtypeStruct` (or `jax.core.ShapedArray`);
            * any array-like object exposing `.shape` and `.dtype`
              (e.g. `jax.Array`, `np.ndarray`);
            * a tuple/list of ints/strs describing the desired shape.
              Example: `[('B', 128), (1, 10)]` implies two inputs, the first with a dynamic batch dimension 'B' and fixed size 128.
        input_params: Optional mapping of string keys to runtime parameters that
            should be exposed as inputs in the ONNX model rather than baked into
            the export (e.g. `"deterministic"` flags).
        model_name: Name to give the ONNX model. Defaults to "jax_model".
        opset: ONNX opset version to target. Defaults to 23.
        enable_double_precision: If True, export tensors as tensor(double). Defaults to False (use tensor(float)).
        record_primitive_calls_file: Optional path to a file. If provided,
            details of each JAX primitive encountered during conversion will be
            recorded to this file. This log can be used by developers to manually
            create new test cases. Defaults to None (disabled).
        return_mode: Output mode. `"proto"` (default) returns an ONNX ModelProto,
            `"ir"` returns the intermediate onnx_ir.Model, and `"file"`
            serialises directly to disk.
        output_path: Destination path (str or PathLike) required when `return_mode` is
            `"file"`. Ignored otherwise.
        inputs_as_nchw: Optional sequence of input indices (0-based) that should be treated as NCHW layout.
            If specified for an input, jax2onnx assumes the external input is NCHW and will automatically
            transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images).
            This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics.
        outputs_as_nchw: Optional sequence of output indices (0-based) that should be treated as NCHW layout.
            If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically
            transpose the NHWC output derived from JAX graph to NCHW before returning it.

    Returns:
        * If `return_mode="proto"` (default): Returns an `onnx.ModelProto` object.
        * If `return_mode="ir"`: Returns an `onnx_ir.Model` object (intermediate representation).
        * If `return_mode="file"`: Returns the string path to the saved file.

    Raises:
        ValueError: If `return_mode` is "file" but `output_path` is not provided.
        ValueError: If `return_mode` is invalid.
        TypeError: If `input_params` keys are not strings.

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from jax2onnx import to_onnx
        >>>
        >>> # Define a simple JAX function
        >>> def linear(x, w, b):
        ...     return jnp.dot(x, w) + b
        >>>
        >>> # Define input shapes: 'B' indicates a dynamic batch dimension
        >>> input_specs = [
        ...     ('B', 32),  # x: [Batch, 32]
        ...     (32, 10),   # w: [32, 10]
        ...     (10,)       # b: [10]
        ... ]
        >>>
        >>> # Convert and save to file directly (Recommended)
        >>> to_onnx(
        ...     linear,
        ...     inputs=input_specs,
        ...     model_name="linear_model",
        ...     return_mode="file",
        ...     output_path="linear_model.onnx"
        ... )
    """

    logging.info(
        f"Converting JAX function to ONNX model with parameters: "
        f"model_name={model_name}, opset={opset}, input_shapes={inputs}, "
        f"input_params={input_params}, "
        f"enable_double_precision={enable_double_precision}, "
        f"record_primitive_calls_file={record_primitive_calls_file}, "
        f"return_mode={return_mode}, output_path={output_path}, "
        f"inputs_as_nchw={inputs_as_nchw}, outputs_as_nchw={outputs_as_nchw}"
    )

    # Determine the nature of the 'inputs' argument to prepare for to_onnx_impl
    normalized_mode = _normalize_return_mode(return_mode)

    file_path: Optional[str] = None
    if normalized_mode == "file":
        if output_path is None:
            raise ValueError(
                "`output_path` must be provided when return_mode is 'file'."
            )
        path_value = os.fspath(output_path)
        if isinstance(path_value, bytes):
            path_value = path_value.decode()
        file_path = cast(str, path_value)

    normalized_inputs: List[InputSpec] = []
    if inputs:
        normalized_inputs = _normalize_input_specs(inputs)

    param_map: Dict[str, object] = {}
    if input_params:
        for key, value in input_params.items():
            if not isinstance(key, str):
                raise TypeError(
                    "input_params must use string keys; "
                    f"received key of type {type(key)}."
                )
            param_map[key] = value

    with _temporary_x64(enable_double_precision):
        result = to_onnx_impl(
            fn=fn,
            inputs=normalized_inputs,
            input_params=param_map,
            model_name=model_name,
            opset=opset,
            enable_double_precision=enable_double_precision,
            record_primitive_calls_file=record_primitive_calls_file,
            protective_clone=(normalized_mode == "ir"),
            inputs_as_nchw=inputs_as_nchw,
            outputs_as_nchw=outputs_as_nchw,
        )

        postprocess_ir_model(
            result,
            promote_to_double=enable_double_precision,
        )

    def _save_model_proto(
        model_proto: onnx.ModelProto,
        dest: str,
        *,
        external_threshold: int = 1_048_576,  # 1 MB default before spilling to .data
    ) -> str:
        dest_dir = os.path.dirname(dest)
        if dest_dir:
            os.makedirs(dest_dir, exist_ok=True)
        data_location = os.path.basename(dest) + ".data"
        data_path = os.path.join(dest_dir or ".", data_location)
        onnx.save_model(
            model_proto,
            dest,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=data_location,
            size_threshold=external_threshold,
            convert_attribute=False,
        )
        # Only keep the .data sidecar if the export actually referenced external data.
        if not any(init.external_data for init in model_proto.graph.initializer):
            # No external payloads; remove an empty sidecar if one was produced.
            try:
                if os.path.exists(data_path) and os.path.getsize(data_path) == 0:
                    os.remove(data_path)
            except OSError:
                pass
        return dest

    _materialize_input_params_on_ir(result, param_map)
    if normalized_mode == "ir":
        return result

    model_proto = ir.to_proto(result)
    if normalized_mode == "file":
        assert file_path is not None
        return _save_model_proto(model_proto, file_path)
    return model_proto