Past Versions
- 0.8.1 adds N-D
lax.dynamic_update_slice
with negative-index handling, sharpens grad/VJP paths for jnp.cumsum
, introduces lax.add_any
and the lax.pow
/jnp.power
/jnp.pow
family with improved lax.scan
dtype propagation, and supports eqx.nn.Linear(use_bias=False)
.
- 0.8.0 adds initial Equinox support (
eqx.dropout
, eqx.layer_norm
, eqx.linear
, plus an MlpExample
), stabilizes SSA/shape handling across lax.scan
and lax.fori_loop
to prevent dtype leaks, improves dtype propagation in lax.gather
and lax.concatenate
, and adds plugin support for lax.pad
.
- 0.7.5 fixes tests for functions without arguments, adds support for
lax.bitwise_not
, lax.clamp
, lax.ge
, jnp.clip
, lax.rev
, and enhances support for nnx.dot_product_attention
, nnx.conv
, nnx.batch_norm
, lax.mul
, lax.reduce_max
, lax.scan
, lax.slice
, lax.while_loop
, nn.gelu
, jnp.arange
, jnp.cumsum
, jnp.select
, jnp.where
, and jnp.concatenate
.
- 0.7.4 adds support for
lax.cumsum
and jnp.cumsum
, and improves lax.scatter
.
- 0.7.3 improves polymorphism handling for transformers.
- 0.7.2 adds support for
jnp.split
, lax.split
, lax.logistic
, includes an example for nnx.GRUCell
, and improves lax.scatter
and lax.while_loop
.
- 0.7.1 fixes a numeric equivalence bug in the test system, and adds support for
core.custom_jvp_generic
, eqx.identity
, jnp.select
, jnp.stack
, jnp.unstack
, lax.select
, plus multiple nn.*
activations (identity
, celu
, elu
, gelu
, relu
, leaky_relu
, mish
, selu
, sigmoid
, soft-sign
, softmax
, truncated_normal
).
- 0.7.0 introduces a GPT-2 example based on nanoGPT with ONNX function support and attention masking, adds support for
jnp.concatenate
, jnp.take
, nnx.Embed
, and starts hosting ONNX models on Hugging Face.
- 0.6.5 improves support for
nnx.batch_norm
, nnx.group_norm
, nnx.layer_norm
, nnx.rms_norm
, lax.broadcast_in_dim
, lax.cond
, lax.fori_loop
, lax.integer_pow
, lax.scan
, lax.scatter
, lax.scatter_add
, lax.scatter_mul
, and lax.while_loop
; and adds support for lax.and
, lax.rem
, and lax.remat2
.
- 0.6.4 improves support for
lax.scatter_mul
.
- 0.6.3 applies double-precision fixes for
lax.fori_loop
and lax.while_loop
, and fixes bugs in lax.scan
and jnp.where
.
- 0.6.2 fixes bugs in
nnx.conv
and lax.reshape
, and adds the jnp.prod
primitive.
- 0.6.1 improves support for
lax.cond
and lax.select_n
, introduces new primitives (lax.reduce_and
, lax.reduce_or
, lax.reduce_prod
, lax.reduce_xor
), and adds examples for jnp.select
and jnp.sort
.
- 0.6.0 introduces the
enable_double_precision
parameter (default False
) to support physics simulations and enhances lax.scatter
handling.
- 0.5.2 adds support for
jnp.where
, jnp.arange
, and jnp.linspace
.
- 0.5.1 expands subgraph coverage for
lax.while_loop
, lax.cond
, lax.fori_loop
, and lax.scan
.
- 0.5.0 improves dynamic batch handling via shape polymorphism and adds
jnp.sign
, jnp.abs
, and jnp.iota
.
- 0.4.4 adds
lax.cos
, lax.cosh
, lax.sin
, lax.sinh
, and lax.scatter
support.
- 0.4.3 fixes ONNX validation for JAX callable outputs and cleans up newly exposed tests.
- 0.4.2 cleans up and fixes the initial ONNX function release.
- 0.4.1 introduces ONNX functions via the
@onnx_function
decorator, creating function instances directly in the call graph.
- 0.3.2 relaxes the minimum Python version to 3.10.
- 0.3.0 streamlines plugin registration and custom primitive integration.
- 0.2.0 (first PyPI release) rebases on
jaxpr
, improving usability and adding low-level lax
components.
- 0.1.0 (initial approach, not released to PyPI) exports early
nnx
components and examples including a vision transformer.