jax2onnx

expect_graph checklist for plugins

expect_graph (from jax2onnx.plugins._post_check_onnx_graph) is the lightweight structural assertion helper used by plugin tests and examples. It lets a test express the operators, ordering, and shapes that should appear in a converted IR/ONNX graph without dumping the full model. This document captures the conventions we rely on when writing or reviewing post_check_onnx_graph expectations.

Test metadata reminder: when wiring new examples/tests, construct callables via construct_and_call(...).with_requested_dtype(...).with_rng_seed(...) so the harness can rebuild deterministic f32/f64 variants. See the builder guide for the full randomness and dtype rules.

Import

from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG

Alias it to EG inside tests to keep callsites short.

Builder reminder: structural tests assume plugins emitted nodes via ctx.builder. Review the ONNX IR Builder Guide if _outputs naming or initializer wiring looks suspicious; policy tests now enforce those contracts.

Basic usage

Pass a list of patterns to expect_graph. Each pattern is either a string or a (string, options) tuple. Nodes are written in evaluation order with -> separating them.

EG([
    "Transpose -> Conv -> Relu -> AveragePool",
])

The pattern above requires the graph to contain that exact operator chain. Failing to find it raises an assertion with a summarized diff of the graph.

Encoding shapes

Append :shape to a node name to assert the output shape of that node. Use x separators (e.g. Bx32x28x28). Leave dimensions symbolic by reusing the string symbol that the test harness passed as an input shape (for example "B").

EG([
    "Gemm:Bx256 -> Relu:Bx256 -> Gemm:Bx10",
])

Write concrete integers for known static sizes (3x1x28x28). Symbols and integers can be mixed (B?x256 is not supported; prefer symbols={"B": None} if you need to unify multiple strings).

Additional match options

Attach an options dictionary to require counts, forbid nodes, or tweak the search.

EG([
    (
        "Transpose:3x1x28x28 -> Conv:3x32x28x28 -> Relu:3x32x28x28 -> Gemm:3x256",
        {
            "counts": {"Transpose": 1, "Conv": 1, "Relu": 1, "Gemm": 1},
        },
    ),
],
no_unused_inputs=True,
mode="all",
must_absent=["Not"],
)

Common fields:

The matcher automatically walks through helper nodes that frequently sit on the main data edge (by default we skip Reshape, Identity, Cast, CastLike, Squeeze, Unsqueeze, Flatten, Shape, Gather, Concat, Add, and Where). This lets a single pattern cover sequential graphs where tensors fan out into shape-building side chains, such as the CNN dynamic example where the Transpose output feeds both Reshape and the shape-construction subgraph.

Function naming compatibility

Function exports now keep the original callable name as the node op_type (TransformerBlock, MLPBlock, …) and move the numeric suffix into node.name/domain (TransformerBlock_2, custom.TransformerBlock_2, …). To keep older expectations valid, expect_graph automatically strips trailing _123 suffixes when comparing op_type and normalises graph filters such as fn:custom.TransformerBlock_2. Prefer matching on the base op_type unless a specific call-site needs to be distinguished by name.

Practical tips

Where to use it

post_check_onnx_graph entries appear inside example/plugin test metadata (see jax2onnx/plugins/examples/nnx/cnn.py for a reference). The helper works with any object that produces an ONNX IR graph compatible with onnx_ir.GraphProto. The same API is shared by policy tests under tests/extra_tests.

When adding new metadata entries, seed them with a minimal structural check, run the example once to capture the intended op sequence, and then layer on shape assertions and counts to guard against regressions.