Skip to content

JAX v0.4.34

Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 04 Oct 14:51
· 1414 commits to main since this release
  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet
      supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the
      formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
      • array[0] on a pmap result now introduces a reshape (use array[0:1]
        instead).
      • The per-shard shape (accessable via jax_array.addressable_shards or
        jax_array.addressable_data(0)) now has a leading (1, ...). Update code
        that directly accesses shards accordingly. The rank of the per-shard-shape
        now matches that of the global shape which is the same behavior as jit.
        This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with
      JAX version 0.4.26. Now we set the default value of the
      --jax_host_callback_legacy configuration value to True, which means that
      if your code uses jax.experimental.host_callback APIs, those API calls
      will be implemented in terms of the new jax.experimental.io_callback API.
      If this breaks your code, for a very limited time, you can set the
      --jax_host_callback_legacy to True. Soon we will remove that
      configuration option, so you should instead transition to using the
      new JAX callback APIs. See #20385 for a discussion.
  • Deprecations

    • In jax.numpy.trim_zeros, non-arraylike arguments or arraylike
      arguments with ndim != 1 are now deprecated, and in the future will result
      in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after
      being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use
      jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It has been 3 months since its deprecation
      in 0.4.30 JAX release.
      Please use the AOT APIs to get the same functionality as jax.xla_computation.
      • jax.xla_computation(fn)(*args, **kwargs) can be replaced with
        jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
      • You can also use .out_info property of jax.stages.Lowered to get the
        output information (like tree structure, shape and dtype).
      • For cross-backend lowering, you can replace
        jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with
        jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • jax.ShapeDtypeStruct no longer accepts the named_shape argument.
      The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a
      DeprecationWarning, now raises an error. None
      is only a tree-prefix of itself. To preserve the current behavior, you can
      ask jax.tree.map to treat None as a leaf value by writing:
      jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use
      jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where jax.numpy.cumsum would produce incorrect outputs
      if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of jax.numpy.ldexp to get correct gradient.