API Reference¶
jax2onnx exposes four public entry points:
to_onnx(...)for export@onnx_functionfor reusable subgraphsallclose(...)for JAX-vs-ONNX validationallclose_onnxruntime_web(...)for ONNX Runtime Web/WASM parity checks
Common Export Flow¶
from jax2onnx import to_onnx
model_path = to_onnx(
fn,
inputs=[("B", 128)],
return_mode="file",
output_path="model.onnx",
)
Use string dimensions such as "B" when you want symbolic dynamic axes.
For direct file export, set return_mode="file" and provide output_path.
For a full export validation checklist, see Validation & Deployment Readiness.
For browser deployment with onnxruntime-web, use the Web export profile:
model_path = to_onnx(
fn,
inputs=[("B", 128)],
return_mode="file",
output_path="model.web.onnx",
export_mode="web",
)
export_mode="web" keeps the ONNX graph semantics unchanged, but serializes a
single self-contained .onnx file instead of spilling large initializers into a
.onnx.data sidecar. That is the easiest artifact shape to serve to
onnxruntime-web/wasm.
Parameters To Reach For First¶
inputs: Positional input specs, either concrete arrays,ShapeDtypeStructvalues, or shape tuples like("B", 128).input_params: Runtime flags or keyword-like values that should stay model inputs instead of being baked into the export.return_mode:"proto"for anonnx.ModelProto,"ir"for the intermediateonnx_ir.Model, or"file"to serialize directly to disk.export_mode:"standard"for normal serialization, or"web"for single-file browser/WASM artifacts.enable_double_precision: Temporarily enables x64 export and emitstensor(double)where appropriate.inputs_as_nchw/outputs_as_nchw: Adapt the external ONNX interface to NCHW while keeping the traced JAX computation in its original layout.input_names/output_names: Apply stable user-facing names after conversion.
Browser/WASM Validation¶
The generated test harness can optionally validate exported models with
onnxruntime-web/wasm in Node.js or Chrome/Chromium:
For the browser runner, add:
npx playwright install chromium
JAX2ONNX_VALIDATE_ONNXRUNTIME_WEB=1 \
JAX2ONNX_ONNXRUNTIME_WEB_RUNNER=chrome \
poetry run pytest -q tests/primitives/test_nn.py
When this flag is enabled, generated tests export with export_mode="web", keep
the existing JAX-vs-Python-ONNX-Runtime CPU check, then compare the same ONNX
model and inputs against onnxruntime-web/wasm.
For a smaller local smoke run that covers the Quickstart Web model plus representative generated LAX/JAX NumPy examples, run the explicit smoke scripts:
When Web runtime validation is requested through the central repository check
runner, it runs the full pytest suite with export_mode="web" and the selected
runtime runner:
JAX2ONNX_RUN_ONNXRUNTIME_WEB=1 ./scripts/run_all_checks.sh
JAX2ONNX_RUN_ONNXRUNTIME_WEB_CHROME=1 ./scripts/run_all_checks.sh
For browser loading code, validation helpers, CI usage, and troubleshooting, see Browser/WASM Deployment.
Use @onnx_function when repeated callables should become
reusable ONNX functions, and allclose(...) when you want a quick numerical
check against ONNX Runtime after export.
jax2onnx.user_interface
¶
User-facing API for converting JAX functions and models to ONNX.
This module provides the primary interface for exporting JAX/Flax models to the ONNX format. It supports dynamic shapes, runtime parameters, file-oriented export profiles, and numerical validation against ONNX Runtime.
Key Functions:
- to_onnx: Convert a JAX function or Flax module to an ONNX model.
- onnx_function: Decorator to mark a function or class as an ONNX function node.
- allclose: Validate numerical equivalence between JAX and ONNX Runtime outputs.
- allclose_onnxruntime_web: Validate ONNX Runtime Web/WASM output against Python ONNX Runtime CPU output.
Example
allclose(fn, onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, enable_double_precision=False, inputs_as_nchw=None, outputs_as_nchw=None)
¶
Checks if JAX and ONNX Runtime outputs remain numerically close.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX callable to compare against the exported ONNX model. |
required |
onnx_model_path
|
str
|
Path to a serialized model that ORT can execute. |
required |
inputs
|
Sequence[Any]
|
Concrete input arrays (or shape tuples, which will be sampled). |
required |
input_params
|
Optional[Dict[str, Any]]
|
Optional keyword arguments applied to both call sites. |
None
|
rtol
|
float
|
Relative tolerance for floating-point comparisons. |
0.001
|
atol
|
float
|
Absolute tolerance for floating-point comparisons. |
1e-05
|
enable_double_precision
|
bool
|
Temporarily enable |
False
|
Returns:
| Type | Description |
|---|---|
bool
|
|
str
|
provides context when a mismatch occurs. |
Example
import jax.numpy as jnp
from jax2onnx import allclose, to_onnx
def my_func(x):
return jnp.sin(x)
model_path = to_onnx(
my_func,
inputs=[("B", 10)],
return_mode="file",
output_path="my_model.onnx",
)
validation_inputs = [(5, 10)]
is_match, msg = allclose(
my_func,
model_path,
inputs=validation_inputs,
atol=1e-5,
)
assert is_match, f"Validation failed: {msg}"
Source code in jax2onnx/user_interface.py
800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 | |
allclose_onnxruntime_web(onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, inputs_as_nchw=None, node_command='node', runner=None)
¶
Validate that onnxruntime-web/wasm matches Python ONNX Runtime CPU output.
This helper is intended for deployment-readiness checks. It executes the
serialized model once with Python onnxruntime on CPUExecutionProvider,
then executes the same model and feeds through onnxruntime-web/wasm.
The default runner is Node.js; pass runner="chrome" or set
JAX2ONNX_ONNXRUNTIME_WEB_RUNNER=chrome to validate in a real browser via
Playwright/Chromium.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
onnx_model_path
|
str
|
Path to a serialized ONNX model, typically produced with
|
required |
inputs
|
Sequence[Any]
|
Concrete input arrays (or shape tuples, which will be sampled). |
required |
input_params
|
Optional[Dict[str, Any]]
|
Optional named ONNX inputs that correspond to runtime parameters materialized during export. |
None
|
rtol
|
float
|
Relative tolerance for output comparison. |
0.001
|
atol
|
float
|
Absolute tolerance for output comparison. |
1e-05
|
inputs_as_nchw
|
Optional[Sequence[int]]
|
Optional sequence of input indices that should be transposed from NHWC validation arrays to NCHW ONNX feeds. |
None
|
node_command
|
str
|
Node.js executable used to run the Web/WASM validator. |
'node'
|
runner
|
Optional[str]
|
Optional runtime selector. Use |
None
|
Returns:
| Type | Description |
|---|---|
bool
|
|
str
|
contains the Web runner result or failure details. |
Example
Source code in jax2onnx/user_interface.py
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 | |
onnx_function(target=None, *, unique=False, namespace=None, name=None, type=None)
¶
Decorator to mark a function or class as an ONNX function.
This decorator is used to indicate that a function or class should be converted to an ONNX function node when included in a model. It allows the function to be traced and exported as a reusable component with its own namespace in the ONNX graph.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target
|
Optional[OnnxFunctionTarget]
|
The target function or class to decorate. When omitted, the decorator must be called with parentheses. |
None
|
unique
|
bool
|
If True, reuse a single ONNX Function definition for all call sites that share the same callable type and captured parameters. |
False
|
namespace
|
Optional[str]
|
Custom domain prefix for the emitted FunctionProto. Defaults to
|
None
|
name
|
Optional[str]
|
Optional human-readable base name for the ONNX function. When set,
this overrides the callable's Python name for the function |
None
|
type
|
Optional[str]
|
Alias for |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the same target is decorated again with a conflicting namespace or display-name override. |
Returns:
| Type | Description |
|---|---|
Union[OnnxFunctionTarget, OnnxFunctionDecorator]
|
The decorated function or class with ONNX function capabilities. |
Example
import jax.numpy as jnp
from flax import nnx
from jax2onnx import onnx_function
@onnx_function
def my_custom_op(x, y):
return jnp.sin(x) * y
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, features, rngs):
self.dense = nnx.Linear(features, rngs=rngs)
self.activation = nnx.relu
def __call__(self, x):
return self.activation(self.dense(x))
Source code in jax2onnx/user_interface.py
to_onnx(fn, inputs, input_params=None, model_name='jax_model', opset=23, *, enable_double_precision=False, record_primitive_calls_file=None, return_mode='proto', output_path=None, inputs_as_nchw=None, outputs_as_nchw=None, input_names=None, output_names=None, export_mode='standard')
¶
to_onnx(
fn: Callable[..., Any],
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["proto"] = ...,
output_path: None = ...,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...,
export_mode: ExportMode = ...
) -> onnx.ModelProto
to_onnx(
fn: Callable[..., Any],
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["ir"],
output_path: Optional[PathLikeStr] = ...,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...,
export_mode: ExportMode = ...
) -> ir.Model
to_onnx(
fn: Callable[..., Any],
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["file"],
output_path: PathLikeStr,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...,
export_mode: ExportMode = ...
) -> str
Converts a JAX function or model into an ONNX model.
This function serves as the main entry point for converting JAX/Flax models to ONNX format. It supports dynamic shapes and additional runtime parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
The JAX function or Flax module to convert. |
required |
inputs
|
Sequence[UserInputSpec]
|
Sequence of input specifications. Each entry may be:
* a |
required |
input_params
|
Optional[Mapping[str, object]]
|
Optional mapping of string keys to runtime parameters that
should be exposed as inputs in the ONNX model rather than baked into
the export (e.g. |
None
|
model_name
|
str
|
Name to give the ONNX model. Defaults to "jax_model". |
'jax_model'
|
opset
|
int
|
ONNX opset version to target. Defaults to 23. |
23
|
enable_double_precision
|
bool
|
If True, export tensors as tensor(double). Defaults to False (use tensor(float)). |
False
|
record_primitive_calls_file
|
Optional[str]
|
Optional path to a file. If provided, details of each JAX primitive encountered during conversion will be recorded to this file. This log can be used by developers to manually create new test cases. Defaults to None (disabled). |
None
|
return_mode
|
ReturnMode
|
Output mode. |
'proto'
|
output_path
|
Optional[PathLikeStr]
|
Destination path (str or PathLike) required when |
None
|
inputs_as_nchw
|
Optional[Sequence[int]]
|
Optional sequence of input indices (0-based) that should be treated as NCHW layout. If specified for an input, jax2onnx assumes the external input is NCHW and will automatically transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images). This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics. |
None
|
outputs_as_nchw
|
Optional[Sequence[int]]
|
Optional sequence of output indices (0-based) that should be treated as NCHW layout. If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically transpose the NHWC output derived from JAX graph to NCHW before returning it. |
None
|
input_names
|
Optional[Sequence[str]]
|
Optional sequence of names for positional model inputs.
When provided, names are assigned by positional index (0-based) after conversion.
These names do not apply to entries supplied through |
None
|
output_names
|
Optional[Sequence[str]]
|
Optional sequence of names for model outputs (0-based). Names are assigned after conversion in output order. |
None
|
export_mode
|
ExportMode
|
Serialization profile. |
'standard'
|
Returns:
| Type | Description |
|---|---|
Union[ModelProto, Model, str]
|
|
Union[ModelProto, Model, str]
|
|
Union[ModelProto, Model, str]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ValueError
|
If |
TypeError
|
If |
Example
import jax.numpy as jnp
from jax2onnx import to_onnx
def linear(x, w, b):
return jnp.dot(x, w) + b
input_specs = [
("B", 32), # x: [Batch, 32]
(32, 10), # w: [32, 10]
(10,), # b: [10]
]
to_onnx(
linear,
inputs=input_specs,
model_name="linear_model",
return_mode="file",
output_path="linear_model.onnx",
)
Source code in jax2onnx/user_interface.py
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 | |