Architecture Overview¶
jax2onnx is a library for converting JAX functions and Flax modules to the ONNX format. It enables deployment of JAX-based models to any runtime that supports ONNX, such as ONNX Runtime, TensorRT, or CoreML.
The library is designed around two core principles:
- Plugin-based extensibility: All operator-specific logic lives in plugins. The core converter knows nothing about specific operations like convolutions or attention—it only orchestrates the conversion process.
- Minimal core, maximal flexibility: The converter is a thin, generic engine that traces JAX programs and delegates lowering to plugins, keeping the architecture clean and maintainable.
Big idea¶
The converter is a tiny, generic JAXPR → IR engine. It knows nothing about NNX, Conv, Pool, or any specific op. Its only job is:
- Discover plugins (inversion of control),
- Activate whatever they declare (monkey-patching to produce crisp primitives),
- Trace your function to a ClosedJaxpr,
- Lower each equation by handing it to a plugin that claimed that primitive,
- Assemble an IR graph,
- Optimize the IR graph with a small, safe, plugin-agnostic pass,
- Finalize a valid ONNX model (stamp shapes/dtypes, prune, serialize).
Everything op-specific — layouts, padding math, attribute shapes, NHWC↔NCHW, etc. — stays in plugins.
Related documentation¶
- Plugin System Guide – detailed guide on writing plugins.
- ONNX IR Builder Guide – canonical builder guardrails and examples.
- Expect Graph Reference – structural test patterns for
expect_graph. - Subgraph Input Handling – control-flow body wiring (If/Loop/Scan).
- Supported Components – autogenerated support matrices for primitives/examples.
- Past Versions – changelog snapshots for each jax2onnx release.
Roles & responsibilities¶
Core (plugin-agnostic)¶
- Plugin discovery. Recursively import the plugins directory (pattern:
plugins/*). Plugins self-register into a registry keyed by primitive name (string). The core never sees concrete classes likennx.Conv. - Activation window. Core enters a context that applies whatever patches plugins declare. This context wraps tracing so patched high-level calls (e.g.,
nnx.Conv.__call__) emit the right primitive names. No allowlists; no special-cases. - Tracing.
make_jaxpr(fn)(*shape_specs)yields a ClosedJaxpr:(constvars, invars, eqns, outvars). -
IR assembly. Walk equations in order; for each equation:
-
Look up
PLUGIN_REGISTRY[eqn.primitive.name]. - Give it the equation and a lowering context; it emits IR nodes/values.
- Assert that every
eqn.outvars[i]is bound to an IR value before moving on (generic guardrail). - Converters/plugins emit new ops through
ctx.builderso constants and_outputsstay consistent across ONNX IR variants (see ONNX IR Builder Guide). - IR optimization (safe, structure-only). Run small, local rewrites that don’t encode op semantics, e.g. folding redundant layout ping-pongs (see below).
- Finalize. Add model inputs/outputs, stamp symbolic dim labels (e.g.
"B"), prune dead nodes/initializers, serializeModelProto.
Plugin (op-specific)¶
Each plugin describes one primitive (or one high-level function). It has three standard pieces:
-
Binding specs (monkey-patching). “When a user calls X, bind primitive named P.” Example: patch
flax.nnx.Conv.__call__so the traced program containsprimitive.name == "nnx.conv". If NNX exposes multiple symbols, the plugin lists them all. The core just applies what’s declared. -
Abstract eval (shape/dtype). Given JAX abstract values (
ShapedArray), return the result’s abstract value (or tuple). No real compute; just shape math (uselax.*if helpful). This is used by JAX during tracing. -
Lowering (IR emission). Given a
LoweringContextand the equation: -
Pull IR inputs via
ctx.get_value_for_var(eqn.invars[i]). - Create IR nodes (Conv, Transpose, Reshape, …).
- Produce IR outputs and bind them to
eqn.outvars[i]viactx.bind_value_for_var(...). - Return nothing (binding suffices) or return the produced values (the core will bind any unbound outvars generically).
That’s it. The contract is tiny and uniform across all primitives.
Plugin guardrails (must-follow)¶
To keep conversions portable across onnx_ir variants, every plugin lowering must observe these project-wide rules:
- Builder-first: emit ops via
ctx.builder(see ONNX IR Builder Guide)._outputsmust always be a sequence; constants come frombuilder.initializer(...)orctx.bind_const_for_var(...). - Metadata stamping: after every builder call, stamp dtype/shape on the produced value and run
_ensure_value_metadata(...)to normalize their.Valuemetadata (there is no separatevalue_inforegistry). Legacy reminder: the converter removedbuilder.value_info; all shape/type metadata must travel with the values themselves. - Single-use RNG / module construction: never seed at import time. Expose stochastic callables with
construct_and_call(...).with_rng_seed(...)/.with_requested_dtype(...)so the test harness can rebuild modules for both f32/f64 variants without clashes (see agent guidelines). - No protobuf in converters/plugins: only the top-level adapters touch
onnx(protobuf types). Policy tests under policy tests enforce this.
Data flow end-to-end¶
User fn + shape specs
│
▼
[Activation Context] ←— plugins declare patches; core applies them (no names)
│
▼
ClosedJaxpr = make_jaxpr(fn)(*specs)
│
├── constvars → ctx.bind_const_for_var(...)
├── invars → ctx.add_input_for_invar(...)
│
└── eqns: e₀, e₁, …, eₙ
│
├─ core reads eᵢ.primitive.name (string)
├─ plugin = REGISTRY[name]
├─ plugin.lower(ctx, eᵢ) ← emits IR nodes
└─ core asserts eᵢ.outvars all bound
│
└── outvars → ctx.add_outputs_from_vars(...)
│
▼
IR graph → **IR optimizer** → stamp shapes/symbols → prune → ONNX ModelProto
No step above references “Conv”, “Tanh”, or any specific op in the core. All knowledge sits behind the primitive name string chosen by the plugin.
Conversion pipeline (detailed)¶
The full conversion pipeline spans two modules: Conversion API (core conversion) and User Interface (export facade). The following table shows the exact order:
| Step | Module | Location | Purpose |
|---|---|---|---|
| 1 | conversion_api |
to_onnx |
Build raw IR: trace JAXPR, lower equations to nodes |
| 2 | conversion_api |
optimize_graph |
Structural optimization: dead node removal, CSE, constant lifting, reshape folding |
| 3 | conversion_api |
Late overrides | Apply user attribute patches to surviving nodes; fix Concat axis |
| 4 | conversion_api |
run_optional_shape_inference |
(Reserved for future shape inference; currently no-op) |
| 5 | conversion_api |
_finalize_model_value_shapes |
Normalize symbolic dims to ir.SymbolicDim objects |
| 6 | conversion_api |
Return | Model has precise shapes preserved |
| 7 | user_interface |
postprocess_ir_model |
Shape loosening: replace intermediate value shapes with dynamic dims for ORT flexibility |
| 8 | user_interface |
_materialize_input_params_on_ir |
Expose input_params as explicit graph inputs |
| 9 | user_interface |
Serialize | Convert to proto / save to file |
Why this order?¶
- Optimize before patching: Dead node removal runs first so we don't waste time patching nodes that will be deleted.
- Finalize before loosening:
conversion_apinormalizes shapes while they are precise. Loosening (Step 7) is intentionally AFTER to preserve accuracy for shape inference and finalization. - Loosening is export-only:
postprocess_ir_modelis called only by the user-facingto_onnxfunction, not by internal pipelines.
Module responsibilities¶
| Module | Responsibility |
|---|---|
| IR Optimizer | Pure optimization passes (DCE, CSE, constant lifting, reshape folding) |
| Conversion API | Core conversion + optimization + finalization (returns precise-shape model) |
| IR Postprocess | Export preparation: shape loosening for runtime flexibility |
| User Interface | Public API facade: orchestrates conversion → postprocess → serialize |
The lowering context (what plugins see)¶
A small, stable API:
-
get_value_for_var(var, *, name_hint=None) -> IRValueMaterialize (or retrieve) the IR value corresponding to a JAX var (const/invar/intermediate). Handles literals by creating constant initializers. -
bind_value_for_var(var, value: IRValue)Declare that this IR value is the output ofvar(an equation outvar). This is the only binding contract the core depends on. -
Minimal utilities the plugin can rely on (implemented once in the core):
-
Name generator (
fresh_name), - Helpers for constants and attributes,
- Optionally a couple of generic IR helpers (
emit_node, tiny wrappers for Shape/Gather/Unsqueeze where dynamic dims are needed).
Output binding pattern (must follow)¶
When lowering, always reuse the IRValue pre-allocated for each equation outvar. The canonical flow is:
- Fetch inputs via
ctx.get_value_for_var(eqn.invars[i]). - Pre-allocate outputs with
ctx.get_value_for_var(eqn.outvars[i]). - Emit nodes whose
outputs=[...]point to those pre-allocated values (write final results directly into them). - Optionally stamp dtype/shape metadata on the same values.
Avoid producing temporary outputs and “attaching” them afterwards; that pattern bypasses the var→value map and leads to orphaned tensors. Keeping the contract tight here means downstream equations always receive the correct tensor without extra bookkeeping.
Builder vs. Tape¶
Every lowering runs against the same IR backend but plugins are expected to work through
ctx.builder whenever possible:
- The builder records nodes/initializers in one place and mirrors
_tape.Buildersemantics across converters, plugins, and tests. The policy suite (IR builder contract test and IR builder usage checker) assumes_outputs/initializer calls flow through the builder helpers. - Drop down to
onnx_ir.tape.Tape(or constructir.Nodemanually) only when you need features the builder cannot express—overload selection, custom output reuse, or metadata props. When you do, restore dtype/shape metadata manually and write back to all graph mirrors (graph.nodes,_nodes, etc.). - The ONNX IR Builder Guide collects the concrete guardrails and
example snippets. Treat it as the source of truth for
_outputs, initializer naming, and RNG/dtype conventions referenced by the policy tests.
Functions (converter-owned call boundaries)¶
Decorators such as @onnx_function register a plugin that lowers the call into a
FunctionScope. At runtime the handler:
- Builds a
FunctionKeyfrom the qualified target name, input aval signature, and capture signature (class instance id/config). - Reuses an existing definition if the key is cached; otherwise it opens a
FunctionScope, maps parent inputs to fresh function inputs, and recursively lowers the original callable inside the child IRContext (constants are emitted asConstantnodes because FunctionProto bodies cannot own initializers). - Seals the scope into an
onnx_ir.Function, records any attribute overrides, and caches the result. - Emits a call-site node in the parent graph with
domain/op_typematching the function and wires the original in/out values to that node.
Because the definition is keyed on avals and capture signature, identical calls share a single function body, while shape/config changes trigger new entries.
Shapes & symbolic dims¶
- Inputs. If the user gives symbolic strings (e.g.,
"B"), the core creates JAX symbolic dims so the jaxpr records symbols instead of numbers. - Abstract eval. Preserve
_DimExprsymbols—calljax.eval_shapeon the original callable (withShapeDtypeStructinputs) instead of doing manual shape math. Never cast symbolic dims to ints. -
Dynamic shapes in IR. When an IR op needs runtime sizes (e.g., a flatten), plugins use:
-
Shape(x)→ shape vector, Gather(shape, axis=i)→ ith dimension,Unsqueeze/Concat→ assemble a runtime shape tensor,Reshape(x, shape_tensor).- Output stamping. After lowering, the core restamps inputs/outputs so symbolic labels survive through ONNX
ValueInfo(only where no concrete size is present).IRContexttracks the origin tensor/axis for each symbol so helpers likedim_as_valuecan materialize runtime shapes viaShape → Gather → Squeeze.
Determinism & graph hygiene¶
- Deterministic names. The core’s
fresh_nameyields deterministic per-graph names; initializers keep stable names based on plugin hints (when feasible). - Single-node policy. If a plugin needs
Reshape, it emits oneReshapeand a single constant shape initializer if static; it avoids const-onlyConcat. - Pruning. A simple backwards mark from graph outputs removes dead nodes and unused initializers. This keeps
match="exact"tests strict. - No dangling inputs. The core asserts every outvar is bound; graph is built in jaxpr order so edges are naturally well-formed.
Graph structure specs¶
- Post-conversion graph checks live directly in metadata. Use
expect_graph([...], …)so the contract is visible next to each testcase. - When the lowered structure changes, run
poetry run python scripts/emit_expect_graph.py <testcase>to regenerate a canonical snippet; paste the result into the plugin metadata (script: expect_graph emitter). - The helper relies on
auto_expect_graph_specinternally, so it always reflects the current ONNX graph without persisting extra fixtures.
IR optimizer (plugin-agnostic)¶
Before serialization we run a tiny, structure-only optimization sweep. The canonical rules and implementation notes live in IR Optimizer Guide. Today the only pass folds redundant Transpose → [elementwise]* → Transpose pairs when their permutations compose to identity; future passes must follow the same IR-only, backend-agnostic constraints.
Testing expectations (how “exact” works)¶
- Anchored path checks. Tests can say:
Transpose → Conv → Relu → AveragePool → …Withmatch="exact", the test fails if required ops are missing or extra ops are present between anchors. - CNN sentinel. The CNN static test is a canary: if Conv doesn’t lower, flatten will see
{B,14,14,1}and fail a laterReshape(B,3136). With the optimizer, extraTranspose…Transposepairs around Relu are eliminated; pairs around AveragePool remain (by design).
See Expect Graph Reference for a focused reference on writing
expect_graph checks (shapes, counts, symbols, and helper flags).
Typical plugin lifecycle (concrete but generic)¶
-
Register
@register_primitive(jaxpr_primitive="…")puts aPrimitiveLeafPlugininstance in the registry under that string key. The core will later match that key witheqn.primitive.name. -
Patch
binding_specs()(optional) returnsMonkeyPatchSpecs: “replacemodule.symbolwithprim.bind(…)shim”. If there are multiple aliases, the plugin lists them. The core just applies them all. -
Abstract eval
def abstract_eval(*avals, **params):returns aShapedArray(or tuple) describing the outputs. Usejax.eval_shapeonlax.*helpers if that’s easier, but never call the patched function (it would recurse). -
Lower
def lower(ctx, eqn): -
x = ctx.get_value_for_var(eqn.invars[0]) - Create IR nodes (e.g., Transpose, Conv, CastLike, …),
ctx.bind_value_for_var(eqn.outvars[0], y)
That’s the whole contract.
Failure modes & how the architecture contains them¶
-
Patch activation window too late. If activation doesn’t wrap tracing, the jaxpr will never contain the plugin’s primitive names. The core still doesn’t special-case anything; you just see “no plugin for primitive ‘foo’”. Fix = activate around
make_jaxpr. -
Plugin forgets to bind the output. Then the core’s generic guardrail catches it and fails the build at the exact primitive, without central knowledge of op names.
-
Multiple symbols for the same high-level op. Plugins add multiple patch specs. The core applies them all — still no names or allow-lists in the core.
-
An unfinished plugin gets imported. If it also patches a runtime path it can trip tracing/lowering. Fix in the plugin: either complete it or don’t patch until ready. The core does not and should not maintain an allow/deny list.
Architectural Guarantees¶
-
Inversion of control. The only dynamic choice the core makes is: “Given
eqn.primitive.name(a string), ask the registry for a handler.” There is zero knowledge of concrete ops or frameworks. -
Uniform contracts. Every plugin implements the same three hooks. The core only provides generic services (var→value map, name generator, constant creation, and a place to put nodes).
-
No central policy on which plugins are ‘on’. Activation applies whatever plugins declare. If a plugin shouldn’t change tracing yet, it shouldn’t publish a monkey-patch — that decision is local to the plugin, not the core.
TL;DR blueprint (for maintainers)¶
-
Core = small, generic: discover, activate, trace, loop eqns, call
plugin.lower, assert outputs bound, run IR optimizer, finalize ONNX. No plugin names. Ever. -
Plugins = specific: declare patches (all aliases), implement
abstract_eval, implementlower(bind outvars), own all op semantics. -
Context = minimal API for plugins:
get_value_for_var,bind_value_for_var,fresh_name, plus a couple IR conveniences; no framework knowledge. -
Optimizer = tiny, safe, and IR-only: fold layout ping-pongs across pure elementwise ops; match by name or object; never mutate
Node.inputsdirectly — usereplace_input_with.