From 8db0b92d4f8543185d8971af6360f75cdb30918b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Heiko=20Sch=C3=BCtt?= Date: Wed, 28 Jun 2023 15:31:51 -0400 Subject: [PATCH 1/2] added possibility to run calc_unbalanced for rdm movies --- src/rsatoolbox/rdm/calc.py | 78 +++++++-------------------- src/rsatoolbox/rdm/calc_unbalanced.py | 2 +- src/rsatoolbox/util/build_rdm.py | 65 ++++++++++++++++++++++ 3 files changed, 86 insertions(+), 59 deletions(-) create mode 100644 src/rsatoolbox/util/build_rdm.py diff --git a/src/rsatoolbox/rdm/calc.py b/src/rsatoolbox/rdm/calc.py index 03f8a992..dfd871d5 100644 --- a/src/rsatoolbox/rdm/calc.py +++ b/src/rsatoolbox/rdm/calc.py @@ -9,11 +9,13 @@ from copy import deepcopy from typing import TYPE_CHECKING, Optional, Tuple import numpy as np -from rsatoolbox.rdm.rdms import RDMs from rsatoolbox.rdm.rdms import concat +from rsatoolbox.rdm.calc_unbalanced import calc_rdm_unbalanced from rsatoolbox.rdm.combine import from_partials from rsatoolbox.data import average_dataset_by from rsatoolbox.util.rdm_utils import _extract_triu_ +from rsatoolbox.util.build_rdm import _build_rdms + if TYPE_CHECKING: from rsatoolbox.data.base import DatasetBase from numpy.typing import NDArray @@ -101,7 +103,7 @@ def calc_rdm(dataset, method='euclidean', descriptor=None, noise=None, def calc_rdm_movie( dataset, method='euclidean', descriptor=None, noise=None, cv_descriptor=None, prior_lambda=1, prior_weight=0.1, - time_descriptor='time', bins=None): + time_descriptor='time', bins=None, unbalanced=False): """ calculates an RDM movie from an input TemporalDataset @@ -121,6 +123,8 @@ def calc_rdm_movie( dimension in dataset.time_descriptors. Defaults to 'time'. bins (array-like): list of bins, with bins[i] containing the vector of time-points for the i-th bin. Defaults to no binning. + unbalanced (bool): if set to True use calc_rdm_unbalanced, + else and by default use calc_rdm Returns: rsatoolbox.rdm.rdms.RDMs: RDMs object with RDM movie @@ -156,11 +160,20 @@ def calc_rdm_movie( rdms = [] for dat in splited_data: dat_single = dat.convert_to_dataset(time_descriptor) - rdms.append(calc_rdm(dat_single, method=method, - descriptor=descriptor, noise=noise, - cv_descriptor=cv_descriptor, - prior_lambda=prior_lambda, - prior_weight=prior_weight)) + if unbalanced: + rdms.append(calc_rdm_unbalanced( + dat_single, method=method, + descriptor=descriptor, noise=noise, + cv_descriptor=cv_descriptor, + prior_lambda=prior_lambda, + prior_weight=prior_weight)) + else: + rdms.append(calc_rdm( + dat_single, method=method, + descriptor=descriptor, noise=noise, + cv_descriptor=cv_descriptor, + prior_lambda=prior_lambda, + prior_weight=prior_weight)) rdm = concat(rdms) rdm.rdm_descriptors[time_descriptor] = time @@ -487,54 +500,3 @@ def _check_noise(noise, n_channel): else: raise ValueError('noise(s) must have shape n_channel x n_channel') return noise - - -def _build_rdms( - utv: NDArray, - ds: DatasetBase, - method: str, - obs_desc_name: str | None, - obs_desc_vals: Optional[NDArray] = None, - cv: Optional[NDArray] = None, - noise: Optional[NDArray] = None - ) -> RDMs: - rdms = RDMs( - dissimilarities=np.array([utv]), - dissimilarity_measure=method, - rdm_descriptors=deepcopy(ds.descriptors) - ) - if (obs_desc_vals is None) and (obs_desc_name is not None): - # obtain the unique values in the target obs descriptor - _, obs_desc_vals, _ = average_dataset_by(ds, obs_desc_name) - - if _averaging_occurred(ds, obs_desc_name, obs_desc_vals): - orig_obs_desc_vals = np.asarray(ds.obs_descriptors[obs_desc_name]) - for dname, dvals in ds.obs_descriptors.items(): - dvals = np.asarray(dvals) - avg_dvals = np.full_like(obs_desc_vals, np.nan, dtype=dvals.dtype) - for i, v in enumerate(obs_desc_vals): - subset = dvals[orig_obs_desc_vals == v] - if len(set(subset)) > 1: - break - avg_dvals[i] = subset[0] - else: - rdms.pattern_descriptors[dname] = avg_dvals - else: - rdms.pattern_descriptors = deepcopy(ds.obs_descriptors) - # Additional rdm_descriptors - if noise is not None: - rdms.descriptors['noise'] = noise - if cv is not None: - rdms.descriptors['cv_descriptor'] = cv - return rdms - - -def _averaging_occurred( - ds: DatasetBase, - obs_desc_name: str | None, - obs_desc_vals: NDArray | None - ) -> bool: - if obs_desc_name is None: - return False - orig_obs_desc_vals = ds.obs_descriptors[obs_desc_name] - return len(obs_desc_vals) != len(orig_obs_desc_vals) diff --git a/src/rsatoolbox/rdm/calc_unbalanced.py b/src/rsatoolbox/rdm/calc_unbalanced.py index 4e11d98d..8c8112c5 100755 --- a/src/rsatoolbox/rdm/calc_unbalanced.py +++ b/src/rsatoolbox/rdm/calc_unbalanced.py @@ -14,9 +14,9 @@ import numpy as np from rsatoolbox.rdm.rdms import RDMs from rsatoolbox.rdm.rdms import concat -from rsatoolbox.rdm.calc import _build_rdms from rsatoolbox.util.data_utils import get_unique_inverse from rsatoolbox.util.matrix import row_col_indicator_rdm +from rsatoolbox.util.build_rdm import _build_rdms from rsatoolbox.cengine.similarity import calc_one, calc if TYPE_CHECKING: from rsatoolbox.data.base import DatasetBase diff --git a/src/rsatoolbox/util/build_rdm.py b/src/rsatoolbox/util/build_rdm.py new file mode 100644 index 00000000..e4d7e1e4 --- /dev/null +++ b/src/rsatoolbox/util/build_rdm.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""helper methods to create RDMs at the end of calculations""" + +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +import numpy as np +from copy import deepcopy +from rsatoolbox.rdm.rdms import RDMs +from rsatoolbox.data import average_dataset_by + +if TYPE_CHECKING: + from rsatoolbox.data.base import DatasetBase + from numpy.typing import NDArray + + +def _build_rdms( + utv: NDArray, + ds: DatasetBase, + method: str, + obs_desc_name: str | None, + obs_desc_vals: Optional[NDArray] = None, + cv: Optional[NDArray] = None, + noise: Optional[NDArray] = None + ) -> RDMs: + rdms = RDMs( + dissimilarities=np.array([utv]), + dissimilarity_measure=method, + rdm_descriptors=deepcopy(ds.descriptors) + ) + if (obs_desc_vals is None) and (obs_desc_name is not None): + # obtain the unique values in the target obs descriptor + _, obs_desc_vals, _ = average_dataset_by(ds, obs_desc_name) + + if _averaging_occurred(ds, obs_desc_name, obs_desc_vals): + orig_obs_desc_vals = np.asarray(ds.obs_descriptors[obs_desc_name]) + for dname, dvals in ds.obs_descriptors.items(): + dvals = np.asarray(dvals) + avg_dvals = np.full_like(obs_desc_vals, np.nan, dtype=dvals.dtype) + for i, v in enumerate(obs_desc_vals): + subset = dvals[orig_obs_desc_vals == v] + if len(set(subset)) > 1: + break + avg_dvals[i] = subset[0] + else: + rdms.pattern_descriptors[dname] = avg_dvals + else: + rdms.pattern_descriptors = deepcopy(ds.obs_descriptors) + # Additional rdm_descriptors + if noise is not None: + rdms.descriptors['noise'] = noise + if cv is not None: + rdms.descriptors['cv_descriptor'] = cv + return rdms + + +def _averaging_occurred( + ds: DatasetBase, + obs_desc_name: str | None, + obs_desc_vals: NDArray | None + ) -> bool: + if obs_desc_name is None: + return False + orig_obs_desc_vals = ds.obs_descriptors[obs_desc_name] + return len(obs_desc_vals) != len(orig_obs_desc_vals) From 15ae7a948f237afc4cd4786e8b39456140450306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Heiko=20Sch=C3=BCtt?= Date: Wed, 20 Sep 2023 18:29:05 +0200 Subject: [PATCH 2/2] Import order build_rdm.py --- src/rsatoolbox/util/build_rdm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rsatoolbox/util/build_rdm.py b/src/rsatoolbox/util/build_rdm.py index e4d7e1e4..e09c21ac 100644 --- a/src/rsatoolbox/util/build_rdm.py +++ b/src/rsatoolbox/util/build_rdm.py @@ -4,8 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -import numpy as np from copy import deepcopy +import numpy as np from rsatoolbox.rdm.rdms import RDMs from rsatoolbox.data import average_dataset_by