Skip to content

Commit

Permalink
Add searching functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 20, 2024
1 parent d617713 commit 320a692
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 23 deletions.
23 changes: 13 additions & 10 deletions xarray/namedarray/_array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,20 @@
# "stack",
]

# from xarray.namedarray._array_api._searching_functions import (
# argmax,
# argmin,
# where,
# )
from xarray.namedarray._array_api._searching_functions import (
argmax,
argmin,
nonzero,
where,
)

# __all__ += [
# "argmax",
# "argmin",
# "where",
# ]
__all__ += [
"argmax",
"argmin",
"nonzero",
"searchsorted",
"where",
]

from xarray.namedarray._array_api._statistical_functions import (
max,
Expand Down
12 changes: 1 addition & 11 deletions xarray/namedarray/_array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from xarray.namedarray._array_api._utils import (
_get_data_namespace,
# _maybe_default_namespace,
_get_namespace_dtype,
_infer_dims,
)
from xarray.namedarray._typing import (
Default,
Expand Down Expand Up @@ -35,16 +35,6 @@ def _like_args(x, dtype=None, device: _Device | None = None):
return dict(shape=x.shape, dtype=dtype, device=device)


def _infer_dims(
shape: _Shape,
dims: _DimsLike | Default = _default,
) -> _DimsLike:
if dims is _default:
return tuple(f"dim_{n}" for n in range(len(shape)))
else:
return dims


def arange(
start: int | float,
/,
Expand Down
87 changes: 86 additions & 1 deletion xarray/namedarray/_array_api/_searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,88 @@
from __future__ import annotations

from typing import Any, TYPE_CHECKING

from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims
from xarray.namedarray._typing import (
Default,
_arrayfunction_or_api,
_ArrayLike,
_default,
_Device,
_DimsLike,
_DType,
_Dims,
_Shape,
_ShapeType,
duckarray,
)
from xarray.namedarray.core import (
NamedArray,
_dims_to_axis,
_get_remaining_dims,
)

if TYPE_CHECKING:
from typing import Literal, Optional, Tuple

from xarray.namedarray._array_api._utils import _get_data_namespace

sdf = _get_data_namespace()

def argmax(
x: NamedArray,
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: int | None = None,
) -> NamedArray:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
_dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
return x._new(dims=_dims, data=data_)


def argmin(
x: NamedArray,
/,
*,
dims: _Dims | Default = _default,
keepdims: bool = False,
axis: int | None = None,
) -> NamedArray:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
_dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
return x._new(dims=_dims, data=data_)


def nonzero(x: NamedArray, /) -> tuple[NamedArray, ...]:
xp = _get_data_namespace(x)
_datas = xp.nonzero(x._data)
# TODO: Verify that dims and axis matches here:
return tuple(x._new(dim, i) for dim, i in zip(x.dims, _datas))


def searchsorted(
x1: NamedArray,
x2: NamedArray,
/,
*,
side: Literal["left", "right"] = "left",
sorter: NamedArray | None = None,
) -> NamedArray:
xp = _get_data_namespace(x1)
_data = xp.searchsorted(x1._data, x2._data, side=side, sorter=sorter)
# TODO: Check dims, probably can do it smarter:
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)


def where(condition: NamedArray, x1: NamedArray, x2: NamedArray, /) -> NamedArray:
xp = _get_data_namespace(x1)
_data = xp.where(condition._data, x1._data, x2._data)
return x1._new(x1.dims, _data)
26 changes: 25 additions & 1 deletion xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
from types import ModuleType
from typing import TYPE_CHECKING, Any

from xarray.namedarray._typing import _arrayapi, _dtype
from xarray.namedarray._typing import (
Default,
_arrayfunction_or_api,
_ArrayLike,
_default,
_arrayapi,
_Device,
_DimsLike,
_DType,
_Dims,
_Shape,
_ShapeType,
duckarray,
_dtype,
)

if TYPE_CHECKING:
from xarray.namedarray.core import NamedArray
Expand Down Expand Up @@ -34,3 +48,13 @@ def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType:

xp = __import__(dtype.__module__)
return xp


def _infer_dims(
shape: _Shape,
dims: _DimsLike | Default = _default,
) -> _DimsLike:
if dims is _default:
return tuple(f"dim_{n}" for n in range(len(shape)))
else:
return dims

0 comments on commit 320a692

Please sign in to comment.