jax2onnx

Plugin Quickstart

This walkthrough shows how to add a new plugin to jax2onnx. There are three common flavours:

Plugin flavour Purpose Canonical example
Low-level primitive Wrap an existing JAX primitive such as jax.lax.abs. Lowerings generally emit a straight ONNX op. jax2onnx/plugins/jax/lax/abs.py
High-level primitive / function Provide a composed op (e.g. jax.nn.dot_product_attention, MultiHeadAttention) or a custom @onnx_function. Often manages RNG helpers, symbol binding, or multiple ONNX ops. jax2onnx/plugins/jax/nn/dot_product_attention.py
Example plugin Expose an end-to-end regression example for docs/tests; lives under the examples.* namespace. jax2onnx/plugins/examples/jnp/select.py

Whichever flavour you choose, the contract is identical:

  1. register metadata so the test generator knows how to rebuild the callable,
  2. implement a lowering that emits ONNX IR via the shared builder helpers, and
  3. add an expect_graph snippet so structural regressions stay locked down.

The walkthrough below uses a low-level primitive (abs) because it is the smallest template. For high-level or example plugins, follow the same steps and refer back to the table above for richer real-world samples.


1. Pick a Template

Start from a plugin that matches the flavour you need:

Create a new file under jax2onnx/plugins/<namespace>/... and rename the class and metadata appropriately.

# jax2onnx/plugins/jax/lax/my_primitive.py
import jax

from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG
from jax2onnx.plugins.plugin_system import PrimitiveLeafPlugin, register_primitive

Tip: keep the namespace consistent (primitives.lax, primitives.nn, etc.). It drives the autogenerated docs and test layout.


2. Register the Primitive

Fill in the @register_primitive decorator:

@register_primitive(
    jaxpr_primitive=jax.lax.abs_p.name,
    context="primitives.lax",
    component="abs",
    onnx=[{"component": "Abs", "doc": "https://onnx.ai/..."}],
    testcases=[
        {
            "testcase": "abs",
            "callable": lambda x: jax.lax.abs(x),
            "input_shapes": [(3,)],
            "post_check_onnx_graph": EG(["Abs:3"], no_unused_inputs=True),
        },
    ],
)

Use construct_and_call(...).with_requested_dtype(...).with_rng_seed(...) when the primitive needs deterministic module construction or RNG split helpers. See jax2onnx/plugins/jax/nn/dot_product_attention.py for a larger example.

Function plugin naming invariants

ONNX function plugins now keep the original callable/class name as the node op_type. Uniqueness lives in the call-site metadata instead:

Update structural expectations (expect_graph, ORT checks, etc.) to key off the op_type when you want to match all instances, and fall back to the full node.name only when a specific call-site matters. Older expectations that referenced Callable_1 continue to pass because the checker strips numeric suffixes when comparing op_type.


3. Implement lower

Fetch inputs and pre-allocated outputs via the lowering context, then emit the ONNX op through ctx.builder:

class AbsPlugin(PrimitiveLeafPlugin):
    def lower(self, ctx, eqn):
        x_var = eqn.invars[0]
        out_var = eqn.outvars[0]

        x_val = ctx.get_value_for_var(x_var, name_hint=ctx.fresh_name("abs_in"))
        out_val = ctx.get_value_for_var(out_var, name_hint=ctx.fresh_name("abs_out"))

        result = ctx.builder.Abs(x_val, _outputs=[out_val.name or ctx.fresh_name("abs_out")])

        # Stamp metadata if the pre-allocated value already knows its type/shape
        if getattr(out_val, "type", None) is not None:
            result.type = out_val.type
        if getattr(out_val, "shape", None) is not None:
            result.shape = out_val.shape

        ctx.bind_value_for_var(out_var, result)

Key guardrails:

The ONNX IR Builder Guide lists every policy enforced by the automated checks.


4. Add a Structural Assertion

The post_check_onnx_graph entry in the testcase calls the structural checker (expect_graph). Use scripts/emit_expect_graph.py to capture the snippet:

poetry run python scripts/emit_expect_graph.py abs

Slot the output into the metadata and rerun the command whenever the lowering changes shape.

For more involved graphs, consult docs/dev_guides/expect_graph_reference.md for matching tips.


5. Run the Tests

At minimum:

poetry run python scripts/check_ir_builder_usage.py --diff
poetry run pytest -q tests/primitives/test_lax.py -k abs
poetry run pytest -q

The pre-commit hooks execute the same checks, but running them locally keeps the feedback loop tight.


6. Submit the PR

Include the plugin file and the updated tests. Reference the example you copied from in your PR description so reviewers know the baseline.

For deeper dives:

Happy lowering!