The converter is a tiny, generic JAXPR → IR engine. It knows nothing about NNX, Conv, Pool, or any specific op. Its only job is:
Everything op-specific — layouts, padding math, attribute shapes, NHWC↔NCHW, etc. — stays in plugins.
docs/dev_guides/onnx_ir_builder.md
– canonical builder guardrails and examples.docs/dev_guides/expect_graph_reference.md
– structural test patterns for expect_graph
.docs/dev_guides/subgraph_input_handling.md
– control-flow body wiring (If/Loop/Scan).docs/readme/coverage_tables.md
– autogenerated support matrices for primitives/examples.docs/readme/past_versions.md
– changelog snapshots for each jax2onnx release.plugins/*
. Plugins self-register into a registry keyed by primitive name (string). The core never sees concrete classes like nnx.Conv
.nnx.Conv.__call__
) emit the right primitive names. No allowlists; no special-cases.make_jaxpr(fn)(*shape_specs)
yields a ClosedJaxpr: (constvars, invars, eqns, outvars)
.IR assembly. Walk equations in order; for each equation:
PLUGIN_REGISTRY[eqn.primitive.name]
.eqn.outvars[i]
is bound to an IR value before moving on (generic guardrail).ctx.builder
so constants and _outputs
stay consistent across ONNX IR variants (see ONNX IR Builder Guide)."B"
), prune dead nodes/initializers, serialize ModelProto
.Each plugin describes one primitive (or one high-level function). It has three standard pieces:
Binding specs (monkey-patching). “When a user calls X, bind primitive named P.”
Example: patch flax.nnx.Conv.__call__
so the traced program contains primitive.name == "nnx.conv"
. If NNX exposes multiple symbols, the plugin lists them all. The core just applies what’s declared.
Abstract eval (shape/dtype). Given JAX abstract values (ShapedArray
), return the result’s abstract value (or tuple). No real compute; just shape math (use lax.*
if helpful). This is used by JAX during tracing.
Lowering (IR emission). Given a LoweringContext
and the equation:
ctx.get_value_for_var(eqn.invars[i])
.eqn.outvars[i]
via ctx.bind_value_for_var(...)
.That’s it. The contract is tiny and uniform across all primitives.
To keep conversions portable across onnx_ir
variants, every plugin lowering must observe these project-wide rules:
ctx.builder
(see ONNX IR Builder Guide). _outputs
must always be a sequence; constants come from builder.initializer(...)
or ctx.bind_const_for_var(...)
._ensure_value_metadata(...)
to normalize the ir.Value
metadata (there is no separate value_info
registry).
Legacy reminder: the converter removed builder.value_info
; all shape/type metadata must travel with the values themselves.construct_and_call(...).with_rng_seed(...)
/ .with_requested_dtype(...)
so the test harness can rebuild modules for both f32/f64 variants without clashes (see AGENTS.md
).onnx
(protobuf types). Policy tests under tests/extra_tests
enforce this.User fn + shape specs
│
▼
[Activation Context] ←— plugins declare patches; core applies them (no names)
│
▼
ClosedJaxpr = make_jaxpr(fn)(*specs)
│
├── constvars → ctx.bind_const_for_var(...)
├── invars → ctx.add_input_for_invar(...)
│
└── eqns: e₀, e₁, …, eₙ
│
├─ core reads eᵢ.primitive.name (string)
├─ plugin = REGISTRY[name]
├─ plugin.lower(ctx, eᵢ) ← emits IR nodes
└─ core asserts eᵢ.outvars all bound
│
└── outvars → ctx.add_outputs_from_vars(...)
│
▼
IR graph → **IR optimizer** → stamp shapes/symbols → prune → ONNX ModelProto
No step above references “Conv”, “Tanh”, or any specific op in the core. All knowledge sits behind the primitive name string chosen by the plugin.
A small, stable API:
get_value_for_var(var, *, name_hint=None) -> IRValue
Materialize (or retrieve) the IR value corresponding to a JAX var (const/invar/intermediate). Handles literals by creating constant initializers.
bind_value_for_var(var, value: IRValue)
Declare that this IR value is the output of var
(an equation outvar). This is the only binding contract the core depends on.
Minimal utilities the plugin can rely on (implemented once in the core):
fresh_name
),emit_node
, tiny wrappers for Shape/Gather/Unsqueeze where dynamic dims are needed).When lowering, always reuse the IRValue pre-allocated for each equation outvar. The canonical flow is:
ctx.get_value_for_var(eqn.invars[i])
.ctx.get_value_for_var(eqn.outvars[i])
.outputs=[...]
point to those pre-allocated values (write final results directly into them).Avoid producing temporary outputs and “attaching” them afterwards; that pattern bypasses the var→value map and leads to orphaned tensors. Keeping the contract tight here means downstream equations always receive the correct tensor without extra bookkeeping.
Every lowering runs against the same IR backend but plugins are expected to work through
ctx.builder
whenever possible:
_tape.Builder
semantics across converters, plugins, and tests. The policy suite (tests/extra_tests/framework/test_ir_builder_contracts.py
and scripts/check_ir_builder_usage.py
) assumes _outputs
/initializer calls flow through
the builder helpers.onnx_ir.tape.Tape
(or construct ir.Node
manually) only when you need features
the builder cannot express—overload selection, custom output reuse, or metadata props. When you do,
restore dtype/shape metadata manually and write back to all graph mirrors (graph.nodes
, _nodes
, etc.)._outputs
, initializer naming, and RNG/dtype
conventions referenced by the policy tests.Decorators such as @onnx_function
register a plugin that lowers the call into a
FunctionScope. At runtime the handler:
FunctionKey
from the qualified target name, input aval signature,
and capture signature (class instance id/config).FunctionScope
, maps parent inputs to fresh function inputs, and recursively
lowers the original callable inside the child IRContext (constants are emitted
as Constant
nodes because FunctionProto bodies cannot own initializers).onnx_ir.Function
, records any attribute overrides,
and caches the result.domain/op_type
matching the
function and wires the original in/out values to that node.Because the definition is keyed on avals and capture signature, identical calls share a single function body, while shape/config changes trigger new entries.
"B"
), the core creates JAX symbolic dims so the jaxpr records symbols instead of numbers._DimExpr
symbols—call jax.eval_shape
on the
original callable (with ShapeDtypeStruct
inputs) instead of doing manual
shape math. Never cast symbolic dims to ints.Dynamic shapes in IR. When an IR op needs runtime sizes (e.g., a flatten), plugins use:
Shape(x)
→ shape vector,Gather(shape, axis=i)
→ ith dimension,Unsqueeze/Concat
→ assemble a runtime shape tensor,Reshape(x, shape_tensor)
.ValueInfo
(only where no concrete size is present). IRContext
tracks the origin tensor/axis for each symbol so helpers like dim_as_value
can materialize runtime shapes via Shape → Gather → Squeeze
.fresh_name
yields deterministic per-graph names; initializers keep stable names based on plugin hints (when feasible).Reshape
, it emits one Reshape
and a single constant shape initializer if static; it avoids const-only Concat
.match="exact"
tests strict.expect_graph([...], …)
so the contract is visible next to each testcase.poetry run python scripts/emit_expect_graph.py <testcase>
to regenerate a canonical snippet; paste the result into the plugin metadata.auto_expect_graph_spec
internally, so it always reflects the current ONNX graph without persisting extra fixtures.Before serialization we run a tiny, structure-only optimization sweep. The canonical rules and implementation notes live in docs/dev_guides/ir_optimizer.md
. Today the only pass folds redundant Transpose → [elementwise]* → Transpose
pairs when their permutations compose to identity; future passes must follow the same IR-only, backend-agnostic constraints.
Transpose → Conv → Relu → AveragePool → …
With match="exact"
, the test fails if required ops are missing or extra ops are present between anchors.{B,14,14,1}
and fail a later Reshape(B,3136)
. With the optimizer, extra Transpose…Transpose
pairs around Relu are eliminated; pairs around AveragePool remain (by design).See docs/dev_guides/expect_graph_reference.md
for a focused reference on writing
expect_graph
checks (shapes, counts, symbols, and helper flags).
Register
@register_primitive(jaxpr_primitive="…")
puts an instance in the registry under that string key. The core will later match that key with eqn.primitive.name
.
Patch
binding_specs()
returns MonkeyPatchSpec
s: “replace module.symbol
with prim.bind(…)
shim”. If there are multiple aliases, the plugin lists them. The core just applies them all.
Abstract eval
def abstract_eval(*avals, **params):
returns a ShapedArray
(or tuple) describing the outputs. Use jax.eval_shape
on lax.*
helpers if that’s easier, but never call the patched function (it would recurse).
Lower
def lower(ctx, eqn):
x = ctx.get_value_for_var(eqn.invars[0])
ctx.bind_value_for_var(eqn.outvars[0], y)
That’s the whole contract.
Patch activation window too late. If activation doesn’t wrap tracing, the jaxpr will never contain the plugin’s primitive names. The core still doesn’t special-case anything; you just see “no plugin for primitive ‘foo’”. Fix = activate around make_jaxpr
.
Plugin forgets to bind the output. Then the core’s generic guardrail catches it and fails the build at the exact primitive, without central knowledge of op names.
Multiple symbols for the same high-level op. Plugins add multiple patch specs. The core applies them all — still no names or allow-lists in the core.
An unfinished plugin gets imported. If it also patches a runtime path it can trip tracing/lowering. Fix in the plugin: either complete it or don’t patch until ready. The core does not and should not maintain an allow/deny list.
Inversion of control. The only dynamic choice the core makes is:
“Given eqn.primitive.name
(a string), ask the registry for a handler.”
There is zero knowledge of concrete ops or frameworks.
Uniform contracts. Every plugin implements the same three hooks. The core only provides generic services (var→value map, name generator, constant creation, and a place to put nodes).
No central policy on which plugins are ‘on’. Activation applies whatever plugins declare. If a plugin shouldn’t change tracing yet, it shouldn’t publish a monkey-patch — that decision is local to the plugin, not the core.
Core = small, generic: discover, activate, trace, loop eqns, call plugin.lower
, assert outputs bound, run IR optimizer, finalize ONNX.
No plugin names. Ever.
Plugins = specific: declare patches (all aliases), implement abstract_eval
, implement lower
(bind outvars), own all op semantics.
Context = minimal API for plugins: get_value_for_var
, bind_value_for_var
, fresh_name
, plus a couple IR conveniences; no framework knowledge.
Optimizer = tiny, safe, and IR-only: fold layout ping-pongs across pure elementwise ops; match by name or object; never mutate Node.inputs
directly — use replace_input_with
.