diff --git a/CHANGELOG.md b/CHANGELOG.md index d670e43b6137..204df6a83e52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. declared inline via {func}`dataclasses.field`. See the function documentation for examples. * Added {func}`jax.numpy.put_along_axis`. + * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions + ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now + supported on GPU. See {jax-issue}`#24663` for more details. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/jax/_src/config.py b/jax/_src/config.py index 72f394dba76f..1c62f7125ee7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1963,3 +1963,14 @@ def _update_garbage_collection_guard(state, key, val): ), include_in_jit_key=True, ) + +gpu_use_magma = enum_state( + name='jax_use_magma', + enum_values=['off', 'on', 'auto'], + default='auto', + help=( + 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. ' + 'See the documentation for lax.linalg.eig for more details about how ' + 'to use this feature.' + ), +) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0e0390abc78f..62cb72c69fd7 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,16 +121,46 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, - compute_right_eigenvectors: bool = True) -> list[Array]: + compute_right_eigenvectors: bool = True, + use_magma: bool | None = None) -> list[Array]: """Eigendecomposition of a general matrix. - Nonsymmetric eigendecomposition is at present only implemented on CPU. + Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU, + the default implementation calls LAPACK directly on the host CPU, but an + experimental GPU implementation using `MAGMA `_ + is also available. The MAGMA implementation is typically slower than the + equivalent LAPACK implementation for small matrices (less than about 2048), + but it may perform better for larger matrices. + + To enable the MAGMA implementation, you must install MAGMA yourself (there + are Debian and conda-forge packages, or you can build from source). Then set + the ``use_magma`` argument to ``True``, or set the ``jax_use_magma`` + configuration variable to ``"on"`` or ``"auto"``: + + .. code-block:: python + + jax.config.update('jax_use_magma', 'on') + + JAX will try to ``dlopen`` the installed MAGMA shared library, raising an + error if it is not found. To explicitly specify the path to the MAGMA + library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full + installation path. + + If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will + be used if the library can be found, and the input matrix is sufficiently + large (>= 2048x2048). Args: x: A batch of square matrices with shape ``[..., n, n]``. compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + eigendecomposition is computed using MAGMA. If ``False``, the computation + is done using LAPACK on to the host CPU. If ``None`` (default), the + behavior is controlled by the ``jax_use_magma`` flag. This argument + is only used on GPU. + Returns: The eigendecomposition of ``x``, which is a tuple of the form ``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left @@ -142,7 +172,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, for that batch element. """ return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma) def eigh( @@ -678,12 +709,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta): # Asymmetric eigendecomposition -def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): +def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): return dispatch.apply_primitive( eig_p, operand, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma, ) def eig_lower(*args, **kw): @@ -692,7 +725,8 @@ def eig_lower(*args, **kw): "If your matrix is symmetric or Hermitian, you should use eigh instead.") def eig_abstract_eval(operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " @@ -716,7 +750,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors, return tuple(output) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -763,18 +798,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, return output +def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors, + compute_right_eigenvectors, use_magma): + gpu_solver.initialize_hybrid_kernels() + dtype = x.dtype + is_real = dtype == np.float32 or dtype == np.float64 + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + batch_dims = x.shape[:-2] + n, m = x.shape[-2:] + assert n == m + num_batch_dims = len(batch_dims) + + layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims) + out_types = [ + api.ShapeDtypeStruct(batch_dims + (n,), dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims, np.int32), + ] + out_layouts = [None, layout, layout, None] + if is_real: + out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types + out_layouts = [None] + out_layouts + + magma = config.gpu_use_magma.value + if use_magma is not None: + magma = "on" if use_magma else "off" + fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout], + output_layouts=out_layouts) + *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = lax.complex(*w) + else: + assert len(w) == 1 + w = w[0] + ok = lax.eq(info, lax.zeros_like_array(info)) + ok = _broadcast_to(ok[..., None], w.shape) + w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j)) + ok = _broadcast_to(ok[..., None], x.shape) + output = [w] + if compute_left_eigenvectors: + vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j)) + output.append(vl) + if compute_right_eigenvectors: + vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j)) + output.append(vr) + return output + + +def _eig_gpu_lowering(target_name_prefix, ctx, operand, *, + compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): + if ctx.is_forward_compat(): + raise NotImplementedError( + "Export of nonsymmetric eigendecomposition on GPU is not supported " + "because of forward compatibility. The " + "'jax_export_ignore_forward_compatibility' configuration option can be " + "used to disable this check.") + rule = mlir.lower_fun(partial( + _eig_gpu_impl, target_name_prefix, + compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), multiple_results=True) + return rule(ctx, operand) + + def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors), + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors)) def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' @@ -793,6 +904,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, eig_p.def_abstract_eval(eig_abstract_eval) mlir.register_lowering(eig_p, eig_lower) mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'), + platform='cuda') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'), + platform='rocm') batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 03f864919887..76a4abff48ad 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + - At present, non-symmetric eigendecomposition is only implemented on the CPU and + GPU backends. For more details about the GPU implementation, see the + documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 19b82a5ce149..ed815e1b1bd2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// LAPACK uses a packed representation to represent a mixture of real -// eigenvectors and complex conjugate pairs. This helper unpacks the -// representation into regular complex matrices. -template -static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { - for (int j = 0; j < n;) { - if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { - // Real values in each row without imaginary part - // Second row of the imaginary part is not provided - for (int i = 0; i < n; ++i) { - unpacked[j * n + i] = {packed[j * n + i], 0.}; - } - ++j; - } else { - // Complex values where the real part is in the jth row - // and the imaginary part is in the next row (j + 1) - for (int i = 0; i < n; ++i) { - const T real_part = packed[j * n + i]; - const T imag_part = packed[(j + 1) * n + i]; - unpacked[j * n + i] = {real_part, imag_part}; - unpacked[(j + 1) * n + i] = {real_part, -imag_part}; - } - j += 2; - } - } -} - // lapack geev template diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 7d15e494fffc..cddcb1162120 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ +#include #include #include #include @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian { // lapack geev +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; + } + ++j; + } else { + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; + } + j += 2; + } + } +} + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 34e40d12d5be..afce2c000ecc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -476,6 +476,55 @@ pybind_extension( ], ) +cc_library( + name = "cuda_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + module_name = "_hybrid", + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_hybrid_kernels", + ":cuda_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "cuda_gpu_kernels", srcs = ["//jaxlib/gpu:gpu_kernels.cc"], @@ -633,6 +682,7 @@ py_library( name = "cuda_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 7d50a91cfcda..e888f6a42a9b 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -37,6 +37,9 @@ exports_files(srcs = [ "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", + "hybrid.cc", + "hybrid_kernels.cc", + "hybrid_kernels.h", "linalg.cc", "linalg_kernels.cc", "linalg_kernels.cu.cc", diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc new file mode 100644 index 000000000000..afe95a650d29 --- /dev/null +++ b/jaxlib/gpu/hybrid.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/gpu/hybrid_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + +void GetLapackKernelsFromScipy() { + static bool initialized = false; // Protected by GIL + if (initialized) return; + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + auto lapack_ptr = [&](const char* name) { + return nb::cast(lapack_capi[name]).data(); + }; + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>(lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); +} + +NB_MODULE(_hybrid, m) { + m.def("initialize", GetLapackKernelsFromScipy); + m.def("has_magma", []() { return MagmaLookup().FindMagmaInit().ok(); }); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); + dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + return dict; + }); +} + +} // namespace +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc new file mode 100644 index 000000000000..1ce2e547b11f --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -0,0 +1,631 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/hybrid_kernels.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace ffi = ::xla::ffi; + +// This helper class is used to define a host buffer that can be copied to and +// from a device buffer. +template +class HostBuffer { + public: + HostBuffer(std::size_t size) : size_(size) { + data_ = std::unique_ptr(new T[size]); + } + + absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T), + gpuMemcpyDeviceToHost, stream)); + } + + absl::Status CopyToDevice(gpuStream_t stream, T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T), + gpuMemcpyHostToDevice, stream)); + } + + T* get() const { return data_.get(); } + + private: + std::unique_ptr data_; + size_t size_; +}; + +// Forwarded from MAGMA for use as an input parameter. +typedef enum { + MagmaNoVec = 301, + MagmaVec = 302, +} magma_vec_t; + +// Compile time lookup of MAGMA function names depending on the data type. +template +struct always_false : std::false_type {}; + +template +struct MagmaGeev { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_sgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_dgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_cgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_zgeev"; +}; + +MagmaLookup::~MagmaLookup() { + if (initialized_) { + void* magma_finalize = dlsym(handle_, "magma_finalize"); + if (magma_finalize != nullptr) { + reinterpret_cast(magma_finalize)(); + } + } + if (handle_ != nullptr) { + dlclose(handle_); + } +} + +absl::StatusOr MagmaLookup::FindMagmaInit() { + void* magma_init = nullptr; + std::vector paths; + const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH"); + if (magma_lib_path != nullptr) { + paths.push_back(magma_lib_path); + } else { + paths.push_back("libmagma.so.2"); + paths.push_back("libmagma.so"); + paths.push_back(nullptr); + } + for (const auto& path : paths) { + handle_ = dlopen(path, RTLD_LAZY); + if (handle_ != nullptr) { + magma_init = dlsym(handle_, "magma_init"); + if (magma_init != nullptr) { + if (path != nullptr) { + lib_path_ = std::string(path); + } + break; + } + } + } + if (handle_ == nullptr || magma_init == nullptr) { + return absl::InternalError( + "Unable to dlopen a MAGMA shared library that defines a magma_init " + "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to " + "specify an explicit path to the library."); + } + return magma_init; +} + +absl::Status MagmaLookup::Initialize() { + if (failed_) { + return absl::InternalError("MAGMA initialization was unsuccessful."); + } + if (!initialized_) { + auto maybe_magma_init = FindMagmaInit(); + if (!maybe_magma_init.ok()) { + failed_ = true; + return maybe_magma_init.status(); + } + reinterpret_cast(maybe_magma_init.value())(); + initialized_ = true; + } + return absl::OkStatus(); +} + +absl::StatusOr MagmaLookup::Find(const char name[]) { + if (!initialized_) { + return absl::InternalError("MAGMA support has not been initialized."); + } + + auto it = symbols_.find(name); + if (it != symbols_.end()) return it->second; + + void* symbol = dlsym(handle_, name); + if (symbol == nullptr) { + if (lib_path_.has_value()) { + return absl::InternalError(absl::StrFormat( + "Unable to load the symbol '%s' from the MAGMA library at '%s'.", + name, lib_path_.value())); + + } else { + return absl::InternalError(absl::StrFormat( + "Unable to load a globally defined symbol called '%s'. Use the " + "JAX_GPU_MAGMA_PATH environment variable to specify an explicit " + "path to the library.", + name)); + } + } + + symbols_.insert({name, symbol}); + return symbol; +} + +// Lookup the MAGMA symbol for the given function name. This function only +// dlopen the MAGMA library once per process. +absl::StatusOr FindMagmaSymbol(const char name[]) { + static absl::Mutex mu; + static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); + absl::MutexLock lock(&mu); + auto status = lookup.Initialize(); + if (!status.ok()) { + return status; + } + return lookup.Find(name); +} + +// Real-valued eigendecomposition + +template +class EigRealHost { + using Real = ffi::NativeType; + + public: + explicit EigRealHost() = default; + EigRealHost(EigRealHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi, + vl, &n_, vr, &n_, work, &lwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigRealMagma { + using Real = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*, + int, Real*, int, Real*, int, int*); + + public: + explicit EigRealMagma() = default; + EigRealMagma(EigRealMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Real query_host; + fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n, + &query_host, -1, &query_info); + return static_cast(query_host); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info); + } + + private: + int n_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto wr_host = HostBuffer(batch * cols); + auto wi_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto work_left = AllocateScratchMemory(cols * cols); + auto work_right = AllocateScratchMemory(cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](auto value) { return std::isfinite(value); }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols, + wi_host.get() + i * cols, work_left.get(), work_right.get(), + work_host.get(), lwork, info_host.get() + i); + if (info_host.get()[i] == 0) { + if (left) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(), + vl_host.get() + i * cols * cols); + } + if (right) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(), + vr_host.get() + i * cols * cols); + } + } + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + wr_host.CopyToDevice(stream, wr->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + wi_host.CopyToDevice(stream, wi->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigRealDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != wr->element_type() || dataType != wi->element_type() || + ffi::ToComplex(dataType) != vl->element_type() || + ffi::ToComplex(dataType) != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig")); + FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::F32: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + case ffi::F64: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // wr + .Ret() // wi + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +// Complex-valued eigendecomposition + +template +class EigCompHost { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + public: + explicit EigCompHost() = default; + EigCompHost(EigCompHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_, + w, vl, &n_, vr, &n_, work, + &lwork, rwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigCompMagma { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*, + Complex*, int, Complex*, int, Complex*, int, Real*, int*); + + public: + explicit EigCompMagma() = default; + EigCompMagma(EigCompMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + lda_ = std::max(n_, 1); + ldvl_ = left ? n_ : 1; + ldvr_ = right ? n_ : 1; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Complex query_host; + fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr, + ldvr_, &query_host, -1, nullptr, &query_info); + return static_cast(query_host.real()); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork, + rwork, info); + } + + private: + int n_, lda_, ldvl_, ldvr_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto w_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto rwork_host = + AllocateScratchMemory(2 * cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols, + vl_host.get() + i * cols * cols, + vr_host.get() + i * cols * cols, work_host.get(), lwork, + rwork_host.get(), info_host.get() + i); + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + w_host.CopyToDevice(stream, w->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigCompDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != w->element_type() || dataType != vl->element_type() || + dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::C64: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + case ffi::C128: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, + stream, left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h new file mode 100644 index 000000000000..2890837a2bd5 --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_ +#define JAXLIB_GPU_HYBRID_KERNELS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +// The MagmaLookup class is used for dlopening the MAGMA shared library, +// initializing it, and looking up MAGMA symbols. +class MagmaLookup { + public: + explicit MagmaLookup() = default; + ~MagmaLookup(); + absl::StatusOr FindMagmaInit(); + absl::Status Initialize(); + absl::StatusOr Find(const char name[]); + + private: + bool initialized_ = false; + bool failed_ = false; + void* handle_ = nullptr; + std::optional lib_path_ = std::nullopt; + absl::flat_hash_map symbols_; +}; + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_HYBRID_KERNELS_H_ diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 03fd43e9ef89..59819f1fc914 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -56,6 +56,21 @@ xla_client.register_custom_call_target(_name, _value, platform="CUDA", api_version=api_version) +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: + try: + _cuhybrid = importlib.import_module( + f"{cuda_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _cuhybrid = None + else: + break + +if _cuhybrid: + for _name, _value in _cuhybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=1) + try: from .rocm import _blas as _hipblas # pytype: disable=import-error except ImportError: @@ -88,6 +103,34 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hiphybrid = importlib.import_module( + f"{rocm_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _hiphybrid = None + else: + break + +if _hiphybrid: + for _name, _value in _hiphybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=1) + +def initialize_hybrid_kernels(): + if _cuhybrid: + _cuhybrid.initialize() + if _hiphybrid: + _hiphybrid.initialize() + +def has_magma(): + if _cuhybrid: + return _cuhybrid.has_magma() + if _hiphybrid: + return _hiphybrid.has_magma() + return False + def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index b5bfe733b992..2bae7ab2a203 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -66,6 +66,7 @@ _py_deps = { "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], + "magma": [], "matplotlib": ["@pypi_matplotlib//:pkg"], "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index c9b73a5785f1..1076f9a77bf8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -389,6 +389,48 @@ pybind_extension( ], ) +cc_library( + name = "hip_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_hybrid", + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_hybrid_kernels", + ":hip_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + cc_library( name = "triton_kernels", srcs = ["//jaxlib/gpu:triton_kernels.cc"], @@ -456,6 +498,7 @@ py_library( name = "rocm_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_solver", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 5b3ac636303a..9a47c6ad5409 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -108,6 +108,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", @@ -144,6 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 438cebca2b06..4db36fa0ea97 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) @@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", ], ) diff --git a/tests/BUILD b/tests/BUILD index c80f63e6d7d6..bd4312e4aa24 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -664,6 +664,13 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "magma_linalg_test", + srcs = ["magma_linalg_test.py"], + enable_backends = ["gpu"], + deps = py_deps("magma"), +) + jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a1817f528f27..7aad5634775d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1492,8 +1492,8 @@ def testTrimZerosNotOneDArray(self): def testPoly(self, a_shape, dtype, rank): if dtype in (np.float16, jnp.bfloat16, np.int16): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d3fe8f476722..d0b109dda07e 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -250,11 +251,11 @@ def testIssue1213(self): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -293,12 +294,12 @@ def check_left_eigenvectors(a, w, vl): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -309,15 +310,15 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, - ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + ) + @jtu.run_on_devices("cpu", "gpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -329,10 +330,10 @@ def testEigvalsGrad(self, shape, dtype): shape=[(4, 4), (5, 5), (50, 50)], dtype=float_types + complex_types, ) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -340,9 +341,11 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -350,8 +353,10 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py new file mode 100644 index 000000000000..d2abb9fe3a0b --- /dev/null +++ b/tests/magma_linalg_test.py @@ -0,0 +1,125 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np + +from absl.testing import absltest + +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import linalg as lax_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import version as jaxlib_version + +config.parse_flags_with_absl() + +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex + + +class MagmaLinalgTest(jtu.JaxTestCase): + + @jtu.sample_product( + shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEig(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + rng = jtu.rand_default(self.rng()) + n = shape[-1] + args_maker = lambda: [rng(shape, dtype)] + + # Norm, adjusted for dimension and type. + def norm(x): + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / ((n + 1) * jnp.finfo(dtype).eps) + + def check_right_eigenvectors(a, w, vr): + self.assertTrue( + np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) + + def check_left_eigenvectors(a, w, vl): + rank = len(a.shape) + aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) + wC = jnp.conj(w) + check_right_eigenvectors(aH, wC, vl) + + a, = args_maker() + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + + self._CompileAndCheck(jnp.linalg.eig, args_maker, rtol=1e-3) + + @jtu.sample_product( + shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + a = jnp.full(shape, jnp.nan, dtype) + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + for result in results: + self.assertTrue(np.all(np.isnan(result))) + + def testEigMagmaConfig(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + a = rng((5, 5), np.float32) + with config.gpu_use_magma("on"): + hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text() + self.assertIn('magma = "on"', hlo) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())