jax2onnx

Big idea

The converter is a tiny, generic JAXPR → IR engine. It knows nothing about NNX, Conv, Pool, or any specific op. Its only job is:

  1. Discover plugins (inversion of control),
  2. Activate whatever they declare (monkey-patching to produce crisp primitives),
  3. Trace your function to a ClosedJaxpr,
  4. Lower each equation by handing it to a plugin that claimed that primitive,
  5. Assemble an IR graph,
  6. Optimize the IR graph with a small, safe, plugin-agnostic pass,
  7. Finalize a valid ONNX model (stamp shapes/dtypes, prune, serialize).

Everything op-specific — layouts, padding math, attribute shapes, NHWC↔NCHW, etc. — stays in plugins.


Related documentation


Roles & responsibilities

Core (plugin-agnostic)

Plugin (op-specific)

Each plugin describes one primitive (or one high-level function). It has three standard pieces:

That’s it. The contract is tiny and uniform across all primitives.

Plugin guardrails (must-follow)

To keep conversions portable across onnx_ir variants, every plugin lowering must observe these project-wide rules:


Data flow end-to-end

User fn + shape specs
       │
       ▼
[Activation Context]  ←— plugins declare patches; core applies them (no names)
       │
       ▼
ClosedJaxpr = make_jaxpr(fn)(*specs)
       │
       ├── constvars  → ctx.bind_const_for_var(...)
       ├── invars     → ctx.add_input_for_invar(...)
       │
       └── eqns:  e₀, e₁, …, eₙ
                  │
                  ├─ core reads eᵢ.primitive.name (string)
                  ├─ plugin = REGISTRY[name]
                  ├─ plugin.lower(ctx, eᵢ)   ← emits IR nodes
                  └─ core asserts eᵢ.outvars all bound
       │
       └── outvars    → ctx.add_outputs_from_vars(...)
       │
       ▼
IR graph → **IR optimizer** → stamp shapes/symbols → prune → ONNX ModelProto

No step above references “Conv”, “Tanh”, or any specific op in the core. All knowledge sits behind the primitive name string chosen by the plugin.


The lowering context (what plugins see)

A small, stable API:

Output binding pattern (must follow)

When lowering, always reuse the IRValue pre-allocated for each equation outvar. The canonical flow is:

  1. Fetch inputs via ctx.get_value_for_var(eqn.invars[i]).
  2. Pre-allocate outputs with ctx.get_value_for_var(eqn.outvars[i]).
  3. Emit nodes whose outputs=[...] point to those pre-allocated values (write final results directly into them).
  4. Optionally stamp dtype/shape metadata on the same values.

Avoid producing temporary outputs and “attaching” them afterwards; that pattern bypasses the var→value map and leads to orphaned tensors. Keeping the contract tight here means downstream equations always receive the correct tensor without extra bookkeeping.

Builder vs. Tape

Every lowering runs against the same IR backend but plugins are expected to work through ctx.builder whenever possible:

Functions (converter-owned call boundaries)

Decorators such as @onnx_function register a plugin that lowers the call into a FunctionScope. At runtime the handler:

  1. Builds a FunctionKey from the qualified target name, input aval signature, and capture signature (class instance id/config).
  2. Reuses an existing definition if the key is cached; otherwise it opens a FunctionScope, maps parent inputs to fresh function inputs, and recursively lowers the original callable inside the child IRContext (constants are emitted as Constant nodes because FunctionProto bodies cannot own initializers).
  3. Seals the scope into an onnx_ir.Function, records any attribute overrides, and caches the result.
  4. Emits a call-site node in the parent graph with domain/op_type matching the function and wires the original in/out values to that node.

Because the definition is keyed on avals and capture signature, identical calls share a single function body, while shape/config changes trigger new entries.


Shapes & symbolic dims


Determinism & graph hygiene


Graph structure specs


IR optimizer (plugin-agnostic)

Before serialization we run a tiny, structure-only optimization sweep. The canonical rules and implementation notes live in docs/dev_guides/ir_optimizer.md. Today the only pass folds redundant Transpose → [elementwise]* → Transpose pairs when their permutations compose to identity; future passes must follow the same IR-only, backend-agnostic constraints.


Testing expectations (how “exact” works)

See docs/dev_guides/expect_graph_reference.md for a focused reference on writing expect_graph checks (shapes, counts, symbols, and helper flags).


Typical plugin lifecycle (concrete but generic)

  1. Register @register_primitive(jaxpr_primitive="…") puts an instance in the registry under that string key. The core will later match that key with eqn.primitive.name.

  2. Patch binding_specs() returns MonkeyPatchSpecs: “replace module.symbol with prim.bind(…) shim”. If there are multiple aliases, the plugin lists them. The core just applies them all.

  3. Abstract eval def abstract_eval(*avals, **params): returns a ShapedArray (or tuple) describing the outputs. Use jax.eval_shape on lax.* helpers if that’s easier, but never call the patched function (it would recurse).

  4. Lower def lower(ctx, eqn):

    • x = ctx.get_value_for_var(eqn.invars[0])
    • Create IR nodes (e.g., Transpose, Conv, CastLike, …),
    • ctx.bind_value_for_var(eqn.outvars[0], y)

That’s the whole contract.


Failure modes & how the architecture contains them


Architectural Guarantees


TL;DR blueprint (for maintainers)

  1. Core = small, generic: discover, activate, trace, loop eqns, call plugin.lower, assert outputs bound, run IR optimizer, finalize ONNX. No plugin names. Ever.

  2. Plugins = specific: declare patches (all aliases), implement abstract_eval, implement lower (bind outvars), own all op semantics.

  3. Context = minimal API for plugins: get_value_for_var, bind_value_for_var, fresh_name, plus a couple IR conveniences; no framework knowledge.

  4. Optimizer = tiny, safe, and IR-only: fold layout ping-pongs across pure elementwise ops; match by name or object; never mutate Node.inputs directly — use replace_input_with.