Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 19, 2024
1 parent 42d0293 commit 5a3778c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 55 deletions.
137 changes: 85 additions & 52 deletions xarray/namedarray/_array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__all__ += ["__array_api_version__"]

from xarray.namedarray._array_api.array_object import Array
from xarray.namedarray.core import NamedArray as Array

__all__ += ["Array"]

Expand All @@ -15,37 +15,37 @@
from xarray.namedarray._array_api.creation_functions import (
arange,
asarray,
empty,
empty_like,
eye,
# empty,
# empty_like,
# eye,
full,
full_like,
# full_like,
linspace,
meshgrid,
# meshgrid,
ones,
ones_like,
tril,
triu,
# ones_like,
# tril,
# triu,
zeros,
zeros_like,
# zeros_like,
)

__all__ += [
"arange",
"asarray",
"empty",
"empty_like",
"eye",
# "empty",
# "empty_like",
# "eye",
"full",
"full_like",
# "full_like",
"linspace",
"meshgrid",
# "meshgrid",
"ones",
"ones_like",
"tril",
"triu",
# "ones_like",
# "tril",
# "triu",
"zeros",
"zeros_like",
# "zeros_like",
]

from xarray.namedarray._array_api.data_type_functions import (
Expand All @@ -57,7 +57,14 @@
result_type,
)

__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"]
__all__ += [
"astype",
"can_cast",
"finfo",
"iinfo",
"isdtype",
"result_type",
]

from xarray.namedarray._array_api.dtypes import (
bool,
Expand Down Expand Up @@ -215,56 +222,82 @@
"trunc",
]

from xarray.namedarray._array_api.indexing_functions import take
# from xarray.namedarray._array_api.indexing_functions import take

__all__ += ["take"]
# __all__ += ["take"]

from xarray.namedarray._array_api.linear_algebra_functions import (
matmul,
matrix_transpose,
outer,
tensordot,
vecdot,
)
# from xarray.namedarray._array_api.linear_algebra_functions import (
# matmul,
# matrix_transpose,
# outer,
# tensordot,
# vecdot,
# )

__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]
# __all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]

from xarray.namedarray._array_api.manipulation_functions import (
broadcast_arrays,
broadcast_to,
concat,
# broadcast_arrays,
# broadcast_to,
# concat,
expand_dims,
flip,
moveaxis,
# flip,
# moveaxis,
permute_dims,
reshape,
roll,
squeeze,
stack,
# roll,
# squeeze,
# stack,
)

__all__ += [
"broadcast_arrays",
"broadcast_to",
"concat",
# "broadcast_arrays",
# "broadcast_to",
# "concat",
"expand_dims",
"flip",
"moveaxis",
# "flip",
# "moveaxis",
"permute_dims",
"reshape",
"roll",
"squeeze",
"stack",
# "roll",
# "squeeze",
# "stack",
]

from xarray.namedarray._array_api.searching_functions import argmax, argmin, where
# from xarray.namedarray._array_api.searching_functions import (
# argmax,
# argmin,
# where,
# )

__all__ += ["argmax", "argmin", "where"]
# __all__ += [
# "argmax",
# "argmin",
# "where",
# ]

from xarray.namedarray._array_api.statistical_functions import max, mean, min, prod, sum
from xarray.namedarray._array_api.statistical_functions import (
# max,
mean,
# min,
# prod,
# sum,
)

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += [
# "max",
"mean",
# "min",
# "prod",
# "sum",
]

from xarray.namedarray._array_api.utility_functions import all, any
from xarray.namedarray._array_api.utility_functions import (
all,
any,
)

__all__ += ["all", "any"]
__all__ += [
"all",
"any",
]
5 changes: 4 additions & 1 deletion xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import TYPE_CHECKING, Any, ModuleType
from __future__ import annotations

from types import ModuleType
from typing import TYPE_CHECKING, Any

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/_array_api/indexing_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from xarray.namedarray._array_api._utils import _get_data_namespace

sdf = _get_data_namespace()
sdf = _get_data_namespace
3 changes: 3 additions & 0 deletions xarray/namedarray/_array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from typing import Any

from xarray.namedarray._array_api._utils import _get_data_namespace
from xarray.namedarray._array_api.creation_functions import asarray

from xarray.namedarray._typing import (
Default,
_arrayapi,
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def nbytes(self) -> _IntOrUnknown:
If the underlying data array does not include ``nbytes``, estimates
the bytes consumed based on the ``size`` and ``dtype``.
"""
from xarray.namedarray._array_api import _get_data_namespace
from xarray.namedarray._array_api._utils import _get_data_namespace

if hasattr(self._data, "nbytes"):
return self._data.nbytes # type: ignore[no-any-return]
Expand Down

0 comments on commit 5a3778c

Please sign in to comment.