Skip to content

Past Versions

  • 0.11.2: Replaced custom ir_clone logic with native onnx-ir Graph.clone(); unified the optimization pipeline around standard onnx-ir passes by removing redundant custom pass machinery; and hardened MaxDiffusion exports by fixing environment-dependent UnboundLocalError failures and correcting type annotations.
  • 0.11.1: Added a comprehensive MaxText model family example stack (DeepSeek, Gemma, GPT-3, Kimi, Llama, Mistral, Qwen) with stubbed dependencies and new primitives; stricter subgraph cleanup now produces cleaner, more minimal ONNX graphs.
  • 0.11.0: Initial Flax Linen support: core layers, activation coverage, attention stack, recurrent stack, and Linen examples (MLP/CNN/Sequential); modernized IR optimization pipeline: standard onnx_ir CSE pass adoption, removed legacy helpers/getattr patterns, and simplified tests with direct graph iteration.
  • 0.10.4: Fixed vmap batching for jax.numpy.reshape/transpose and other jax.numpy primitives; refactored the IR optimizer to use onnx-ir public APIs (value.consumers(), graph.remove()); added Common Subexpression Elimination (CSE) and Constant Lifting passes to ir_optimizations.py; added GitHub Actions CI for automated testing.
  • 0.10.3: Added a Flax/NNX DINOv3 VisionTransformer example stack (plugins/examples/nnx/dinov3.py) with deterministic rng helpers, rotary cache capture, and expect_graph coverage across ViT variants; introduced Equinox→NNX parity testing for DINOv3 (weight copy + forward check) to keep the Equinox and Flax paths aligned; example registry keys are now context-aware (context::component) with override warnings to avoid collisions between example stacks; documented NNX DINO exports (static/dynamic batch) and kept generated ONNX artifacts out of git via a dedicated .gitignore.
  • 0.10.2: Added the GPT-OSS export stack (Equinox + Flax/NNX reference modules, parity harnesses, exporter scripts, docs); new primitive coverage for lax.top_k, lax.rsqrt, and Equinox RMSNorm; masked softmax lowers where-masked calls to Softmax + Where while zeroing masked positions; scatter ops in cond/scan now preserve ONNX-compliant initializers/types via refreshed index helpers (fixing Issue #139); strengthened symbolic-dimension handling via DimExpr/shape-polynomial helpers to stabilize broadcast/loop/gather shapes; IR return-mode/input_param materialization fixed and legacy serde_onnx removed for deterministic IR-only outputs; tightened typing with shared typing_support protocols and helper scripts (check_typing.sh, report_rng_traces.py); dependency stack bumped to JAX 0.8.1 / Flax 0.12.1 with matching NNX plugin updates.
  • 0.10.1: Introduced complex-number support with a unified packed layout ([..., 2]) and broad plugin coverage (elementwise, conjugation, dot_general/matmul/conv, FFT via DFT); added stacktrace metadata toggles (pkg.jax2onnx.callsite / pkg.jax2onnx.plugin); enabled an Einsum fallback for lax.dot_general; tightened lax.broadcast_in_dim determinism (always emit Expand, preserve loop extent metadata); and rebuilt lax.reduce_window_sum using a Conv-based lowering covering strides, dilation, integer operands, and static base dilation.
  • 0.10.0: Expanded Equinox DINO exporter (new equinox/eqx/nn plugins and example), introduced shared lowering helpers (_axis0_utils, _loop_extent_meta, jax/lax/gather_compile, jax/lax/gather_helpers, jax/image/resize, jax/numpy/outer), refreshed control-flow and scatter/gather implementations, added @onnx_function(unique=True), refactored the IR builder (clone_graph, live proxies), and bumped dependencies to JAX 0.8.0 / onnx-ir 0.1.11.
  • 0.9.0 migrates from the ONNX proto builder to onnx_ir, adds a return_mode (proto / ir / file), and updates dependencies to JAX 0.7.2, Flax 0.12.0, Equinox 0.13.2, onnx-ir 0.1.10, and onnx 1.19.1.
  • 0.8.1 adds N-D lax.dynamic_update_slice with negative-index handling, sharpens grad/VJP paths for jnp.cumsum, introduces lax.add_any and the lax.pow/jnp.power/jnp.pow family with improved lax.scan dtype propagation, and supports eqx.nn.Linear(use_bias=False).
  • 0.8.0 adds initial Equinox support (eqx.dropout, eqx.layer_norm, eqx.linear, plus an MlpExample), stabilizes SSA/shape handling across lax.scan and lax.fori_loop to prevent dtype leaks, improves dtype propagation in lax.gather and lax.concatenate, and adds plugin support for lax.pad.
  • 0.7.5 fixes tests for functions without arguments, adds support for lax.bitwise_not, lax.clamp, lax.ge, jnp.clip, lax.rev, and enhances support for nnx.dot_product_attention, nnx.conv, nnx.batch_norm, lax.mul, lax.reduce_max, lax.scan, lax.slice, lax.while_loop, nn.gelu, jnp.arange, jnp.cumsum, jnp.select, jnp.where, and jnp.concatenate.
  • 0.7.4 adds support for lax.cumsum and jnp.cumsum, and improves lax.scatter.
  • 0.7.3 improves polymorphism handling for transformers.
  • 0.7.2 adds support for jnp.split, lax.split, lax.logistic, includes an example for nnx.GRUCell, and improves lax.scatter and lax.while_loop.
  • 0.7.1 fixes a numeric equivalence bug in the test system, and adds support for core.custom_jvp_generic, eqx.identity, jnp.select, jnp.stack, jnp.unstack, lax.select, plus multiple nn.* activations (identity, celu, elu, gelu, relu, leaky_relu, mish, selu, sigmoid, soft-sign, softmax, truncated_normal).
  • 0.7.0 introduces a GPT-2 example based on nanoGPT with ONNX function support and attention masking, adds support for jnp.concatenate, jnp.take, nnx.Embed, and starts hosting ONNX models on Hugging Face.
  • 0.6.5 improves support for nnx.batch_norm, nnx.group_norm, nnx.layer_norm, nnx.rms_norm, lax.broadcast_in_dim, lax.cond, lax.fori_loop, lax.integer_pow, lax.scan, lax.scatter, lax.scatter_add, lax.scatter_mul, and lax.while_loop; and adds support for lax.and, lax.rem, and lax.remat2.
  • 0.6.4 improves support for lax.scatter_mul.
  • 0.6.3 applies double-precision fixes for lax.fori_loop and lax.while_loop, and fixes bugs in lax.scan and jnp.where.
  • 0.6.2 fixes bugs in nnx.conv and lax.reshape, and adds the jnp.prod primitive.
  • 0.6.1 improves support for lax.cond and lax.select_n, introduces new primitives (lax.reduce_and, lax.reduce_or, lax.reduce_prod, lax.reduce_xor), and adds examples for jnp.select and jnp.sort.
  • 0.6.0 introduces the enable_double_precision parameter (default False) to support physics simulations and enhances lax.scatter handling.
  • 0.5.2 adds support for jnp.where, jnp.arange, and jnp.linspace.
  • 0.5.1 expands subgraph coverage for lax.while_loop, lax.cond, lax.fori_loop, and lax.scan.
  • 0.5.0 improves dynamic batch handling via shape polymorphism and adds jnp.sign, jnp.abs, and jnp.iota.
  • 0.4.4 adds lax.cos, lax.cosh, lax.sin, lax.sinh, and lax.scatter support.
  • 0.4.3 fixes ONNX validation for JAX callable outputs and cleans up newly exposed tests.
  • 0.4.2 cleans up and fixes the initial ONNX function release.
  • 0.4.1 introduces ONNX functions via the @onnx_function decorator, creating function instances directly in the call graph.
  • 0.3.2 relaxes the minimum Python version to 3.10.
  • 0.3.0 streamlines plugin registration and custom primitive integration.
  • 0.2.0 (first PyPI release) rebases on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (initial approach, not released to PyPI) exports early nnx components and examples including a vision transformer.