Skip to content

Releases: jax-ml/jax

JAX release v0.2.27

18 Jan 19:31
Compare
Choose a tag to compare
  • GitHub commits.

  • Breaking changes:

    • Support for NumPy 1.18 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/ deprecation.html). Please upgrade to a supported NumPy version.
    • The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variable, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678).
    • Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the bit representation. In particular, 0.0 and -0.0 are now treated as equivalent, where previously -0.0 was treated as less than 0.0. Additionally all NaN representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN values were sorted to the front of the array, and NaN values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax- issue}#9178).
    • {func}jax.numpy.unique now treats NaN values in the same way as np.unique in NumPy versions 1.21 and newer: at most one NaN value will appear in the uniquified output ({jax-issue}9184).
  • Bug fixes:

    • host_callback now supports ad_checkpoint.checkpoint ({jax-issue}#8907).
  • New features:

    • add jax.block_until_ready ({jax-issue}`#8941)
    • Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.
    • Added jax.ensure_compile_time_eval to the public api ({jax-issue}#7987).
    • jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}#9189).

JAX release v0.2.26

08 Dec 19:20
Compare
Choose a tag to compare
  • Bug fixes:

  • Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).

  • jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).

Jaxlib release v0.1.75

07 Dec 15:13
Compare
Choose a tag to compare
  • New features:
    • Support for python 3.10.

Jaxlib release v0.1.74

16 Nov 20:13
Compare
Choose a tag to compare
jaxlib-v0.1.74

Jaxlib v0.1.74

JAX release v0.2.25

10 Nov 22:26
Compare
Choose a tag to compare
  • New features:

    • (Experimental) jax.distributed.initialize exposes multi-host GPU backend.
    • jax.random.permutation supports new independent keyword argument
      ({jax-issue}#8430)
  • Breaking changes

    • Moved jax.experimental.stax to jax.example_libraries.stax
    • Moved jax.experimental.optimizers to jax.example_libraries.optimizers
  • New features:

    • Added jax.lax.linalg.qdwh.

Jax release v0.2.24

19 Oct 15:06
Compare
Choose a tag to compare
  • New features:
    • jax.random.choice and jax.random.permutation now support
      multidimensional arrays and an optional axis argument (#8158)
  • Breaking changes:
    • jax.numpy.take and jax.numpy.take_along_axis now require array-like inputs
      (see #7737)

Jaxlib release v0.1.73

18 Oct 22:40
Compare
Choose a tag to compare
Update the workspace file

PiperOrigin-RevId: 404076864

jaxlib release v0.1.72

12 Oct 19:53
Compare
Choose a tag to compare
Merge pull request #8181 from skye:workspace

PiperOrigin-RevId: 402632543

Jax release v0.2.21

23 Sep 18:09
Compare
Choose a tag to compare
  • New features:

    • Added jax.numpy.insert implementation (#7936 ).
  • Breaking Changes

    • jax.api has been removed. Functions that were available as jax.api.*
      were aliases for functions in jax.*; please use the functions in
      jax.* instead.
    • jax.partial, jax.lax.partial, and jax.util.partial were accidental
      exports that have now been removed. Use functools.partial from the Python
      standard library instead.
    • Boolean scalar indices now raise a TypeError; previously this silently
      returned wrong results (#7925 ).
    • Many more jax.numpy functions now require array-like inputs, and will error
      if passed a list (#7747 #7802 #7907 ).
      See #7737 for a discussion of the rationale behind this change.
    • When inside a transformation such as jax.jit, jax.numpy.array always
      stages the array it produces into the traced computation. Previously
      jax.numpy.array would sometimes produce a on-device array, even under
      a jax.jit decorator. This change may break code that used JAX arrays to
      perform shape or index computations that must be known statically; the
      workaround is to perform such computations using classic NumPy arrays
      instead.
    • jnp.ndarray is now a true base-class for JAX arrays. In particular, this
      means that for a standard numpy array x, isinstance(x, jnp.ndarray) will
      now return False (#7927).

Jax release v0.2.20

03 Sep 16:19
Compare
Choose a tag to compare
Merge pull request #7793 from yashk2810:update_pypi

PiperOrigin-RevId: 394697075