Skip to content

Plugin System

jax2onnx uses a plugin-based architecture to support JAX primitives. This allows for modular extensibility as JAX evolves or as you need to support custom primitives.

Core Concepts

The system is built around the PrimitiveLeafPlugin class and the @register_primitive decorator found in jax2onnx/plugins/plugin_system.py.

PrimitiveLeafPlugin

To support a JAX primitive (e.g., jax.lax.abs_p), you create a class that inherits from PrimitiveLeafPlugin. This class is responsible for "lowering" the JAX primitive into a sequence of ONNX operations.

@register_primitive

The @register_primitive decorator registers your plugin with the system and provides metadata like the JAX primitive name, documentation links, and test cases.

Creating a Plugin

Here is an example of how to implement a plugin for jax.lax.abs.

1. Import Dependencies

from typing import Any
from jax import core
import jax
from jax2onnx.plugins.plugin_system import PrimitiveLeafPlugin, register_primitive
from jax2onnx.converter.typing_support import LoweringContextProtocol
from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG

JaxprEqn = getattr(core, "JaxprEqn", Any)

2. Implement the Plugin

Define a class that inherits from PrimitiveLeafPlugin and decorate it.

@register_primitive(
    jaxpr_primitive=jax.lax.abs_p.name,
    jax_doc="https://docs.jax.dev/en/latest/_autosummary/jax.lax.abs.html",
    onnx=[
        {
            "component": "Abs",
            "doc": "https://onnx.ai/onnx/operators/onnx__Abs.html",
        }
    ],
    since="0.5.0",
    context="primitives.lax",
    component="abs",
    testcases=[
        {
            "testcase": "abs",
            "callable": lambda x: jax.lax.abs(x),
            "input_shapes": [(3,)],
            "post_check_onnx_graph": EG(
                ["Abs:3"],
                no_unused_inputs=True,
            ),
        }
    ],
)
class AbsPlugin(PrimitiveLeafPlugin):
    def lower(self, ctx: LoweringContextProtocol, eqn: JaxprEqn) -> None:
        # 1. Access input and output variables from the equation
        x_var = eqn.invars[0]
        out_var = eqn.outvars[0]

        # 2. Get the ONNX value for the input
        x_val = ctx.get_value_for_var(x_var, name_hint=ctx.fresh_name("abs_in"))

        # 3. Prepare the output specification (name, type, shape)
        out_spec = ctx.get_value_for_var(out_var, name_hint=ctx.fresh_name("abs_out"))
        desired_name = getattr(out_spec, "name", None) or ctx.fresh_name("abs_out")

        # 4. Create the ONNX node via the builder
        #    ctx.builder.<OpName>(inputs..., _outputs=[names...])
        result = ctx.builder.Abs(x_val, _outputs=[desired_name])

        # 5. Set type/shape info if available
        if getattr(out_spec, "type", None) is not None:
            result.type = out_spec.type
        if getattr(out_spec, "shape", None) is not None:
            result.shape = out_spec.shape

        # 6. Bind the ONNX result back to the JAX output variable
        ctx.bind_value_for_var(out_var, result)

Key Components

LoweringContextProtocol (ctx)

  • ctx.get_value_for_var(var): Retrieves the ONNX value corresponding to a JAX variable.
  • ctx.bind_value_for_var(var, value): Associates an ONNX value with a JAX output variable.
  • ctx.builder: An interface to create ONNX nodes (e.g., ctx.builder.Abs, ctx.builder.Add).
  • ctx.fresh_name(hint): Generates a unique name for ONNX tensors.

Testing

The testcases list in the decorator allows you to define inline tests. These are automatically picked up by the test runner to verify your plugin against the actual JAX behavior and check the generated ONNX graph structure.

Higher-Level Functions

For higher-level functions (like jax.nn.softmax), jax2onnx supports function plugins via @onnx_function or FunctionPlugin. These allow you to map a Python function directly to an ONNX FunctionProto, encapsulating complex logic as a reusable component.

See jax2onnx/plugins/plugin_system.py for more details on function plugins.