Plugin System¶
jax2onnx uses a plugin-based architecture to support JAX primitives. This allows for modular extensibility as JAX evolves or as you need to support custom primitives.
Core Concepts¶
The system is built around the PrimitiveLeafPlugin class and the @register_primitive decorator found in jax2onnx/plugins/plugin_system.py.
PrimitiveLeafPlugin¶
To support a JAX primitive (e.g., jax.lax.abs_p), you create a class that inherits from PrimitiveLeafPlugin. This class is responsible for "lowering" the JAX primitive into a sequence of ONNX operations.
@register_primitive¶
The @register_primitive decorator registers your plugin with the system and provides metadata like the JAX primitive name, documentation links, and test cases.
Creating a Plugin¶
Here is an example of how to implement a plugin for jax.lax.abs.
1. Import Dependencies¶
from typing import Any
from jax import core
import jax
from jax2onnx.plugins.plugin_system import PrimitiveLeafPlugin, register_primitive
from jax2onnx.converter.typing_support import LoweringContextProtocol
from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG
JaxprEqn = getattr(core, "JaxprEqn", Any)
2. Implement the Plugin¶
Define a class that inherits from PrimitiveLeafPlugin and decorate it.
@register_primitive(
jaxpr_primitive=jax.lax.abs_p.name,
jax_doc="https://docs.jax.dev/en/latest/_autosummary/jax.lax.abs.html",
onnx=[
{
"component": "Abs",
"doc": "https://onnx.ai/onnx/operators/onnx__Abs.html",
}
],
since="0.5.0",
context="primitives.lax",
component="abs",
testcases=[
{
"testcase": "abs",
"callable": lambda x: jax.lax.abs(x),
"input_shapes": [(3,)],
"post_check_onnx_graph": EG(
["Abs:3"],
no_unused_inputs=True,
),
}
],
)
class AbsPlugin(PrimitiveLeafPlugin):
def lower(self, ctx: LoweringContextProtocol, eqn: JaxprEqn) -> None:
# 1. Access input and output variables from the equation
x_var = eqn.invars[0]
out_var = eqn.outvars[0]
# 2. Get the ONNX value for the input
x_val = ctx.get_value_for_var(x_var, name_hint=ctx.fresh_name("abs_in"))
# 3. Prepare the output specification (name, type, shape)
out_spec = ctx.get_value_for_var(out_var, name_hint=ctx.fresh_name("abs_out"))
desired_name = getattr(out_spec, "name", None) or ctx.fresh_name("abs_out")
# 4. Create the ONNX node via the builder
# ctx.builder.<OpName>(inputs..., _outputs=[names...])
result = ctx.builder.Abs(x_val, _outputs=[desired_name])
# 5. Set type/shape info if available
if getattr(out_spec, "type", None) is not None:
result.type = out_spec.type
if getattr(out_spec, "shape", None) is not None:
result.shape = out_spec.shape
# 6. Bind the ONNX result back to the JAX output variable
ctx.bind_value_for_var(out_var, result)
Key Components¶
LoweringContextProtocol (ctx)¶
ctx.get_value_for_var(var): Retrieves the ONNX value corresponding to a JAX variable.ctx.bind_value_for_var(var, value): Associates an ONNX value with a JAX output variable.ctx.builder: An interface to create ONNX nodes (e.g.,ctx.builder.Abs,ctx.builder.Add).ctx.fresh_name(hint): Generates a unique name for ONNX tensors.
Testing¶
The testcases list in the decorator allows you to define inline tests. These are automatically picked up by the test runner to verify your plugin against the actual JAX behavior and check the generated ONNX graph structure.
Higher-Level Functions¶
For higher-level functions (like jax.nn.softmax), jax2onnx supports function plugins via @onnx_function or FunctionPlugin. These allow you to map a Python function directly to an ONNX FunctionProto, encapsulating complex logic as a reusable component.
See jax2onnx/plugins/plugin_system.py for more details on function plugins.