Skip to content

Commit

Permalink
add support for hashable dims (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Dec 21, 2023
1 parent f51ac73 commit d91642d
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 62 deletions.
9 changes: 4 additions & 5 deletions docs/source/background/einops.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ of the elements respectively: `->`, space as delimiter and parenthesis:
side in the einops notation is only used to label the dimensions.
In fact, 5/7 examples in https://einops.rocks/api/rearrange/ fall in this category.
This is not necessary when working with xarray objects.
* In xarray dimension names can be any {term}`hashable <xarray:name>`. `xarray-einstats` only
supports strings as dimension names, but the space can't be used as delimiter.
* In xarray dimension names can be any {term}`hashable <xarray:name>`.
* In xarray dimensions are labeled and the order doesn't matter.
This might seem the same as the first reason but it is not. When splitting
or stacking dimensions you need (and want) the names of both parent and children dimensions.
Expand All @@ -25,8 +24,8 @@ of the elements respectively: `->`, space as delimiter and parenthesis:

However, there are also many cases in which dimension names in xarray will be strings
without any spaces nor parenthesis in them. So similarly to the option of
doing `da.stack(dim=("dim1", "dim2"))` which can't be used for all valid
dimension names but is generally easier to write and less error prone,
doing `da.stack(dim=["dim1", "dim2"])` which can't be used for all valid
dimension names but is generally easier to write and less error prone than,
`xarray_einstats.einops` also provides two possible syntaxes.

The guiding principle of the einops module is to take the input expressions
Expand All @@ -37,7 +36,7 @@ labeled, we can take advantage of that during the translation process
and thus support "partial" expressions that cover only the dimensions
that will be modified.

Another important consideration is to take into account that _in xarray_,
Another important consideration is to take into account that _in xarray_
dimension order should not matter, hence the constraint of using dicts
on the left side. Imposing this constraint also
makes our job of filling in the "partial" expressions much easier.
Expand Down
94 changes: 46 additions & 48 deletions src/xarray_einstats/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
import warnings
from collections.abc import Hashable

import einops
import xarray as xr
Expand Down Expand Up @@ -61,7 +62,7 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
allow_dict, allow_list : bool, optional
Whether or not to allow lists or dicts as elements of ``redims``.
When processing ``in_dims`` for example we need the names of
the variables to be decomposed so dicts are required and lists/tuples
the variables to be decomposed so dicts are required and lists
are not accepted.
Returns
Expand All @@ -85,14 +86,14 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
from xarray_einstats.einops import process_pattern_list, DimHandler
handler = DimHandler()
process_pattern_list(["a", {"b": ("c", "d")}, ("e", "f", "g")], handler)
process_pattern_list(["a", {"b": ["c", "d"]}, ["e", "f", "g"]], handler)
"""
out = []
out_names = []
txt = []
for subitem in redims:
if isinstance(subitem, str):
if isinstance(subitem, Hashable):
out.append(subitem)
out_names.append(subitem)
txt.append(handler.get_name(subitem))
Expand All @@ -103,8 +104,10 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
f"found {len(subitem)}: {subitem.keys()}"
)
key, values = list(subitem.items())[0]
if isinstance(values, str):
raise ValueError("Found values of type str in a pattern dict, use xarray.rename")
if isinstance(values, Hashable):
raise ValueError(
"Found values of hashable type in a pattern dict, use xarray.rename"
)
out.extend(values)
out_names.append(key)
txt.append(f"( {handler.get_names(values)} )")
Expand Down Expand Up @@ -182,7 +185,7 @@ def translate_pattern(pattern):
return dims


def _rearrange(da, out_dims, in_dims=None, **kwargs):
def _rearrange(da, out_dims, in_dims=None, dim_lengths=None):
"""Wrap `einops.rearrange <https://einops.rocks/api/rearrange/>`_.
This is the function that actually interfaces with ``einops``.
Expand All @@ -198,11 +201,14 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
See docstring of :func:`~xarray_einstats.einops.rearrange`
in_dims : list of str or dict, optional
See docstring of :func:`~xarray_einstats.einops.rearrange`
kwargs : dict, optional
dim_lengths : dict, optional
kwargs with key equal to dimension names in ``out_dims``
(that is, strings or dict keys) are passed to einops.rearrange
the rest of keys are passed to :func:`xarray.apply_ufunc`
"""
if dim_lengths is None:
dim_lengths = {}

da_dims = da.dims

handler = DimHandler()
Expand Down Expand Up @@ -231,9 +237,9 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
{non_core_pattern} {handler.get_names(missing_out_dims)} {out_pattern}"

axes_lengths = {
handler.rename_kwarg(k): v for k, v in kwargs.items() if k in out_names + out_dims
handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in out_names + out_dims
}
kwargs = {k: v for k, v in kwargs.items() if k not in out_names + out_dims}
kwargs = {k: v for k, v in dim_lengths.items() if k not in out_names + out_dims}
return xr.apply_ufunc(
einops.rearrange,
da,
Expand All @@ -245,7 +251,7 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs):
)


def rearrange(da, pattern, pattern_in=None, **kwargs):
def rearrange(da, pattern, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs):
"""Expose `einops.rearrange <https://einops.rocks/api/rearrange/>`_ with an xarray-like API.
It has two possible syntaxes which are independent and somewhat complementary.
Expand All @@ -268,12 +274,12 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
a default name.
If `pattern` is not a string, then it must be a list where each of its elements
is one of: ``str``, ``list`` (to stack those dimensions and give them an
arbitrary name) or ``dict of {str: list}`` (to stack the dimensions indicated
is one of: :term:`python:hashable`, ``list`` (to stack those dimensions and
give them an arbitrary name) or ``dict`` (to stack the dimensions indicated
as values of the dictionary and name the resulting dimensions with the key).
`pattern` is then interpreted as the output side of the einops pattern. See
TODO for more details.
`pattern` is then interpreted as the output side of the einops pattern.
See :ref:`about_einops` for more details.
pattern_in : list of [str or dict], optional
The input pattern for the dimensions. It can only be provided if `pattern`
is a ``list``. Also, note this is only necessary if you want to split some dimensions.
Expand All @@ -282,28 +288,22 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
with the only difference that ``list`` elements are not allowed, the same way
that ``(dim1 dim2)=dim`` is required on the left hand side when using string
patterns.
kwargs : dict, optional
Passed to :func:`xarray_einstats.einops.rearrange`
dim_lengths, **dim_lengths_kwargs : dict, optional
If the keys are dimensions present in `pattern` they will be passed to
`einops.rearrange <https://einops.rocks/api/rearrange/>`_, otherwise,
they are passed to :func:`xarray.apply_ufunc`.
Returns
-------
xarray.DataArray
Notes
-----
Unlike for general xarray objects, where dimension
names can be :term:`hashable <xarray:name>` here
dimension names are not recommended but required to be
strings for both cases. Future releases however might
support this when using lists as `pattern`, comment
on :issue:`50` if you are interested in the feature
or could help implement it.
See Also
--------
xarray_einstats.einops.reduce
"""
if dim_lengths is None:
dim_lengths = {}
dim_lengths = {**dim_lengths, **dim_lengths_kwargs}
if isinstance(pattern, str):
if "->" in pattern:
in_pattern, out_pattern = pattern.split("->")
Expand All @@ -312,11 +312,11 @@ def rearrange(da, pattern, pattern_in=None, **kwargs):
out_pattern = pattern
in_dims = None
out_dims = translate_pattern(out_pattern)
return _rearrange(da, out_dims=out_dims, in_dims=in_dims, **kwargs)
return _rearrange(da, out_dims=pattern, in_dims=pattern_in, **kwargs)
return _rearrange(da, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths)
return _rearrange(da, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths)


def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
def _reduce(da, reduction, out_dims, in_dims=None, dim_lengths=None):
"""Wrap `einops.reduce <https://einops.rocks/api/reduce/>`_.
This is the function that actually interfaces with ``einops``.
Expand All @@ -338,11 +338,14 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
in_dims : list of str or dict, optional
The input pattern for the dimensions.
This is only necessary if you want to split some dimensions.
kwargs : dict, optional
dim_lengths : dict, optional
kwargs with key equal to dimension names in ``out_dims``
(that is, strings or dict keys) are passed to einops.rearrange
the rest of keys are passed to :func:`xarray.apply_ufunc`
"""
if dim_lengths is None:
dim_lengths = {}

da_dims = da.dims

handler = DimHandler()
Expand All @@ -361,8 +364,8 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
pattern = f"{handler.get_names(missing_in_dims)} {in_pattern} -> {out_pattern}"

all_dims = set(out_dims + out_names + in_names + in_dims)
axes_lengths = {handler.rename_kwarg(k): v for k, v in kwargs.items() if k in all_dims}
kwargs = {k: v for k, v in kwargs.items() if k not in all_dims}
axes_lengths = {handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in all_dims}
kwargs = {k: v for k, v in dim_lengths.items() if k not in all_dims}
return xr.apply_ufunc(
einops.reduce,
da,
Expand All @@ -375,7 +378,7 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs):
)


def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
def reduce(da, pattern, reduction, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs):
"""Expose `einops.reduce <https://einops.rocks/api/reduce/>`_ with an xarray-like API.
It has two possible syntaxes which are independent and somewhat complementary.
Expand Down Expand Up @@ -412,27 +415,22 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
The syntax and interpretation is the same as the case when `pattern` is a list,
with the only difference that ``list`` elements are not allowed, the same way
that ``(dim1 dim2)=dim`` is required on the left hand side when using string
kwargs : dict, optional
Passed to :func:`xarray_einstats.einops.reduce`
dim_lengths, **dim_lengths_kwargs : dict, optional
If the keys are dimensions present in `pattern` they will be passed to
`einops.reduce <https://einops.rocks/api/reduce/>`_, otherwise,
they are passed to :func:`xarray.apply_ufunc`.
Returns
-------
xarray.DataArray
Notes
-----
Unlike for general xarray objects, where dimension
names can be :term:`hashable <xarray:name>` here
dimension names are not recommended but required to be
strings for both cases. Future releases however might
support this when using lists as `pattern`, comment
on :issue:`50` if you are interested in the feature
or could help implement it.
See Also
--------
xarray_einstats.einops.rearrange
"""
if dim_lengths is None:
dim_lengths = {}
dim_lengths = {**dim_lengths, **dim_lengths_kwargs}
if isinstance(pattern, str):
if "->" in pattern:
in_pattern, out_pattern = pattern.split("->")
Expand All @@ -441,8 +439,8 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs):
out_pattern = pattern
in_dims = None
out_dims = translate_pattern(out_pattern)
return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, **kwargs)
return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, **kwargs)
return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths)
return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths)


def raw_reduce(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/xarray_einstats/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def ecdf(da, dims=None, *, npoints=None, **kwargs):
dims = da.dims
elif isinstance(dims, str):
dims = [dims]
total_points = np.product([da.sizes[d] for d in dims])
total_points = np.prod([da.sizes[d] for d in dims])
if npoints is None:
npoints = min(total_points, 200)
x = xr.DataArray(np.linspace(0, 1, npoints), dims=["quantile"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_einops_accessor_rearrange(data):

@pytest.mark.skipif(find_spec("einops") is None, reason="einops must be installed")
def test_einops_accessor_reduce(data):
pattern_in = [{"batch (hh.mm)": ("d1", "d2")}]
pattern_in = [{"batch (hh.mm)": ["d1", "d2"]}]
pattern = ["d1", "subject"]
kwargs = {"d2": 2}
input_data = data.rename({"batch": "batch (hh.mm)"})
Expand Down
31 changes: 26 additions & 5 deletions tests/test_einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ class TestRearrange:
"args",
(
(
{"pattern": [{"dex": ("drug dose (mg)", "experiment")}]},
{"pattern": [{"dex": ["drug dose (mg)", "experiment"]}]},
((4, 6, 8 * 15), ["batch", "subject", "dex"]),
),
(
{
"pattern_in": [{"drug dose (mg)": ("d1", "d2")}],
"pattern_in": [{"drug dose (mg)": ["d1", "d2"]}],
"pattern": ["d1", "d2", "batch"],
"d1": 2,
"d2": 4,
Expand All @@ -80,6 +80,16 @@ def test_rearrange(self, data, args):
assert out_da.shape == shape
assert list(out_da.dims) == dims

def test_rearrange_tuple_dim(self, data):
out_da = rearrange(
data.rename(drug=("drug dose", "mg")),
pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}],
pattern=[("d", 1), ("d", 2), "batch"],
dim_lengths={("d", 1): 2, ("d", 2): 4},
)
assert out_da.shape == (6, 15, 2, 4, 4)
assert list(out_da.dims) == ["subject", "experiment", ("d", 1), ("d", 2), "batch"]


class TestRawReduce:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -110,16 +120,16 @@ class TestReduce:
),
(
{
"pattern_in": [{"batch (hh.mm)": ("d1", "d2")}],
"pattern_in": [{"batch (hh.mm)": ["d1", "d2"]}],
"pattern": ["d1", "subject"],
"d2": 2,
},
((2, 6), ["d1", "subject"]),
),
(
{
"pattern_in": [{"drug": ("d1", "d2")}, {"batch (hh.mm)": ("b1", "b2")}],
"pattern": ["subject", ("b1", "d1")],
"pattern_in": [{"drug": ["d1", "d2"]}, {"batch (hh.mm)": ["b1", "b2"]}],
"pattern": ["subject", ["b1", "d1"]],
"d2": 4,
"b2": 2,
},
Expand All @@ -132,3 +142,14 @@ def test_reduce(self, data, args):
out_da = reduce(data.rename({"batch": "batch (hh.mm)"}), reduction="mean", **kwargs)
assert out_da.shape == shape
assert list(out_da.dims) == dims

def test_reduce_tuple_dim(self, data):
out_da = reduce(
data.rename(drug=("drug dose", "mg")),
reduction="mean",
pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}],
pattern=["subject", ("d", 2), "batch"],
dim_lengths={("d", 1): 2, ("d", 2): 4},
)
assert out_da.shape == (6, 4, 4)
assert list(out_da.dims) == ["subject", ("d", 2), "batch"]
4 changes: 2 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_svd(self, matrices):
u_da, s_da, vh_da = svd(matrices, dims=("dim", "dim2"), out_append="_bis")
s_full = xr.zeros_like(matrices)
idx = xr.DataArray(np.arange(len(matrices["dim"])), dims="pointwise_sel")
s_full.loc[{"dim": idx, "dim2": idx}] = s_da
s_full.loc[{"dim": idx, "dim2": idx}] = s_da.rename(dim="pointwise_sel")
compare = matmul(
matmul(u_da, s_full, dims=[["dim", "dim_bis"], ["dim", "dim2"]]),
vh_da,
Expand All @@ -259,7 +259,7 @@ def test_svd_non_square(self, matrices):
s_full = xr.zeros_like(matrices)
# experiment is shorter than dim
idx = xr.DataArray(np.arange(len(matrices["experiment"])), dims="pointwise_sel")
s_full.loc[{"experiment": idx, "dim": idx}] = s_da.transpose("batch", "experiment", "dim2")
s_full.loc[{"experiment": idx, "dim": idx}] = s_da.rename(experiment="pointwise_sel")
compare = matmul(
matmul(u_da, s_full, dims=[["experiment", "experiment_bis"], ["experiment", "dim"]]),
vh_da,
Expand Down

0 comments on commit d91642d

Please sign in to comment.