Skip to content

Releases: jax-ml/jax

jaxlib release v0.4.2

26 Jan 01:06
Compare
Choose a tag to compare
jaxlib-v0.4.2

jaxlib version 0.4.2

JAX release v0.4.2

26 Jan 01:07
Compare
Choose a tag to compare
jax-v0.4.2

jax version 0.4.2

Jaxlib release v0.4.1

13 Dec 18:24
Compare
Choose a tag to compare
  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's
      {ref}version-support-policy.
    • The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX has been changed to allocate XX% of
      the total GPU memory instead of the previous behavior of using currently available GPU memory
      to calculate preallocation. Please refer to
      GPU memory allocation for
      more details.
    • The deprecated method .block_host_until_ready() has been removed. Use
      .block_until_ready() instead.

Jax release v0.4.1

13 Dec 18:23
Compare
Choose a tag to compare
  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's
      {ref}version-support-policy.
    • We introduce jax.Array which is a unified array type that subsumes
      DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX.
      The jax.Array type helps make parallelism a core feature of JAX,
      simplifies and unifies JAX internals, and allows us to unify jit and
      pjit. jax.Array has been enabled by default in JAX 0.4 and makes some
      breaking change to the pjit API. The jax.Array migration
      guide
      can
      help you migrate your codebase to jax.Array. You can also look at the
      Distributed arrays and automatic parallelization
      tutorial to understand the new concepts.
    • PartitionSpec and Mesh are now out of experimental. The new API endpoints
      are jax.sharding.PartitionSpec and jax.sharding.Mesh.
      jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are
      deprecated and will be removed in 3 months.
    • with_sharding_constraints new public endpoint is
      jax.lax.with_sharding_constraint.
    • If using ABSL flags together with jax.config, the ABSL flag values are no
      longer read or written after the JAX configuration options are initially
      populated from the ABSL flags. This change improves performance of reading
      jax.config options, which are used pervasively in JAX.
    • The jax2tf.call_tf function now uses for TF lowering the first TF
      device of the same platform as used by the embedding JAX computation.
      Before, it was using the 0th device for the JAX-default backend.
    • A number of jax.numpy functions now have their arguments marked as
      positional-only, matching NumPy.
    • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24.
      It will be removed in a future release, in accordance with the {ref}api-compatibility
      policy. It can be replaced with jnp.sort(a, axis=0).

Jaxlib release v0.3.25

15 Nov 15:32
Compare
Choose a tag to compare
jaxlib-v0.3.25

jaxlib version 0.3.25

Jax release v0.3.25

15 Nov 15:32
Compare
Choose a tag to compare
jax-v0.3.25

jax version 0.3.25

Jaxlib release v0.3.24

04 Nov 14:51
Compare
Choose a tag to compare
jaxlib-v0.3.24

jaxlib version 0.3.24

Jax release v0.3.24

04 Nov 14:50
Compare
Choose a tag to compare
jax-v0.3.24

jax version 0.3.24

jax-v0.3.23

12 Oct 18:07
Compare
Choose a tag to compare
  • Changes
    • Update Colab TPU driver version for new jaxlib release.

jaxlib-v0.3.22

11 Oct 21:34
Compare
Choose a tag to compare
jaxlib version 0.3.22