expect_graph
checklist for pluginsexpect_graph
(from jax2onnx.plugins._post_check_onnx_graph
) is the lightweight
structural assertion helper used by plugin tests and examples. It lets a test
express the operators, ordering, and shapes that should appear in a converted
IR/ONNX graph without dumping the full model. This document captures the
conventions we rely on when writing or reviewing post_check_onnx_graph
expectations.
Test metadata reminder: when wiring new examples/tests, construct callables via
construct_and_call(...).with_requested_dtype(...).with_rng_seed(...)
so the harness can rebuild deterministic f32/f64 variants. See the builder guide for the full randomness and dtype rules.
from jax2onnx.plugins._post_check_onnx_graph import expect_graph as EG
Alias it to EG
inside tests to keep callsites short.
Builder reminder: structural tests assume plugins emitted nodes via
ctx.builder
. Review the ONNX IR Builder Guide if_outputs
naming or initializer wiring looks suspicious; policy tests now enforce those contracts.
Pass a list of patterns to expect_graph
. Each pattern is either a string or a
(string, options)
tuple. Nodes are written in evaluation order with ->
separating them.
EG([
"Transpose -> Conv -> Relu -> AveragePool",
])
The pattern above requires the graph to contain that exact operator chain. Failing to find it raises an assertion with a summarized diff of the graph.
Append :shape
to a node name to assert the output shape of that node. Use
x
separators (e.g. Bx32x28x28
). Leave dimensions symbolic by reusing the
string symbol that the test harness passed as an input shape (for example
"B"
).
EG([
"Gemm:Bx256 -> Relu:Bx256 -> Gemm:Bx10",
])
Write concrete integers for known static sizes (3x1x28x28
). Symbols and
integers can be mixed (B?x256
is not supported; prefer symbols={"B": None}
if you need to unify multiple strings).
Attach an options dictionary to require counts, forbid nodes, or tweak the search.
EG([
(
"Transpose:3x1x28x28 -> Conv:3x32x28x28 -> Relu:3x32x28x28 -> Gemm:3x256",
{
"counts": {"Transpose": 1, "Conv": 1, "Relu": 1, "Gemm": 1},
},
),
],
no_unused_inputs=True,
mode="all",
must_absent=["Not"],
)
Common fields:
counts
: map of op type to the exact number of occurrences expected.must_absent
: list of operator names that must not appear anywhere.symbols
: dictionary mapping symbolic dim labels to None
(any value) or an
integer (specific size). Use it when multiple patterns should share the same
symbolic dimension.mode
: one of "all"
(default; all patterns must match), "any"
(at least
one matches), or "exact"
(the entire graph must equal the pattern).no_unused_inputs
: when True
, fail if the graph retains dangling inputs
after conversion. Combine with no_unused_function_inputs=True
to extend the
check to every imported ONNX function body (requires search_functions=True
).search_functions
: include function bodies (control-flow subgraphs) in the
search.The matcher automatically walks through helper nodes that frequently sit on the
main data edge (by default we skip Reshape
, Identity
, Cast
, CastLike
,
Squeeze
, Unsqueeze
, Flatten
, Shape
, Gather
, Concat
, Add
, and
Where
). This lets a single pattern cover sequential graphs where tensors fan
out into shape-building side chains, such as the CNN dynamic example where the
Transpose
output feeds both Reshape
and the shape-construction subgraph.
Function exports now keep the original callable name as the node op_type
(TransformerBlock
, MLPBlock
, …) and move the numeric suffix into
node.name
/domain
(TransformerBlock_2
, custom.TransformerBlock_2
, …). To
keep older expectations valid, expect_graph
automatically strips trailing
_123
suffixes when comparing op_type
and normalises graph filters such as
fn:custom.TransformerBlock_2
. Prefer matching on the base op_type
unless a
specific call-site needs to be distinguished by name.
_stamp_type_and_shape
calls and layout errors quickly.mode="all"
with multiple patterns to check disjoint subgraphs, or
mode="exact"
when the entire graph must be anchored (rare; harder to
maintain).counts
to constrain the
totals.post_check_onnx_graph
entries appear inside example/plugin test metadata (see
jax2onnx/plugins/examples/nnx/cnn.py
for a reference). The helper works with
any object that produces an ONNX IR graph compatible with
onnx_ir.GraphProto
. The same API is shared by policy tests under
tests/extra_tests
.
When adding new metadata entries, seed them with a minimal structural check, run the example once to capture the intended op sequence, and then layer on shape assertions and counts to guard against regressions.