From 320a692e866fc48c2e25d6189a57dfc0598af7f1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:26:39 +0200 Subject: [PATCH] Add searching functions --- xarray/namedarray/_array_api/__init__.py | 23 ++--- .../_array_api/_creation_functions.py | 12 +-- .../_array_api/_searching_functions.py | 87 ++++++++++++++++++- xarray/namedarray/_array_api/_utils.py | 26 +++++- 4 files changed, 125 insertions(+), 23 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index b0c32cf225c..873bff3c5b0 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -270,17 +270,20 @@ # "stack", ] -# from xarray.namedarray._array_api._searching_functions import ( -# argmax, -# argmin, -# where, -# ) +from xarray.namedarray._array_api._searching_functions import ( + argmax, + argmin, + nonzero, + where, +) -# __all__ += [ -# "argmax", -# "argmin", -# "where", -# ] +__all__ += [ + "argmax", + "argmin", + "nonzero", + "searchsorted", + "where", +] from xarray.namedarray._array_api._statistical_functions import ( max, diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 196c5c9b9f0..31a690d424b 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -6,8 +6,8 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, - # _maybe_default_namespace, _get_namespace_dtype, + _infer_dims, ) from xarray.namedarray._typing import ( Default, @@ -35,16 +35,6 @@ def _like_args(x, dtype=None, device: _Device | None = None): return dict(shape=x.shape, dtype=dtype, device=device) -def _infer_dims( - shape: _Shape, - dims: _DimsLike | Default = _default, -) -> _DimsLike: - if dims is _default: - return tuple(f"dim_{n}" for n in range(len(shape))) - else: - return dims - - def arange( start: int | float, /, diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index b62d0a8393b..2ee30632b2f 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -1,3 +1,88 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims +from xarray.namedarray._typing import ( + Default, + _arrayfunction_or_api, + _ArrayLike, + _default, + _Device, + _DimsLike, + _DType, + _Dims, + _Shape, + _ShapeType, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) + +if TYPE_CHECKING: + from typing import Literal, Optional, Tuple + from xarray.namedarray._array_api._utils import _get_data_namespace -sdf = _get_data_namespace() + +def argmax( + x: NamedArray, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: int | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + return x._new(dims=_dims, data=data_) + + +def argmin( + x: NamedArray, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: int | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + return x._new(dims=_dims, data=data_) + + +def nonzero(x: NamedArray, /) -> tuple[NamedArray, ...]: + xp = _get_data_namespace(x) + _datas = xp.nonzero(x._data) + # TODO: Verify that dims and axis matches here: + return tuple(x._new(dim, i) for dim, i in zip(x.dims, _datas)) + + +def searchsorted( + x1: NamedArray, + x2: NamedArray, + /, + *, + side: Literal["left", "right"] = "left", + sorter: NamedArray | None = None, +) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.searchsorted(x1._data, x2._data, side=side, sorter=sorter) + # TODO: Check dims, probably can do it smarter: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def where(condition: NamedArray, x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.where(condition._data, x1._data, x2._data) + return x1._new(x1.dims, _data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 018b84ca54c..f01ec05b146 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -3,7 +3,21 @@ from types import ModuleType from typing import TYPE_CHECKING, Any -from xarray.namedarray._typing import _arrayapi, _dtype +from xarray.namedarray._typing import ( + Default, + _arrayfunction_or_api, + _ArrayLike, + _default, + _arrayapi, + _Device, + _DimsLike, + _DType, + _Dims, + _Shape, + _ShapeType, + duckarray, + _dtype, +) if TYPE_CHECKING: from xarray.namedarray.core import NamedArray @@ -34,3 +48,13 @@ def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: xp = __import__(dtype.__module__) return xp + + +def _infer_dims( + shape: _Shape, + dims: _DimsLike | Default = _default, +) -> _DimsLike: + if dims is _default: + return tuple(f"dim_{n}" for n in range(len(shape))) + else: + return dims