Complex Numbers in jax2onnx¶
This guide explains how jax2onnx handles complex tensors while staying within the ONNX specification, and how plugin authors should interact with the shared helper utilities.
Why we need a strategy¶
ONNX has only limited native complex-tensor support across operators. Most
arithmetic, shape, and control-flow primitives expect real-valued tensors. The
main complex-aware operator we rely on is DFT, which represents complex inputs
and outputs as real tensors whose trailing dimension packs the real and imaginary
channels ([..., 2]).
To stay portable across runtimes we represent every complex tensor as a real tensor with that trailing size-2 channel. Conversion never emits Real, Imag, or other custom operators—everything is expressed in terms of standard ONNX ops on real tensors.
Helper surface (plugins/_complex_utils.py)¶
| Helper | Purpose |
|---|---|
ensure_complex_dtype(ctx, value, target_dtype) |
Cast native complex tensors between COMPLEX64 and COMPLEX128 when a lowering needs an exact complex dtype. |
pack_native_complex(ctx, tensor) |
Reinterpret a native complex64/complex128 value as a packed real tensor ([..., 2]). Handles double-precision upgrades automatically when enable_double_precision=True. |
is_packed_complex_tensor(value) |
Detect whether a value already uses the packed representation. |
ensure_packed_real_pair(ctx, value, *, name_hint) |
Return (packed_tensor, base_dtype) for both native complex inputs and already-packed tensors. Raises if the value is neither. |
cast_real_tensor(ctx, value, target_dtype, *, name_hint) |
Insert a Cast when the packed tensor must move between FLOAT and DOUBLE representations. |
resolve_common_real_dtype(lhs, rhs) |
Pick the shared real dtype (FLOAT or DOUBLE) for binary complex operations. |
split_packed_real_imag(ctx, value, base_dtype, *, prefix) |
Gather the trailing real and imaginary channels from a packed tensor, returning two real tensors. |
pack_real_imag_pair(ctx, real, imag, base_dtype, *, name_hint) |
Unsqueeze matching real/imag tensors and concatenate them back into the packed [... , 2] representation. |
conjugate_packed_tensor(ctx, value, base_dtype, *, prefix) |
Flip the sign of the imaginary channel while preserving shape metadata, producing the complex conjugate of a packed tensor. |
coerce_dim_values(dims) |
Normalise shape metadata so onnx_ir can stamp symbolic dimensions and integers consistently. |
unpack_to_native_complex(...) |
Convert a packed tensor back to a native complex value (used rarely, e.g. when handing results back to JAX in test harnesses). |
These helpers take care of dtype metadata, IRBuilder stamping, and axis bookkeeping so individual plugins only need to express the real-valued arithmetic. New complex-aware plugins should rely on them instead of ad hoc Gather / Reshape sequences so every lowering shares the same representation.
Supported operations¶
- Construction and projection (
lax.complex,lax.real,lax.imag): lax.complexpacks matching real and imaginary tensors into the shared trailing-channel representation.-
lax.realandlax.imaggather the corresponding channel for packed or native complex inputs; real inputs pass through the real path. -
Elementwise arithmetic (
lax.add,lax.sub,lax.mul,lax.div): - Detection logic looks at the JAX avals and value metadata. When a complex value is involved we normalise operands through
ensure_packed_real_pair, align their base dtype (FLOAT↔DOUBLE) viaresolve_common_real_dtype/cast_real_tensor, run the real-valued formulas, and usepack_real_imag_pairto rebuild the packed output. -
Outputs inherit the packed representation and expose real metadata (
tensor(float)/tensor(double)with trailing2). -
FFT pipeline (
lax.fft,jnp.fftfor FFT/IFFT/RFFT): - Complex inputs (FFT/IFFT) are packed, reshaped if needed, and lowered to ONNX
DFTwithinverse/onesidedflags. Real inputs (RFFT) receive the trailing channel before invokingDFT. IRFFTcurrently requires explicitfft_lengths. The implementation reconstructs the missing half of the spectrum, flips the imaginary channel, and runs a forward packedDFTbefore gathering the real component.-
For
jnp.fft, metadata-only primitives reuse the samelax.fftlowering for canonical 1-D forms (axis=-1, optional length).jnp.fft.irfftfollows the same packed reconstruction path aslax.fftand currently needs static length information. -
MatMul / Einsum family (
jax.lax.dot_general,jnp.matmul): - Operands are normalised via
ensure_packed_real_pairand cast to a shared real dtype. The real/imag channels are split withsplit_packed_real_imag, the real-valued contraction (EinsumorMatMul) runs four times, andpack_real_imag_pairstitches the results back together. - For
dot_general, both the batched MatMul fast-path and generalEinsumlowering share the same helper plumbing so the trailing complex channel is never part of the contraction labels. -
For
jnp.matmul, the four-real flow lowers to four ONNXMatMulnodes before recombining; broadcasting and vector/matrix promotion match the real path. -
Convolutions (
jax.lax.conv_general_dilated): - Inputs and kernels flow through
ensure_packed_real_pair, are cast to a shared dtype, and have the complex channel split before any layout transposes. -
Each of the four real-valued paths runs through the existing Conv lowering (after layout canonicalisation). Outputs are optionally transposed back to the requested layout and re-packed with
pack_real_imag_pair. -
Conjugation (
jax.lax.conj,jnp.conj): -
Normalise packed/native complex inputs with
ensure_packed_real_pair, callconjugate_packed_tensorto negate the imaginary channel, and return the packed output. Real inputs bypass through anIdentity. -
Complex-valued outputs (
lax.linalg.eig,jnp.roots): -
Plugins that naturally produce complex results should emit the same packed representation and use
_complex_utilshelpers for dtype and shape metadata. -
Tests: regression coverage is generated from plugin metadata and exercised through
tests/primitives/test_lax.py,tests/primitives/test_jnp.py, and focused helper tests undertests/extra_tests/converter/test_complex_utils.py.
Authoring new plugins with complex inputs¶
- Detect complex flows early. Inspect JAX avals (
var.aval.dtype) or existing value metadata. If the operand is complex, callensure_packed_real_pair(...)to normalise it. - Work in real space. Once packed, treat the tensors as real arrays. Use
resolve_common_real_dtypeandcast_real_tensorto reconcile dtypes before running arithmetic. - Stamp shapes and metadata. Most helpers already stamp values, but if you build new tensors (e.g., concatenations) remember to call
_stamp_type_and_shapewithcoerce_dim_values(...)so the ONNX graph carries explicit metadata. - Return packed outputs. Results should remain in
[... , 2]form. Do not attempt to reintroduce native complex ONNX tensors—runtimes will reject them. - Tests + docs. Add
expect_graphsnippets alongside the plugin metadata and cover complex variants in the autogenerated test suites.
Current limitations¶
- The packed representation is the only supported ONNX-side complex convention. Do not introduce native complex ONNX tensors for general arithmetic paths.
- Not every JAX primitive that accepts complex inputs has a complex-aware lowering. Several linear-algebra paths still reject complex inputs explicitly; add support case by case using the same packed-real recipe.
IRFFTneeds static onesided length information. If a new frontend path hides that length, add metadata/tests before claiming coverage.- Convolution transpose / deconvolution paths should reuse the same four-real structure if complex support is added there.
Potential optimizations¶
- Gaussian 3-multiply strategy: at the moment we always lower via the straightforward four-real expansion. The
complex_strategyknob isn’t wired yet; it could be exposed as (four_realby default,gaussas an opt-in). The Gauss variant would replace the four real multiplies[(a_r b_r), (a_i b_i), (a_r b_i), (a_i b_r)]with three multiplies plus a few adds: p1 = a_r * b_rp2 = a_i * b_ip3 = (a_r + a_i) * (b_r + b_i)
then reconstructreal = p1 - p2andimag = p3 - p1 - p2.
We’ll evaluate this once we can guarantee backend support and have regression coverage for the numerical trade-offs.