Skip to content

Releases: jax-ml/jax

Jaxlib release v0.3.14

21 Jun 21:51
Compare
Choose a tag to compare
jaxlib-v0.3.14

Jaxlib release v0.3.14

JAX release v0.3.14

21 Jun 21:51
Compare
Choose a tag to compare
jax-v0.3.14

JAX release v0.3.14

JAX release v0.3.13

16 May 19:13
Compare
Choose a tag to compare
jax-v0.3.13

jax version 0.3.13

Jax release v0.3.12

16 May 00:52
Compare
Choose a tag to compare

Jax release v0.3.11

15 May 18:39
Compare
Choose a tag to compare
  • Changes
    • {func}jax.lax.eigh now accepts an optional sort_eigenvalues argument
      that allows users to opt out of eigenvalue sorting on TPU.
  • Deprecations
    • Non-array arguments to functions in {mod}jax.lax.linalg are now marked
      keyword-only. As a backward-compatibility step passing keyword-only
      arguments positionally yields a warning, but in a future JAX release passing
      keyword-only arguments positionally will fail.
      However, most users should prefer to use {mod}jax.numpy.linalg instead.
    • {func}jax.scipy.linalg.polar_unitary, which was a JAX extension to the
      scipy API, is deprecated. Use {func}jax.scipy.linalg.polar instead.

Jaxlib release v0.3.10

04 May 21:52
Compare
Choose a tag to compare
Update TF commit for release

PiperOrigin-RevId: 446555288

Jax release v0.3.10

04 May 21:52
Compare
Choose a tag to compare
Update TF commit for release

PiperOrigin-RevId: 446555288

Jax release 0.3.9

03 May 02:41
Compare
Choose a tag to compare
  • Changes
    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

JAX release v0.3.8

30 Apr 03:09
Compare
Choose a tag to compare
  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted.
      Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
      for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of
      {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

Jaxlib v0.3.7

29 Apr 18:18
Compare
Choose a tag to compare
  • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.