Releases: jax-ml/jax
Releases · jax-ml/jax
jaxlib release v0.4.2
jaxlib-v0.4.2 jaxlib version 0.4.2
JAX release v0.4.2
jax-v0.4.2 jax version 0.4.2
Jaxlib release v0.4.1
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}version-support-policy
. - The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
GPU memory allocation for
more details. - The deprecated method
.block_host_until_ready()
has been removed. Use
.block_until_ready()
instead.
- Support for Python 3.7 has been dropped, in accordance with JAX's
Jax release v0.4.1
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}version-support-policy
. - We introduce
jax.Array
which is a unified array type that subsumes
DeviceArray
,ShardedDeviceArray
, andGlobalDeviceArray
types in JAX.
Thejax.Array
type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unifyjit
and
pjit
.jax.Array
has been enabled by default in JAX 0.4 and makes some
breaking change to thepjit
API. The jax.Array migration
guide can
help you migrate your codebase tojax.Array
. You can also look at the
Distributed arrays and automatic parallelization
tutorial to understand the new concepts. PartitionSpec
andMesh
are now out of experimental. The new API endpoints
arejax.sharding.PartitionSpec
andjax.sharding.Mesh
.
jax.experimental.maps.Mesh
andjax.experimental.PartitionSpec
are
deprecated and will be removed in 3 months.with_sharding_constraint
s new public endpoint is
jax.lax.with_sharding_constraint
.- If using ABSL flags together with
jax.config
, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
jax.config
options, which are used pervasively in JAX. - The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend. - A number of
jax.numpy
functions now have their arguments marked as
positional-only, matching NumPy. jnp.msort
is now deprecated, following the deprecation ofnp.msort
in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}api-compatibility
policy. It can be replaced withjnp.sort(a, axis=0)
.
- Support for Python 3.7 has been dropped, in accordance with JAX's
Jaxlib release v0.3.25
jaxlib-v0.3.25 jaxlib version 0.3.25
Jax release v0.3.25
jax-v0.3.25 jax version 0.3.25
Jaxlib release v0.3.24
jaxlib-v0.3.24 jaxlib version 0.3.24
Jax release v0.3.24
jax-v0.3.24 jax version 0.3.24
jax-v0.3.23
- Changes
- Update Colab TPU driver version for new jaxlib release.
jaxlib-v0.3.22
jaxlib version 0.3.22