diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 042b01f31fd..9cfc94f5356 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -413,100 +413,22 @@ def __len__(self) -> _IntOrUnknown: except Exception as exc: raise TypeError("len() of unsized object") from exc - # Array API: - # Attributes: + # < Array api > - @property - def device(self) -> _Device: - """ - Device of the array’s elements. - - See Also - -------- - ndarray.device - """ - if isinstance(self._data, _arrayapi): - return self._data.device - else: - raise NotImplementedError("self._data missing device") - - @property - def dtype(self) -> _DType_co: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def mT(self): - raise NotImplementedError("Todo: ") - - @property - def ndim(self) -> int: - """ - Number of array dimensions. - - See Also - -------- - numpy.ndarray.ndim - """ - return len(self.shape) - - @property - def shape(self) -> _Shape: - """ - Get the shape of the array. - - Returns - ------- - shape : tuple of ints - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape - - @property - def size(self) -> _IntOrUnknown: - """ - Number of elements in the array. - - Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. - - See Also - -------- - numpy.ndarray.size - """ - return math.prod(self.shape) + def __abs__(self, /) -> Self: + from xarray.namedarray._array_api import abs - def to_device(self, device: _Device, /, stream: None = None) -> Self: - if isinstance(self._data, _arrayapi): - return self._replace(data=self._data.to_device(device, stream=stream)) - else: - raise NotImplementedError("Only array api are valid.") + return abs(self) - @property - def T(self) -> NamedArray[Any, _DType_co]: - """Return a new object with transposed dimensions.""" - if self.ndim != 2: - raise ValueError( - f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." - ) + def __add__(self, other: int | float | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import add, asarray - return self.permute_dims() + return add(self, asarray(other)) - # methods - def __abs__(self, /): - from xarray.namedarray._array_api import abs + def __and__(self, other, /): + from xarray.namedarray._array_api import bitwise_and - return abs(self) + return bitwise_and(self, asarray(other)) # def __array_namespace__(self, /, *, api_version=None): # if api_version is not None and api_version not in ( @@ -525,16 +447,6 @@ def __bool__(self, /) -> bool: def __complex__(self, /) -> complex: return self._data.__complex__() - def __float__(self, /) -> float: - return self._data.__float__() - - def __index__(self, /) -> int: - return self._data.__index__() - - def __int__(self, /) -> int: - return self._data.__int__() - - # dlpack def __dlpack__( self, /, @@ -551,208 +463,348 @@ def __dlpack__( def __dlpack_device__(self, /) -> tuple[IntEnum, int]: return self._data.__dlpack_device__() - # Arithmetic Operators + def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import asarray, equal - def __neg__(self, /): - from xarray.namedarray._array_api import negative + return equal(self, asarray(other)) - return negative(self) + def __float__(self, /) -> float: + return self._data.__float__() - def __pos__(self, /): - from xarray.namedarray._array_api import positive + def __floordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide - return positive(self) + return floor_divide(self, asarray(other)) - def __add__(self, other: int | float | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import add, asarray + def __ge__(self, other, /): + from xarray.namedarray._array_api import greater_equal - return add(self, asarray(other)) + return greater_equal(self, asarray(other)) - def __sub__(self, other, /): - from xarray.namedarray._array_api import subtract + def __getitem__(self, key: _IndexKeyLike | NamedArray): + if isinstance(key, int | slice | tuple): + _data = self._data[key] + return self._new((), _data) + elif isinstance(key, NamedArray): + _key = self._data # TODO: Transpose, unordered dims shouldn't matter. + _data = self._data[_key] + return self._new(key._dims, _data) + else: + raise NotImplementedError("{k=} is not supported") - return subtract(self, other) + def __gt__(self, other, /): + from xarray.namedarray._array_api import greater - def __mul__(self, other, /): - from xarray.namedarray._array_api import multiply + return greater(self, asarray(other)) - return multiply(self, other) + def __index__(self, /) -> int: + return self._data.__index__() - def __truediv__(self, other, /): - from xarray.namedarray._array_api import divide + def __int__(self, /) -> int: + return self._data.__int__() - return divide(self, other) + def __invert__(self, /): + from xarray.namedarray._array_api import bitwise_invert - def __floordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + return bitwise_invert(self) - return floor_divide(self, other) + def __iter__(self: NamedArray, /): + from xarray.namedarray._array_api import asarray - def __mod__(self, other, /): - from xarray.namedarray._array_api import remainder + # TODO: smarter way to retain dims, xarray? + return (asarray(i) for i in self._data) - return remainder(self, other) + def __le__(self, other, /): + from xarray.namedarray._array_api import less_equal - def __pow__(self, other, /): - from xarray.namedarray._array_api import pow + return less_equal(self, asarray(other)) - return pow(self, other) + def __lshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift - # Array Operators + return bitwise_left_shift(self) + + def __lt__(self, other, /): + from xarray.namedarray._array_api import less + + return less(self, asarray(other)) def __matmul__(self, other, /): from xarray.namedarray._array_api import matmul - return matmul(self, other) + return matmul(self, asarray(other)) - # Bitwise Operators + def __mod__(self, other, /): + from xarray.namedarray._array_api import remainder - def __invert__(self, /): - from xarray.namedarray._array_api import bitwise_invert + return remainder(self, asarray(other)) - return bitwise_invert(self) + def __mul__(self, other, /): + from xarray.namedarray._array_api import multiply - def __and__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + return multiply(self, asarray(other)) + + def __ne__(self, other, /): + from xarray.namedarray._array_api import not_equal + + return not_equal(self, asarray(other)) + + def __neg__(self, /): + from xarray.namedarray._array_api import negative - return bitwise_and(self) + return negative(self) def __or__(self, other, /): from xarray.namedarray._array_api import bitwise_or return bitwise_or(self) - def __xor__(self, other, /): - from xarray.namedarray._array_api import bitwise_xor + def __pos__(self, /): + from xarray.namedarray._array_api import positive - return bitwise_xor(self) + return positive(self) - def __lshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift + def __pow__(self, other, /): + from xarray.namedarray._array_api import pow - return bitwise_left_shift(self) + return pow(self, asarray(other)) def __rshift__(self, other, /): from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(self) - # Comparison Operators - def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import asarray, equal + def __setitem__( + self, + key: _IndexKeyLike, + value: int | float | bool | NamedArray, + /, + ) -> None: + from xarray.namedarray._array_api import asarray - return equal(self, asarray(other)) + if isinstance(key, NamedArray): + key = key._data + self._array.__setitem__(key, asarray(value)._data) - def __ge__(self, other, /): - from xarray.namedarray._array_api import greater_equal + def __sub__(self, other, /): + from xarray.namedarray._array_api import subtract - return greater_equal(self, other) + return subtract(self, asarray(other)) - def __gt__(self, other, /): - from xarray.namedarray._array_api import greater + def __truediv__(self, other, /): + from xarray.namedarray._array_api import divide - return greater(self, other) + return divide(self, asarray(other)) - def __le__(self, other, /): - from xarray.namedarray._array_api import less_equal + def __xor__(self, other, /): + from xarray.namedarray._array_api import bitwise_xor - return less_equal(self, other) + return bitwise_xor(self) - def __lt__(self, other, /): - from xarray.namedarray._array_api import less + def __iadd__(self, other, /): + self._data.__iadd__(other._data) + return self - return less(self, other) + def __radd__(self, other, /): + from xarray.namedarray._array_api import add - def __ne__(self, other, /): - from xarray.namedarray._array_api import not_equal + return add(asarray(other), self) - return not_equal(self, other) + def __iand__(self, other, /): + self._data.__iand__(other._data) + return self - # Reflected Operators + def __rand__(self, other, /): + from xarray.namedarray._array_api import bitwise_and - # (Reflected) Arithmetic Operators + return bitwise_and(asarray(other), self) - def __radd__(self, other, /): - from xarray.namedarray._array_api import add + def __ifloordiv__(self, other, /): + self._data.__ifloordiv__(other._data) + return self - return add(other, self) + def __rfloordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide - def __rsub__(self, other, /): - from xarray.namedarray._array_api import subtract + return floor_divide(asarray(other), self) - return subtract(other, self) + def __ilshift__(self, other, /): + self._data.__ilshift__(other._data) + return self - def __rmul__(self, other, /): - from xarray.namedarray._array_api import multiply + def __rlshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift - return multiply(other, self) + return bitwise_left_shift(asarray(other), self) - def __rtruediv__(self, other, /): - from xarray.namedarray._array_api import divide + def __imatmul__(self, other, /): + self._data.__imatmul__(other._data) + return self - return divide(other, self) + def __rmatmul__(self, other, /): + from xarray.namedarray._array_api import matmul - def __rfloordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + return matmul(asarray(other), self) - return floor_divide(other, self) + def __imod__(self, other, /): + self._data.__imod__(other._data) + return self def __rmod__(self, other, /): from xarray.namedarray._array_api import remainder - return remainder(other, self) + return remainder(asarray(other), self) + + def __imul__(self, other, /): + self._data.__imul__(other._data) + return self + + def __rmul__(self, other, /): + from xarray.namedarray._array_api import multiply + + return multiply(asarray(other), self) + + def __ior__(self, other, /): + self._data.__ior__(other._data) + return self + + def __ror__(self, other, /): + from xarray.namedarray._array_api import bitwise_or + + return bitwise_or(asarray(other), self) + + def __ipow__(self, other, /): + self._data.__ipow__(other._data) + return self def __rpow__(self, other, /): from xarray.namedarray._array_api import pow - return pow(other, self) + return pow(asarray(other), self) - # (Reflected) Array Operators + def __irshift__(self, other, /): + self._data.__irshift__(other._data) + return self - def __rmatmul__(self, other, /): - from xarray.namedarray._array_api import matmul + def __rrshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_right_shift - return matmul(other, self) + return bitwise_right_shift(asarray(other), self) - # (Reflected) Bitwise Operators + def __isub__(self, other, /): + self._data.__isub__(other._data) + return self - def __rand__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + def __rsub__(self, other, /): + from xarray.namedarray._array_api import subtract - return bitwise_and(other, self) + return subtract(asarray(other), self) - def __ror__(self, other, /): - from xarray.namedarray._array_api import bitwise_or + def __itruediv__(self, other, /): + self._data.__itruediv__(asarray(other)._data) + return self - return bitwise_or(other, self) + def __rtruediv__(self, other, /): + from xarray.namedarray._array_api import divide + + return divide(asarray(other), self) + + def __ixor__(self, other, /): + self._data.__ixor__(other._data) + return self def __rxor__(self, other, /): from xarray.namedarray._array_api import bitwise_xor - return bitwise_xor(other, self) - - def __rlshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift + return bitwise_xor(asarray(other), self) - return bitwise_left_shift(other, self) + def to_device(self, device: _Device, /, stream: None = None) -> Self: + if isinstance(self._data, _arrayapi): + return self._replace(data=self._data.to_device(device, stream=stream)) + else: + raise NotImplementedError("Only array api are valid.") - def __rrshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_right_shift + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. - return bitwise_right_shift(other, self) + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype - # Indexing + @property + def device(self) -> _Device: + """ + Device of the array’s elements. - def __getitem__(self, key: _IndexKeyLike | NamedArray): - if isinstance(key, int | slice | tuple): - _data = self._data[key] - return self._new((), _data) - elif isinstance(key, NamedArray): - _key = self._data # TODO: Transpose, unordered dims shouldn't matter. - _data = self._data[_key] - return self._new(key._dims, _data) + See Also + -------- + ndarray.device + """ + if isinstance(self._data, _arrayapi): + return self._data.device else: - raise NotImplementedError("{k=} is not supported") + raise NotImplementedError("self._data missing device") + + @property + def mT(self): + raise NotImplementedError("Todo: ") + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def size(self) -> _IntOrUnknown: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + @property + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() + + # @property def nbytes(self) -> _IntOrUnknown: