API Reference¶
jax2onnx exposes three public entry points: to_onnx(...) for export, @onnx_function for reusable subgraphs, and allclose(...) for JAX-vs-ONNX validation.
Common Export Flow¶
from jax2onnx import to_onnx
model_path = to_onnx(
fn,
inputs=[("B", 128)],
return_mode="file",
output_path="model.onnx",
)
Use string dimensions such as "B" when you want symbolic dynamic axes. For direct file export, set return_mode="file" and provide output_path.
Parameters To Reach For First¶
inputs: Positional input specs, either concrete arrays,ShapeDtypeStructvalues, or shape tuples like("B", 128).input_params: Runtime flags or keyword-like values that should stay model inputs instead of being baked into the export.return_mode:"proto"for anonnx.ModelProto,"ir"for the intermediateonnx_ir.Model, or"file"to serialize directly to disk.enable_double_precision: Temporarily enables x64 export and emitstensor(double)where appropriate.inputs_as_nchw/outputs_as_nchw: Adapt the external ONNX interface to NCHW while keeping the traced JAX computation in its original layout.input_names/output_names: Apply stable user-facing names after conversion.
Use @onnx_function when repeated callables should become reusable ONNX functions, and allclose(...) when you want a quick numerical check against ONNX Runtime after export.
jax2onnx.user_interface
¶
User-facing API for converting JAX functions and models to ONNX.
This module provides the primary interface for exporting JAX/Flax models to the ONNX format. It supports dynamic shapes, runtime parameters, and numerical validation against ONNX Runtime.
Key Functions:
- to_onnx: Convert a JAX function or Flax module to an ONNX model.
- onnx_function: Decorator to mark a function or class as an ONNX function node.
- allclose: Validate numerical equivalence between JAX and ONNX Runtime outputs.
Example
from jax2onnx import to_onnx import jax.numpy as jnp
def my_model(x): ... return jnp.sin(x)
to_onnx(my_model, inputs=[('B', 10)], return_mode="file", output_path="model.onnx")
allclose(fn, onnx_model_path, inputs, input_params=None, rtol=0.001, atol=1e-05, *, enable_double_precision=False, inputs_as_nchw=None, outputs_as_nchw=None)
¶
Checks if JAX and ONNX Runtime outputs remain numerically close.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
JAX callable to compare against the exported ONNX model. |
required |
onnx_model_path
|
str
|
Path to a serialized model that ORT can execute. |
required |
inputs
|
List[Any]
|
Concrete input arrays (or shape tuples, which will be sampled). |
required |
input_params
|
Optional[Dict[str, Any]]
|
Optional keyword arguments applied to both call sites. |
None
|
rtol
|
float
|
Relative tolerance for floating-point comparisons. |
0.001
|
atol
|
float
|
Absolute tolerance for floating-point comparisons. |
1e-05
|
enable_double_precision
|
bool
|
Temporarily enable |
False
|
Returns:
| Type | Description |
|---|---|
bool
|
|
str
|
provides context when a mismatch occurs. |
Example
import jax.numpy as jnp from jax2onnx import to_onnx, allclose
1. Define and Export¶
def my_func(x): ... return jnp.sin(x)
model_path = to_onnx( ... my_func, ... inputs=[('B', 10)], ... return_mode="file", ... output_path="my_model.onnx" ... )
2. Validate¶
Provide concrete shapes for validation (replacing dynamic dim 'B')¶
validation_inputs = [(5, 10)] is_match, msg = allclose(my_func, model_path, inputs=validation_inputs, atol=1e-5)
assert is_match, f"Validation failed: {msg}"
Source code in jax2onnx/user_interface.py
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 | |
onnx_function(target=None, *, unique=False, namespace=None, name=None, type=None)
¶
Decorator to mark a function or class as an ONNX function.
This decorator is used to indicate that a function or class should be converted to an ONNX function node when included in a model. It allows the function to be traced and exported as a reusable component with its own namespace in the ONNX graph.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target
|
Optional[Union[Callable, type]]
|
The target function or class to decorate. When omitted, the decorator must be called with parentheses. |
None
|
unique
|
bool
|
If True, reuse a single ONNX Function definition for all call sites that share the same callable type and captured parameters. |
False
|
namespace
|
Optional[str]
|
Custom domain prefix for the emitted FunctionProto. Defaults to
|
None
|
name
|
Optional[str]
|
Optional human-readable base name for the ONNX function. When set,
this overrides the callable's Python name for the function |
None
|
type
|
Optional[str]
|
Alias for |
None
|
Returns:
| Type | Description |
|---|---|
Union[Callable, type]
|
The decorated function or class with ONNX function capabilities. |
Example
from jax2onnx import onnx_function import jax.numpy as jnp
@onnx_function def my_custom_op(x, y): ... return jnp.sin(x) * y
Also works with Flax modules:¶
from flax import nnx
@onnx_function class MLPBlock(nnx.Module): def init(self, features, rngs): self.dense = nnx.Linear(features, rngs=rngs) self.activation = nnx.relu
def __call__(self, x): return self.activation(self.dense(x))
Source code in jax2onnx/user_interface.py
to_onnx(fn, inputs, input_params=None, model_name='jax_model', opset=23, *, enable_double_precision=False, record_primitive_calls_file=None, return_mode='proto', output_path=None, inputs_as_nchw=None, outputs_as_nchw=None, input_names=None, output_names=None)
¶
to_onnx(
fn: Callable,
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["proto"] = ...,
output_path: None = ...,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...
) -> onnx.ModelProto
to_onnx(
fn: Callable,
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["ir"],
output_path: Optional[PathLikeStr] = ...,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...
) -> ir.Model
to_onnx(
fn: Callable,
inputs: Sequence[UserInputSpec],
input_params: Optional[Mapping[str, object]] = ...,
model_name: str = ...,
opset: int = ...,
*,
enable_double_precision: bool = ...,
record_primitive_calls_file: Optional[str] = ...,
return_mode: Literal["file"],
output_path: PathLikeStr,
inputs_as_nchw: Optional[Sequence[int]] = ...,
outputs_as_nchw: Optional[Sequence[int]] = ...,
input_names: Optional[Sequence[str]] = ...,
output_names: Optional[Sequence[str]] = ...
) -> str
Converts a JAX function or model into an ONNX model.
This function serves as the main entry point for converting JAX/Flax models to ONNX format. It supports dynamic shapes and additional runtime parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The JAX function or Flax module to convert. |
required |
inputs
|
Sequence[UserInputSpec]
|
Sequence of input specifications. Each entry may be:
* a |
required |
input_params
|
Optional[Mapping[str, object]]
|
Optional mapping of string keys to runtime parameters that
should be exposed as inputs in the ONNX model rather than baked into
the export (e.g. |
None
|
model_name
|
str
|
Name to give the ONNX model. Defaults to "jax_model". |
'jax_model'
|
opset
|
int
|
ONNX opset version to target. Defaults to 23. |
23
|
enable_double_precision
|
bool
|
If True, export tensors as tensor(double). Defaults to False (use tensor(float)). |
False
|
record_primitive_calls_file
|
Optional[str]
|
Optional path to a file. If provided, details of each JAX primitive encountered during conversion will be recorded to this file. This log can be used by developers to manually create new test cases. Defaults to None (disabled). |
None
|
return_mode
|
ReturnMode
|
Output mode. |
'proto'
|
output_path
|
Optional[PathLikeStr]
|
Destination path (str or PathLike) required when |
None
|
inputs_as_nchw
|
Optional[Sequence[int]]
|
Optional sequence of input indices (0-based) that should be treated as NCHW layout. If specified for an input, jax2onnx assumes the external input is NCHW and will automatically transpose it to NHWC before feeding it to the JAX graph (which typically expects NHWC for images). This allows the exported ONNX model to accept NCHW inputs while preserving correct graph semantics. |
None
|
outputs_as_nchw
|
Optional[Sequence[int]]
|
Optional sequence of output indices (0-based) that should be treated as NCHW layout. If specified for an output, jax2onnx assumes the external output should be NCHW and will automatically transpose the NHWC output derived from JAX graph to NCHW before returning it. |
None
|
input_names
|
Optional[Sequence[str]]
|
Optional sequence of names for positional model inputs.
When provided, names are assigned by positional index (0-based) after conversion.
These names do not apply to entries supplied through |
None
|
output_names
|
Optional[Sequence[str]]
|
Optional sequence of names for model outputs (0-based). Names are assigned after conversion in output order. |
None
|
Returns:
| Type | Description |
|---|---|
Union[ModelProto, Model, str]
|
|
Union[ModelProto, Model, str]
|
|
Union[ModelProto, Model, str]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ValueError
|
If |
TypeError
|
If |
Example
import jax import jax.numpy as jnp from jax2onnx import to_onnx
Define a simple JAX function¶
def linear(x, w, b): ... return jnp.dot(x, w) + b
Define input shapes: 'B' indicates a dynamic batch dimension¶
input_specs = [ ... ('B', 32), # x: [Batch, 32] ... (32, 10), # w: [32, 10] ... (10,) # b: [10] ... ]
Convert and save to file directly (Recommended)¶
to_onnx( ... linear, ... inputs=input_specs, ... model_name="linear_model", ... return_mode="file", ... output_path="linear_model.onnx" ... )
Source code in jax2onnx/user_interface.py
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 | |