Releases: jax-ml/jax
Releases · jax-ml/jax
Jaxlib release v0.3.14
jaxlib-v0.3.14 Jaxlib release v0.3.14
JAX release v0.3.14
jax-v0.3.14 JAX release v0.3.14
JAX release v0.3.13
jax-v0.3.13 jax version 0.3.13
Jax release v0.3.12
- Changes
- Fixes #10717
Jax release v0.3.11
- Changes
- {func}
jax.lax.eigh
now accepts an optionalsort_eigenvalues
argument
that allows users to opt out of eigenvalue sorting on TPU.
- {func}
- Deprecations
- Non-array arguments to functions in {mod}
jax.lax.linalg
are now marked
keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail.
However, most users should prefer to use {mod}jax.numpy.linalg
instead. - {func}
jax.scipy.linalg.polar_unitary
, which was a JAX extension to the
scipy API, is deprecated. Use {func}jax.scipy.linalg.polar
instead.
- Non-array arguments to functions in {mod}
Jaxlib release v0.3.10
Update TF commit for release PiperOrigin-RevId: 446555288
Jax release v0.3.10
Update TF commit for release PiperOrigin-RevId: 446555288
Jax release 0.3.9
- Changes
- Added support for fully asynchronous checkpointing for GlobalDeviceArray.
JAX release v0.3.8
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svd
on TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.cond
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinv
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rank
on TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vq
has been added. jax.experimental.maps.mesh
has been deleted.
Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.- {func}
jax.scipy.linalg.qr
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
({jax-issue}#10452
) - {func}
jax.numpy.take_along_axis
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
. - {func}
jax.numpy.take
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axis
now raises aTypeError
if its indices are not of an integer type, matching the behavior of
{func}numpy.take_along_axis
. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_index
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index
. Previously non-integerdims
was silently cast to integers. - {func}
jax.numpy.split
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior of {func}numpy.split
. Previously non-integeraxis
was silently cast to integers. - {func}
jax.numpy.indices
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices
. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diag
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior of {func}numpy.diag
. Previously non-integerk
was silently cast to integers. - Added {func}
jax.random.orthogonal
.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
({jax-issue}#10389
). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest
, {mod}absl.testing
, {mod}numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices
. Many of the deprecated utilities will still exist in {mod}jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
Jaxlib v0.3.7
- Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.