This walkthrough shows how to add a new plugin to jax2onnx
. There are
three common flavours:
Plugin flavour | Purpose | Canonical example |
---|---|---|
Low-level primitive | Wrap an existing JAX primitive such as jax.lax.abs . Lowerings generally emit a straight ONNX op. |
jax2onnx/plugins/jax/lax/abs.py |
High-level primitive / function | Provide a composed op (e.g. jax.nn.dot_product_attention , MultiHeadAttention ) or a custom @onnx_function . Often manages RNG helpers, symbol binding, or multiple ONNX ops. |
jax2onnx/plugins/jax/nn/dot_product_attention.py |
Example plugin | Expose an end-to-end regression example for docs/tests; lives under the examples.* namespace. |
jax2onnx/plugins/examples/jnp/select.py |
Whichever flavour you choose, the contract is identical:
expect_graph
snippet so structural regressions stay locked down.The walkthrough below uses a low-level primitive (abs
) because it is the
smallest template. For high-level or example plugins, follow the same steps and
refer back to the table above for richer real-world samples.
Start from a plugin that matches the flavour you need:
lax/abs.py
.jax/nn/dot_product_attention.py
for a larger pattern with RNG helpers and multiple ONNX ops.examples/jnp/select.py
.Create a new file under jax2onnx/plugins/<namespace>/...
and rename the class
and metadata appropriately.
# jax2onnx/plugins/jax/lax/my_primitive.py
import jax
from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG
from jax2onnx.plugins.plugin_system import PrimitiveLeafPlugin, register_primitive
Tip: keep the namespace consistent (
primitives.lax
,primitives.nn
, etc.). It drives the autogenerated docs and test layout.
Fill in the @register_primitive
decorator:
jaxpr_primitive
: the name JAX uses in the traced ClosedJaxpr
onnx
: docs for the ONNX op(s) the lowering will emittestcases
: at least one entry for the test generator@register_primitive(
jaxpr_primitive=jax.lax.abs_p.name,
context="primitives.lax",
component="abs",
onnx=[{"component": "Abs", "doc": "https://onnx.ai/..."}],
testcases=[
{
"testcase": "abs",
"callable": lambda x: jax.lax.abs(x),
"input_shapes": [(3,)],
"post_check_onnx_graph": EG(["Abs:3"], no_unused_inputs=True),
},
],
)
Use construct_and_call(...).with_requested_dtype(...).with_rng_seed(...)
when
the primitive needs deterministic module construction or RNG split helpers.
See jax2onnx/plugins/jax/nn/dot_product_attention.py
for a larger example.
ONNX function plugins now keep the original callable/class name as the node
op_type
. Uniqueness lives in the call-site metadata instead:
<Callable>_N
(1-indexed) so graphs remain human
readable.custom
for the first
instance and custom.Callable_2
for the second. The pair (op_type, domain)
stays stable across exports and test runs.Function
opset imports automatically, so ONNX Runtime receives the same domain the
call-site advertises—no manual opset bookkeeping required.Update structural expectations (expect_graph
, ORT checks, etc.) to key off the
op_type
when you want to match all instances, and fall back to the full
node.name
only when a specific call-site matters. Older expectations that
referenced Callable_1
continue to pass because the checker strips numeric
suffixes when comparing op_type
.
lower
Fetch inputs and pre-allocated outputs via the lowering context, then emit the
ONNX op through ctx.builder
:
class AbsPlugin(PrimitiveLeafPlugin):
def lower(self, ctx, eqn):
x_var = eqn.invars[0]
out_var = eqn.outvars[0]
x_val = ctx.get_value_for_var(x_var, name_hint=ctx.fresh_name("abs_in"))
out_val = ctx.get_value_for_var(out_var, name_hint=ctx.fresh_name("abs_out"))
result = ctx.builder.Abs(x_val, _outputs=[out_val.name or ctx.fresh_name("abs_out")])
# Stamp metadata if the pre-allocated value already knows its type/shape
if getattr(out_val, "type", None) is not None:
result.type = out_val.type
if getattr(out_val, "shape", None) is not None:
result.shape = out_val.shape
ctx.bind_value_for_var(out_var, result)
Key guardrails:
ctx.builder.Abs(...)
). Avoid constructing
ir.Node
directly unless you need advanced features._outputs
must be a sequence (["_name"]
, not "name"
).eqn.outvars[i]
using ctx.bind_value_for_var(...)
._const_i64
that route through
the builder so initializers stay registered.The ONNX IR Builder Guide lists every policy enforced by the automated checks.
The post_check_onnx_graph
entry in the testcase calls the structural checker
(expect_graph
). Use scripts/emit_expect_graph.py
to capture the snippet:
poetry run python scripts/emit_expect_graph.py abs
Slot the output into the metadata and rerun the command whenever the lowering changes shape.
For more involved graphs, consult
docs/dev_guides/expect_graph_reference.md
for matching tips.
At minimum:
poetry run python scripts/check_ir_builder_usage.py --diff
poetry run pytest -q tests/primitives/test_lax.py -k abs
poetry run pytest -q
The pre-commit hooks execute the same checks, but running them locally keeps the feedback loop tight.
Include the plugin file and the updated tests. Reference the example you copied from in your PR description so reviewers know the baseline.
For deeper dives:
jax2onnx/plugins/jax/lax/abs.py
–
full, working examplejax2onnx/plugins/jax/nn/dot_product_attention.py
– high-level primitive with RNG + multi-op loweringjax2onnx/plugins/examples/jnp/select.py
– minimal example plugindocs/design.md
– architecture overview and plugin rolesHappy lowering!