Plugin System¶
The plugin system is the extension point that keeps primitive-specific lowering
out of the converter core. A plugin teaches jax2onnx how to recognize a JAX
primitive or high-level callable during tracing, and how to lower the resulting
JAXPR equation into onnx_ir.
Read this page together with:
- Architecture Overview for the core/plugin boundary.
- ONNX IR Builder Guide for builder and metadata rules.
- Expect Graph Reference for structural testcase checks.
- ONNX Functions for user-facing
@onnx_functionbehavior.
Responsibilities¶
| Layer | Owns |
|---|---|
conversion_api |
Plugin discovery/activation, JAXPR tracing, equation dispatch, value binding checks, model finalization. |
PrimitiveLeafPlugin |
Primitive metadata, optional tracing-time bindings, abstract eval, and ONNX IR lowering. |
FunctionPlugin |
@onnx_function registration, function-body tracing, FunctionScope lowering, call-site emission. |
register_example |
Example/test metadata that does not introduce a new primitive. |
| Generated tests/docs | Numeric checks, structural checks, component coverage, examples, and ONNX operator coverage derived from metadata. |
The converter dispatches by eqn.primitive.name. It does not choose how a
primitive maps to ONNX. That mapping belongs in the plugin.
Registry and Activation¶
Plugins live under jax2onnx/plugins/. import_all_plugins() recursively
imports that tree once per process. Imports are important: registration happens
as a side effect of importing plugin modules.
There are three main registries:
| Registry | Key | Value |
|---|---|---|
PLUGIN_REGISTRY |
JAXPR primitive name string | PrimitiveLeafPlugin or FunctionPlugin instance |
ONNX_FUNCTION_PLUGIN_REGISTRY |
qualified Python target name | FunctionPlugin reference |
EXAMPLE_REGISTRY |
{context}::{component} |
example metadata from register_example(...) |
During conversion, _activate_plugin_worlds() does this:
- import all plugins;
- apply registered function patches through
apply_monkey_patches(); - enter every
PrimitiveLeafPlugin.plugin_binding()context; - bind leaf-plugin abstract eval where needed;
- backfill allowlisted transpose rules for custom primitives;
- trace with
jax.make_jaxpr; - restore all patches when tracing exits.
Nested ONNX Function body tracing uses a matching activation path so function bodies see the same primitive bindings as the top-level trace.
Primitive Plugins¶
Most primitive support is implemented as a subclass of PrimitiveLeafPlugin
registered with @register_primitive(...).
from typing import Any
import jax
import onnx_ir as ir
from jax import core
from jax2onnx.converter.typing_support import LoweringContextProtocol
from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG
from jax2onnx.plugins._ir_shapes import _ensure_value_metadata, _stamp_type_and_shape
from jax2onnx.plugins.plugin_system import PrimitiveLeafPlugin, register_primitive
JaxprEqn = getattr(core, "JaxprEqn", Any)
@register_primitive(
jaxpr_primitive=jax.lax.abs_p.name,
jax_doc="https://docs.jax.dev/en/latest/_autosummary/jax.lax.abs.html",
onnx=[{"component": "Abs", "doc": "https://onnx.ai/onnx/operators/onnx__Abs.html"}],
since="0.5.0",
context="primitives.lax",
component="abs",
testcases=[
{
"testcase": "abs",
"callable": lambda x: jax.lax.abs(x),
"input_shapes": [(3,)],
"post_check_onnx_graph": EG(["Abs:3"], no_unused_inputs=True),
}
],
)
class AbsPlugin(PrimitiveLeafPlugin):
def lower(self, ctx: LoweringContextProtocol, eqn: JaxprEqn) -> None:
x_val = ctx.get_value_for_var(eqn.invars[0])
out_val = ctx.get_value_for_var(eqn.outvars[0], name_hint=ctx.fresh_name("abs"))
result = ctx.builder.Abs(x_val, _outputs=[out_val.name or ctx.fresh_name("abs")])
result.type = out_val.type
result.shape = out_val.shape
_stamp_type_and_shape(result, tuple(getattr(eqn.outvars[0].aval, "shape", ())))
_ensure_value_metadata(ctx, result)
ctx.bind_value_for_var(eqn.outvars[0], result)
This is the common shape. More complex plugins may also define a custom JAX
primitive (_PRIM), abstract_eval, batching rules, JVP/transpose rules, and
binding specs.
Registering Metadata¶
@register_primitive(...) stores every keyword argument as plugin metadata and
registers the plugin instance under jaxpr_primitive.
Important top-level metadata fields:
| Field | Purpose |
|---|---|
jaxpr_primitive |
Registry key. Must match eqn.primitive.name. |
jax_doc |
Source documentation for the JAX/Flax API. Used by generated docs. |
onnx |
List of ONNX components this lowering emits or depends on. |
since |
First jax2onnx version that supports the component. |
context |
Grouping namespace such as primitives.lax or primitives.nnx. |
component |
Stable component name used for generated tests and docs. |
description |
Optional human-readable summary. |
children |
For examples, child components used by the example. |
testcases |
Generated test inputs, expectations, and structural checks. |
Avoid duplicate jaxpr_primitive keys. The registry is a dictionary, so a later
import can overwrite an earlier entry; that is not a supported extension model.
Binding Specs¶
A plugin only needs binding specs when ordinary JAX tracing would not emit the
primitive you want to lower. Direct jax.lax primitives often do not need
patching. High-level APIs such as jax.numpy.*, Flax NNX modules, or custom
wrappers often do.
New primitive plugins should implement:
from jax2onnx.plugins._patching import AssignSpec, MonkeyPatchSpec
class SomePlugin(PrimitiveLeafPlugin):
@classmethod
def binding_specs(cls) -> list[AssignSpec | MonkeyPatchSpec]:
return [
AssignSpec("jax.numpy", "some_prim", cls._PRIM, delete_if_missing=True),
MonkeyPatchSpec(
target="jax.numpy",
attr="some_function",
make_value=lambda original: make_wrapper(original, cls._PRIM),
),
]
Use AssignSpec to install or expose a primitive object. Use MonkeyPatchSpec
to wrap an existing function or method so tracing calls cls._PRIM.bind(...).
Patch rules:
- Patches must be scoped. They are applied only while tracing and restored on exit.
- The wrapper must call the original implementation for unsupported modes, not silently lower an unsupported case.
- Do not seed, allocate modules, or capture mutable state inside the patch wrapper.
patch_info()exists for older compatibility paths. Preferbinding_specs()for new primitive plugins.
Abstract Eval, Batching, and Autodiff¶
Custom primitives must be usable by JAX before the converter can lower them.
At minimum, a primitive that is emitted during tracing needs abstract eval.
If the primitive can appear under vmap, it also needs a batching rule.
If tests or examples cover gradients, JVP/transpose support must be registered.
Guidelines:
- Define
abstract_eval(*avals, **params)on the plugin when the output shape or dtype can be inferred from abstract values. - Do not call the patched public function from
abstract_eval; that can recurse back into the same primitive. - Register batching rules explicitly for custom primitives used under
vmap. - Use helpers from
jax2onnx/plugins/jax/_autodiff_utils.pyfor JVP and transpose rules instead of writing directly into JAX AD registries. - Conversion-time AD backfill is allowlisted. Use
JAX2ONNX_AD_DEBUG=1when debugging registration decisions, andJAX2ONNX_DISABLE_AD_BACKFILL=1to disable backfill.
Lowering Contract¶
The lowering receives a LoweringContextProtocol and a JAXPR equation. The
contract is strict:
- all non-literal inputs must already be bound before
lower(...)runs; - the plugin emits connected
onnx_irnodes and values; - every non-drop output var must be bound before the equation completes;
- returned
ir.Valueobjects may be bound generically, but explicit binding is clearer for non-trivial lowerings.
Core context methods used most often:
| API | Use |
|---|---|
ctx.get_value_for_var(var, ...) |
Fetch or allocate the IR value for a JAX var. Literals can become constants. |
ctx.require_value_for_var(var, ...) |
Require an existing binding; useful when allocation would hide a bug. |
ctx.bind_value_for_var(var, value) |
Bind an equation output var to a produced ir.Value. |
ctx.bind_const_for_var(var, array) |
Bind a JAXPR constvar. Top-level graphs use initializers; Function bodies use Constant nodes. |
ctx.fresh_name(base) |
Create deterministic per-context names. |
ctx.builder |
Emit ONNX IR nodes through the project builder wrapper. |
ctx.try_evaluate_const(var) |
Evaluate a known constant producer when a lowering needs static data. |
ctx.ensure_external_flag(name, var) |
Route dynamic call parameters into graph/function inputs where needed. |
Lowering guardrails:
- Emit normal nodes through
ctx.builder. - Pass
_outputs=[...], never_outputs="name". - Stamp or preserve dtype/shape metadata on produced values.
- Reuse pre-allocated output values when identity matters.
- Keep converter/plugins IR-only; do not import ONNX protobuf types in lowering code.
- Use shared shape/constant helpers when available instead of ad hoc tensor construction.
The core validates both sides of the equation. assert_eqn_inputs_bound catches
unbound inputs before the plugin runs. finalize_eqn_lowering_outputs catches
missing or disconnected outputs after it runs.
Function Plugins and @onnx_function¶
@onnx_function registers a FunctionPlugin under the primitive name
onnx_fn::<qualified-target-name>. It is user-facing, but its implementation is
part of the plugin system.
During tracing, the function plugin patches the target callable or class
__call__ so a function primitive appears in JAXPR. During lowering, it:
- resolves the original callable or instance;
- builds a
FunctionKeyfrom target name, input avals, and capture signature; - reuses a cached body when the key matches;
- otherwise creates a
FunctionScopechildIRContext; - maps parent inputs and dynamic call parameters to function inputs;
- retraces and lowers the callable body inside the child context;
- emits constants as
Constantnodes because ONNX Function bodies cannot own initializers; - attaches the resulting
onnx_ir.Functionto the parent model; - emits a call-site node in the parent graph.
unique=True changes the capture signature from object identity to a stable
fingerprint of the callable/function state, so equivalent blocks can share one
Function body while different state produces separate bodies. Namespace and type
overrides are validated at registration time to avoid registry collisions.
Keep user-facing usage details in ONNX Functions.
Keep implementation details here and in tests under tests/extra_tests/converter/.
Example Metadata¶
Use register_example(...) when you want generated tests and docs for a
composition of existing primitives, without introducing a new primitive.
from jax2onnx.plugins.plugin_system import (
construct_and_call,
register_example,
with_rng_seed,
)
register_example(
component="CNN",
description="A simple convolutional neural network.",
context="examples.nnx",
children=["nnx.Conv", "nnx.Linear", "nnx.avg_pool", "nnx.relu"],
since="0.2.0",
testcases=[
{
"testcase": "simple_cnn",
"callable": construct_and_call(CNN, rngs=with_rng_seed(0)),
"input_shapes": [("B", 28, 28, 1)],
"expected_output_shapes": [("B", 10)],
}
],
)
Examples are stored in EXAMPLE_REGISTRY, then consumed by the same generated
test and documentation tooling as primitive metadata.
Testcase Metadata¶
Generated tests load all plugin and example metadata, expand dynamic/concrete
variants, expand f32/f64 variants, convert via public to_onnx(...), run
optional structural checks, and run numeric validation when concrete
input_values are available.
Common testcase fields:
| Field | Purpose |
|---|---|
testcase |
Stable generated test/model name. |
callable |
Function or construct_and_call(...) object. |
input_shapes |
Shape specs; "B" triggers dynamic plus concrete variants unless restricted. |
input_dtypes |
Dtypes paired with input_shapes. |
input_values |
Concrete values used for tracing specs and numeric validation. |
input_params |
Keyword parameters supplied to the callable and optionally materialized as graph inputs. |
expected_output_shapes |
Output shape assertions after export. |
expected_output_dtypes |
Output dtype assertions after export. |
post_check_onnx_graph |
expect_graph(...) structural predicate. |
run_only_dynamic |
Skip the concrete variant for symbolic-shape tests. |
run_only_f32_variant |
Skip generated f64 variant. |
run_only_f64_variant |
Generate only the f64 variant. |
disable_float64_test |
Keep f32 only for tests that cannot safely run in f64. |
opset_version |
Override the default opset for this testcase. |
inputs_as_nchw / outputs_as_nchw |
Request public IO layout bridges. |
input_names / output_names |
Public graph IO name overrides. |
skip_numeric_validation |
Keep export/structure checks but skip runtime allclose. |
check_onnx_load |
Run ONNX checker after serialization. |
expected_number_of_function_instances |
Assert expected FunctionProto count for ONNX Function tests. |
callable_factory is no longer supported. Use construct_and_call(...) so the
test generator can instantiate a fresh callable under the requested dtype.
Deterministic Callable Helpers¶
Use these helpers from jax2onnx.plugins.plugin_system in metadata:
| Helper | Use |
|---|---|
construct_and_call(Constructor, *args, **kwargs) |
Store a module constructor and instantiate it per generated dtype. |
with_requested_dtype() |
Placeholder argument resolved to the generated dtype. |
with_rng_seed(seed) |
Placeholder that builds fresh nnx.Rngs(seed) per instantiation. |
with_prng_key(seed) |
Placeholder that builds a fresh JAX PRNG key per instantiation. |
Do not treat dtype/RNG placeholders as methods on the factory object; they are constructor arguments resolved later by the generated test harness.
Pass placeholders as constructor arguments instead:
This keeps module construction deterministic across f32/f64 variants and avoids import-time RNG state.
Generated Coverage¶
Primitive and example metadata is the source for:
- generated pytest classes under
tests/primitives/andtests/examples/; docs/user_guide/supported_components.md;docs/user_guide/examples.md;docs/user_guide/onnx_operator_coverage.md.
Keep metadata user-oriented. It should describe supported components, examples, ONNX operators, and durable tests. Release plans, temporary audits, and one-off migration notes belong outside generated user documentation.
Review Checklist¶
Before adding or changing a plugin:
- Verify the JAXPR primitive name.
- Keep primitive semantics in the plugin, not in
conversion_api. - Use
binding_specs()only when tracing needs a custom primitive. - Implement abstract eval, batching, and AD support needed by the testcase.
- Emit IR through
ctx.builderwith sequence-valued_outputs. - Preserve dtype/shape metadata on produced values.
- Add focused
expect_graph(...)checks for durable structure. - Use
construct_and_call(...)and RNG/dtype placeholders for module examples. - Run the focused generated pytest target for the component.
- Run builder/policy checks when touching lowering infrastructure.
Failure Modes¶
| Symptom | Likely cause | Fix |
|---|---|---|
No plugins registered for primitive ... |
Plugin module was not imported, wrong jaxpr_primitive, or binding did not emit the intended primitive. |
Check plugin import path, registry key, and binding_specs(). |
| JAX tracing recurses or hangs | abstract_eval or a patch wrapper calls the patched public function again. |
Call a lax.* helper, the stored original implementation, or pure shape logic. |
unbound input before lowering |
A nested body/function input was not mapped, or a literal/const was not bound. | Fix body input wiring or use require_value_for_var to expose the missing var. |
did not bind output |
Lowering forgot to bind or return a produced value. | Bind every non-drop outvar or return the exact ir.Value sequence. |
disconnected value |
Lowering bound a placeholder that is not produced by a graph node/input/initializer. | Emit the node with that value as output, or bind the actual produced value. |
| Extra graph inputs remain | Lowering introduced unused inputs or function params. | Add focused structural checks and inspect pruning/signature rules. |
| f64 variant fails unexpectedly | Callable construction or metadata ignores generated dtype. | Use construct_and_call(...) with with_requested_dtype() placeholders. |