From d91642de7ad1d58a23b4c3b97e1c9e91395f5ee4 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 21 Dec 2023 20:58:47 +0100 Subject: [PATCH] add support for hashable dims (#56) --- docs/source/background/einops.md | 9 ++- src/xarray_einstats/einops.py | 94 ++++++++++++++++---------------- src/xarray_einstats/numba.py | 2 +- tests/test_accessors.py | 2 +- tests/test_einops.py | 31 +++++++++-- tests/test_linalg.py | 4 +- 6 files changed, 80 insertions(+), 62 deletions(-) diff --git a/docs/source/background/einops.md b/docs/source/background/einops.md index 17eab74..f726c8f 100644 --- a/docs/source/background/einops.md +++ b/docs/source/background/einops.md @@ -13,8 +13,7 @@ of the elements respectively: `->`, space as delimiter and parenthesis: side in the einops notation is only used to label the dimensions. In fact, 5/7 examples in https://einops.rocks/api/rearrange/ fall in this category. This is not necessary when working with xarray objects. -* In xarray dimension names can be any {term}`hashable `. `xarray-einstats` only - supports strings as dimension names, but the space can't be used as delimiter. +* In xarray dimension names can be any {term}`hashable `. * In xarray dimensions are labeled and the order doesn't matter. This might seem the same as the first reason but it is not. When splitting or stacking dimensions you need (and want) the names of both parent and children dimensions. @@ -25,8 +24,8 @@ of the elements respectively: `->`, space as delimiter and parenthesis: However, there are also many cases in which dimension names in xarray will be strings without any spaces nor parenthesis in them. So similarly to the option of -doing `da.stack(dim=("dim1", "dim2"))` which can't be used for all valid -dimension names but is generally easier to write and less error prone, +doing `da.stack(dim=["dim1", "dim2"])` which can't be used for all valid +dimension names but is generally easier to write and less error prone than, `xarray_einstats.einops` also provides two possible syntaxes. The guiding principle of the einops module is to take the input expressions @@ -37,7 +36,7 @@ labeled, we can take advantage of that during the translation process and thus support "partial" expressions that cover only the dimensions that will be modified. -Another important consideration is to take into account that _in xarray_, +Another important consideration is to take into account that _in xarray_ dimension order should not matter, hence the constraint of using dicts on the left side. Imposing this constraint also makes our job of filling in the "partial" expressions much easier. diff --git a/src/xarray_einstats/einops.py b/src/xarray_einstats/einops.py index 738fb3b..97f35da 100644 --- a/src/xarray_einstats/einops.py +++ b/src/xarray_einstats/einops.py @@ -14,6 +14,7 @@ """ import warnings +from collections.abc import Hashable import einops import xarray as xr @@ -61,7 +62,7 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True): allow_dict, allow_list : bool, optional Whether or not to allow lists or dicts as elements of ``redims``. When processing ``in_dims`` for example we need the names of - the variables to be decomposed so dicts are required and lists/tuples + the variables to be decomposed so dicts are required and lists are not accepted. Returns @@ -85,14 +86,14 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True): from xarray_einstats.einops import process_pattern_list, DimHandler handler = DimHandler() - process_pattern_list(["a", {"b": ("c", "d")}, ("e", "f", "g")], handler) + process_pattern_list(["a", {"b": ["c", "d"]}, ["e", "f", "g"]], handler) """ out = [] out_names = [] txt = [] for subitem in redims: - if isinstance(subitem, str): + if isinstance(subitem, Hashable): out.append(subitem) out_names.append(subitem) txt.append(handler.get_name(subitem)) @@ -103,8 +104,10 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True): f"found {len(subitem)}: {subitem.keys()}" ) key, values = list(subitem.items())[0] - if isinstance(values, str): - raise ValueError("Found values of type str in a pattern dict, use xarray.rename") + if isinstance(values, Hashable): + raise ValueError( + "Found values of hashable type in a pattern dict, use xarray.rename" + ) out.extend(values) out_names.append(key) txt.append(f"( {handler.get_names(values)} )") @@ -182,7 +185,7 @@ def translate_pattern(pattern): return dims -def _rearrange(da, out_dims, in_dims=None, **kwargs): +def _rearrange(da, out_dims, in_dims=None, dim_lengths=None): """Wrap `einops.rearrange `_. This is the function that actually interfaces with ``einops``. @@ -198,11 +201,14 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs): See docstring of :func:`~xarray_einstats.einops.rearrange` in_dims : list of str or dict, optional See docstring of :func:`~xarray_einstats.einops.rearrange` - kwargs : dict, optional + dim_lengths : dict, optional kwargs with key equal to dimension names in ``out_dims`` (that is, strings or dict keys) are passed to einops.rearrange the rest of keys are passed to :func:`xarray.apply_ufunc` """ + if dim_lengths is None: + dim_lengths = {} + da_dims = da.dims handler = DimHandler() @@ -231,9 +237,9 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs): {non_core_pattern} {handler.get_names(missing_out_dims)} {out_pattern}" axes_lengths = { - handler.rename_kwarg(k): v for k, v in kwargs.items() if k in out_names + out_dims + handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in out_names + out_dims } - kwargs = {k: v for k, v in kwargs.items() if k not in out_names + out_dims} + kwargs = {k: v for k, v in dim_lengths.items() if k not in out_names + out_dims} return xr.apply_ufunc( einops.rearrange, da, @@ -245,7 +251,7 @@ def _rearrange(da, out_dims, in_dims=None, **kwargs): ) -def rearrange(da, pattern, pattern_in=None, **kwargs): +def rearrange(da, pattern, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs): """Expose `einops.rearrange `_ with an xarray-like API. It has two possible syntaxes which are independent and somewhat complementary. @@ -268,12 +274,12 @@ def rearrange(da, pattern, pattern_in=None, **kwargs): a default name. If `pattern` is not a string, then it must be a list where each of its elements - is one of: ``str``, ``list`` (to stack those dimensions and give them an - arbitrary name) or ``dict of {str: list}`` (to stack the dimensions indicated + is one of: :term:`python:hashable`, ``list`` (to stack those dimensions and + give them an arbitrary name) or ``dict`` (to stack the dimensions indicated as values of the dictionary and name the resulting dimensions with the key). - `pattern` is then interpreted as the output side of the einops pattern. See - TODO for more details. + `pattern` is then interpreted as the output side of the einops pattern. + See :ref:`about_einops` for more details. pattern_in : list of [str or dict], optional The input pattern for the dimensions. It can only be provided if `pattern` is a ``list``. Also, note this is only necessary if you want to split some dimensions. @@ -282,28 +288,22 @@ def rearrange(da, pattern, pattern_in=None, **kwargs): with the only difference that ``list`` elements are not allowed, the same way that ``(dim1 dim2)=dim`` is required on the left hand side when using string patterns. - kwargs : dict, optional - Passed to :func:`xarray_einstats.einops.rearrange` + dim_lengths, **dim_lengths_kwargs : dict, optional + If the keys are dimensions present in `pattern` they will be passed to + `einops.rearrange `_, otherwise, + they are passed to :func:`xarray.apply_ufunc`. Returns ------- xarray.DataArray - Notes - ----- - Unlike for general xarray objects, where dimension - names can be :term:`hashable ` here - dimension names are not recommended but required to be - strings for both cases. Future releases however might - support this when using lists as `pattern`, comment - on :issue:`50` if you are interested in the feature - or could help implement it. - - See Also -------- xarray_einstats.einops.reduce """ + if dim_lengths is None: + dim_lengths = {} + dim_lengths = {**dim_lengths, **dim_lengths_kwargs} if isinstance(pattern, str): if "->" in pattern: in_pattern, out_pattern = pattern.split("->") @@ -312,11 +312,11 @@ def rearrange(da, pattern, pattern_in=None, **kwargs): out_pattern = pattern in_dims = None out_dims = translate_pattern(out_pattern) - return _rearrange(da, out_dims=out_dims, in_dims=in_dims, **kwargs) - return _rearrange(da, out_dims=pattern, in_dims=pattern_in, **kwargs) + return _rearrange(da, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths) + return _rearrange(da, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths) -def _reduce(da, reduction, out_dims, in_dims=None, **kwargs): +def _reduce(da, reduction, out_dims, in_dims=None, dim_lengths=None): """Wrap `einops.reduce `_. This is the function that actually interfaces with ``einops``. @@ -338,11 +338,14 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs): in_dims : list of str or dict, optional The input pattern for the dimensions. This is only necessary if you want to split some dimensions. - kwargs : dict, optional + dim_lengths : dict, optional kwargs with key equal to dimension names in ``out_dims`` (that is, strings or dict keys) are passed to einops.rearrange the rest of keys are passed to :func:`xarray.apply_ufunc` """ + if dim_lengths is None: + dim_lengths = {} + da_dims = da.dims handler = DimHandler() @@ -361,8 +364,8 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs): pattern = f"{handler.get_names(missing_in_dims)} {in_pattern} -> {out_pattern}" all_dims = set(out_dims + out_names + in_names + in_dims) - axes_lengths = {handler.rename_kwarg(k): v for k, v in kwargs.items() if k in all_dims} - kwargs = {k: v for k, v in kwargs.items() if k not in all_dims} + axes_lengths = {handler.rename_kwarg(k): v for k, v in dim_lengths.items() if k in all_dims} + kwargs = {k: v for k, v in dim_lengths.items() if k not in all_dims} return xr.apply_ufunc( einops.reduce, da, @@ -375,7 +378,7 @@ def _reduce(da, reduction, out_dims, in_dims=None, **kwargs): ) -def reduce(da, pattern, reduction, pattern_in=None, **kwargs): +def reduce(da, pattern, reduction, pattern_in=None, dim_lengths=None, **dim_lengths_kwargs): """Expose `einops.reduce `_ with an xarray-like API. It has two possible syntaxes which are independent and somewhat complementary. @@ -412,27 +415,22 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs): The syntax and interpretation is the same as the case when `pattern` is a list, with the only difference that ``list`` elements are not allowed, the same way that ``(dim1 dim2)=dim`` is required on the left hand side when using string - kwargs : dict, optional - Passed to :func:`xarray_einstats.einops.reduce` + dim_lengths, **dim_lengths_kwargs : dict, optional + If the keys are dimensions present in `pattern` they will be passed to + `einops.reduce `_, otherwise, + they are passed to :func:`xarray.apply_ufunc`. Returns ------- xarray.DataArray - Notes - ----- - Unlike for general xarray objects, where dimension - names can be :term:`hashable ` here - dimension names are not recommended but required to be - strings for both cases. Future releases however might - support this when using lists as `pattern`, comment - on :issue:`50` if you are interested in the feature - or could help implement it. - See Also -------- xarray_einstats.einops.rearrange """ + if dim_lengths is None: + dim_lengths = {} + dim_lengths = {**dim_lengths, **dim_lengths_kwargs} if isinstance(pattern, str): if "->" in pattern: in_pattern, out_pattern = pattern.split("->") @@ -441,8 +439,8 @@ def reduce(da, pattern, reduction, pattern_in=None, **kwargs): out_pattern = pattern in_dims = None out_dims = translate_pattern(out_pattern) - return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, **kwargs) - return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, **kwargs) + return _reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, dim_lengths=dim_lengths) + return _reduce(da, reduction, out_dims=pattern, in_dims=pattern_in, dim_lengths=dim_lengths) def raw_reduce(*args, **kwargs): diff --git a/src/xarray_einstats/numba.py b/src/xarray_einstats/numba.py index 3c42604..b807f10 100644 --- a/src/xarray_einstats/numba.py +++ b/src/xarray_einstats/numba.py @@ -268,7 +268,7 @@ def ecdf(da, dims=None, *, npoints=None, **kwargs): dims = da.dims elif isinstance(dims, str): dims = [dims] - total_points = np.product([da.sizes[d] for d in dims]) + total_points = np.prod([da.sizes[d] for d in dims]) if npoints is None: npoints = min(total_points, 200) x = xr.DataArray(np.linspace(0, 1, npoints), dims=["quantile"]) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index f5da4ae..2556760 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -88,7 +88,7 @@ def test_einops_accessor_rearrange(data): @pytest.mark.skipif(find_spec("einops") is None, reason="einops must be installed") def test_einops_accessor_reduce(data): - pattern_in = [{"batch (hh.mm)": ("d1", "d2")}] + pattern_in = [{"batch (hh.mm)": ["d1", "d2"]}] pattern = ["d1", "subject"] kwargs = {"d2": 2} input_data = data.rename({"batch": "batch (hh.mm)"}) diff --git a/tests/test_einops.py b/tests/test_einops.py index 049dcea..9de94c0 100644 --- a/tests/test_einops.py +++ b/tests/test_einops.py @@ -60,12 +60,12 @@ class TestRearrange: "args", ( ( - {"pattern": [{"dex": ("drug dose (mg)", "experiment")}]}, + {"pattern": [{"dex": ["drug dose (mg)", "experiment"]}]}, ((4, 6, 8 * 15), ["batch", "subject", "dex"]), ), ( { - "pattern_in": [{"drug dose (mg)": ("d1", "d2")}], + "pattern_in": [{"drug dose (mg)": ["d1", "d2"]}], "pattern": ["d1", "d2", "batch"], "d1": 2, "d2": 4, @@ -80,6 +80,16 @@ def test_rearrange(self, data, args): assert out_da.shape == shape assert list(out_da.dims) == dims + def test_rearrange_tuple_dim(self, data): + out_da = rearrange( + data.rename(drug=("drug dose", "mg")), + pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}], + pattern=[("d", 1), ("d", 2), "batch"], + dim_lengths={("d", 1): 2, ("d", 2): 4}, + ) + assert out_da.shape == (6, 15, 2, 4, 4) + assert list(out_da.dims) == ["subject", "experiment", ("d", 1), ("d", 2), "batch"] + class TestRawReduce: @pytest.mark.parametrize( @@ -110,7 +120,7 @@ class TestReduce: ), ( { - "pattern_in": [{"batch (hh.mm)": ("d1", "d2")}], + "pattern_in": [{"batch (hh.mm)": ["d1", "d2"]}], "pattern": ["d1", "subject"], "d2": 2, }, @@ -118,8 +128,8 @@ class TestReduce: ), ( { - "pattern_in": [{"drug": ("d1", "d2")}, {"batch (hh.mm)": ("b1", "b2")}], - "pattern": ["subject", ("b1", "d1")], + "pattern_in": [{"drug": ["d1", "d2"]}, {"batch (hh.mm)": ["b1", "b2"]}], + "pattern": ["subject", ["b1", "d1"]], "d2": 4, "b2": 2, }, @@ -132,3 +142,14 @@ def test_reduce(self, data, args): out_da = reduce(data.rename({"batch": "batch (hh.mm)"}), reduction="mean", **kwargs) assert out_da.shape == shape assert list(out_da.dims) == dims + + def test_reduce_tuple_dim(self, data): + out_da = reduce( + data.rename(drug=("drug dose", "mg")), + reduction="mean", + pattern_in=[{("drug dose", "mg"): [("d", 1), ("d", 2)]}], + pattern=["subject", ("d", 2), "batch"], + dim_lengths={("d", 1): 2, ("d", 2): 4}, + ) + assert out_da.shape == (6, 4, 4) + assert list(out_da.dims) == ["subject", ("d", 2), "batch"] diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 7a0b953..29e2a1f 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -246,7 +246,7 @@ def test_svd(self, matrices): u_da, s_da, vh_da = svd(matrices, dims=("dim", "dim2"), out_append="_bis") s_full = xr.zeros_like(matrices) idx = xr.DataArray(np.arange(len(matrices["dim"])), dims="pointwise_sel") - s_full.loc[{"dim": idx, "dim2": idx}] = s_da + s_full.loc[{"dim": idx, "dim2": idx}] = s_da.rename(dim="pointwise_sel") compare = matmul( matmul(u_da, s_full, dims=[["dim", "dim_bis"], ["dim", "dim2"]]), vh_da, @@ -259,7 +259,7 @@ def test_svd_non_square(self, matrices): s_full = xr.zeros_like(matrices) # experiment is shorter than dim idx = xr.DataArray(np.arange(len(matrices["experiment"])), dims="pointwise_sel") - s_full.loc[{"experiment": idx, "dim": idx}] = s_da.transpose("batch", "experiment", "dim2") + s_full.loc[{"experiment": idx, "dim": idx}] = s_da.rename(experiment="pointwise_sel") compare = matmul( matmul(u_da, s_full, dims=[["experiment", "experiment_bis"], ["experiment", "dim"]]), vh_da,