Skip to content

Commit

Permalink
Merge branch 'main' into fmri2
Browse files Browse the repository at this point in the history
  • Loading branch information
JasperVanDenBosch committed Sep 15, 2023
2 parents 0201560 + ddd7f67 commit 366a621
Show file tree
Hide file tree
Showing 23 changed files with 1,404 additions and 94 deletions.
844 changes: 844 additions & 0 deletions demos/demo_eeg.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Binary file added demos/demo_eeg_data/annotate.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/demo_eeg_data/shared0140_nsd11797.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/demo_eeg_data/shared0936_nsd67830.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/demo_eeg_data/shared0944_nsd68742.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 0 additions & 3 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,3 @@ As an example, we assume to have measured data from 10 trials, each with six EEG
Beyond the functions to manipulate the data provided by ``rsatoolbox.data.Dataset``, the ``rsatoolbox.data.TemporalDataset`` class provides the following functions:
``split_time``, ``subset_time``, ``bin_time``, ``convert_to_dataset``.


TODO: TIPS TO IMPORT FMRI / EEG ETC DATA
6 changes: 6 additions & 0 deletions docs/source/demo_eeg.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"path": "../../demos/demo_eeg.ipynb",
"extra-media": [
"../../demos/demo_eeg_data"
]
}
1 change: 1 addition & 0 deletions docs/source/demos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ Demos
demo_unbalanced
demo_temporal
demo_meg_mne
demo_eeg
demo_searchlight
rescale_partials
7 changes: 7 additions & 0 deletions docs/source/rsatoolbox.data.ops.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
rsatoolbox.data.ops module
==========================

.. automodule:: rsatoolbox.data.ops
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/rsatoolbox.io.mne.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
rsatoolbox.io.mne module
========================

.. automodule:: rsatoolbox.io.mne
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/rsatoolbox.io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Submodules

rsatoolbox.io.hdf5
rsatoolbox.io.meadows
rsatoolbox.io.mne
rsatoolbox.io.pandas
rsatoolbox.io.pkl

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ matplotlib
h5py
tqdm
joblib
importlib-metadata>=6.0.0; python_version < "3.8"
importlib-metadata>=6.0.0; python_version < "3.8"
typing-extensions; python_version < "3.8"
1 change: 0 additions & 1 deletion src/rsatoolbox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .dataset import TemporalDataset
from .dataset import load_dataset
from .dataset import dataset_from_dict
from .dataset import merge_subsets
from .computations import average_dataset
from .computations import average_dataset_by
from .noise import cov_from_residuals
Expand Down
104 changes: 71 additions & 33 deletions src/rsatoolbox/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

from __future__ import annotations
from typing import List, Optional
from warnings import warn
from copy import deepcopy
import numpy as np
from pandas import DataFrame
from rsatoolbox.data.ops import merge_datasets
from rsatoolbox.util.data_utils import get_unique_unsorted
from rsatoolbox.util.data_utils import get_unique_inverse
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
from rsatoolbox.util.descriptor_utils import subset_descriptor
from rsatoolbox.util.descriptor_utils import num_index
from rsatoolbox.util.descriptor_utils import format_descriptor
from rsatoolbox.util.descriptor_utils import parse_input_descriptor
from rsatoolbox.util.descriptor_utils import append_obs_descriptors
from rsatoolbox.util.descriptor_utils import desc_eq
from rsatoolbox.io.hdf5 import read_dict_hdf5
from rsatoolbox.io.pkl import read_dict_pkl
Expand Down Expand Up @@ -244,8 +245,8 @@ def odd_even_split(self, obs_desc):
ds_part = self.split_obs(obs_desc)
odd_list = ds_part[0::2]
even_list = ds_part[1::2]
odd_split = merge_subsets(odd_list)
even_split = merge_subsets(even_list)
odd_split = merge_datasets(odd_list)
even_split = merge_datasets(even_list)
return odd_split, even_split

def nested_odd_even_split(self, l1_obs_desc, l2_obs_desc):
Expand Down Expand Up @@ -282,8 +283,8 @@ def nested_odd_even_split(self, l1_obs_desc, l2_obs_desc):
odd_split, even_split = partition.odd_even_split(l2_obs_desc)
odd_list.append(odd_split)
even_list.append(even_split)
odd_split = merge_subsets(odd_list)
even_split = merge_subsets(even_list)
odd_split = merge_datasets(odd_list)
even_split = merge_datasets(even_list)
return odd_split, even_split

@staticmethod
Expand Down Expand Up @@ -685,17 +686,54 @@ def subset_time(self, by, t_from, t_to):
time_descriptors=time_descriptors)
return dataset

def convert_to_dataset(self, by):
""" converts to Dataset long format.
time dimension is absorbed into observation dimension
def sort_by(self, by):
""" sorts the dataset by a given observation descriptor
Args:
by(String): the descriptor which indicates the time dimension in
the time_descriptor
by(String): the descriptor by which the dataset shall be sorted
Returns:
---
"""
desc = self.obs_descriptors[by]
order = np.argsort(desc)
self.measurements = self.measurements[order]
self.obs_descriptors = subset_descriptor(self.obs_descriptors, order)

def time_as_channels(self) -> Dataset:
"""Converts this to a standard Dataset "long format",
where timepoints are represented as additional channels.
Args:
by (str): the descriptor which indicates the time dimension in
the time_descriptor.
Returns:
Dataset
"""
n_obs, n_chans, n_tps = self.measurements.shape
old_chn_des = self.channel_descriptors
chn_des = {k: np.repeat(v, n_tps) for (k, v) in old_chn_des.items()}
for k, v in self.time_descriptors.items():
chn_des[k] = np.tile(v, n_chans)
return Dataset(
measurements=self.measurements.reshape(n_obs, -1),
descriptors=deepcopy(self.descriptors),
obs_descriptors=deepcopy(self.obs_descriptors),
channel_descriptors=chn_des
)

def time_as_observations(self, by='time') -> Dataset:
"""Converts this to a standard Dataset "long format",
where timepoints are represented as additional observations.
Args:
by (str): the descriptor which indicates the time dimension in
the time_descriptor.
Returns:
Dataset
"""
time = get_unique_unsorted(self.time_descriptors[by])

Expand Down Expand Up @@ -735,6 +773,24 @@ def convert_to_dataset(self, by):
channel_descriptors=channel_descriptors)
return dataset

def convert_to_dataset(self, by):
""" converts to Dataset long format.
time dimension is absorbed into observation dimension
Deprecated: Use `TemporalDataset.time_as_observations()` instead.
Args:
by(String): the descriptor which indicates the time dimension in
the time_descriptor
Returns:
Dataset
"""
warn('Deprecated: [TemporalDataset.convert_to_dataset()]. Replace by '
'[TemporalDataset.time_as_observations()]', DeprecationWarning)
return self.time_as_observations(by)

def to_dict(self):
""" Generates a dictionary which contains the information to
recreate the TemporalDataset object. Used for saving to disc
Expand Down Expand Up @@ -818,6 +874,8 @@ def merge_subsets(dataset_list):
(e.g., as generated by the subset_* methods). Assumes that descriptors,
channel descriptors and number of channels per observation match.
Deprecated. Use `rsatoolbox.data.ops.merge_datasets` instead.
Args:
dataset_list (list):
List containing rsatoolbox datasets
Expand All @@ -826,26 +884,6 @@ def merge_subsets(dataset_list):
merged_dataset (Dataset):
rsatoolbox dataset created from all datasets in dataset_list
"""
assert isinstance(dataset_list, list), "Provided object is not a list."
assert "Dataset" in str(type(dataset_list[0])), \
"Provided list does not only contain Dataset objects."
baseline_ds = dataset_list[0]
descriptors = baseline_ds.descriptors.copy()
channel_descriptors = baseline_ds.channel_descriptors.copy()
measurements = baseline_ds.measurements.copy()
obs_descriptors = baseline_ds.obs_descriptors.copy()

for ds in dataset_list[1:]:
assert "Dataset" in str(type(ds)), \
"Provided list does not only contain Dataset objects."
assert descriptors == ds.descriptors.copy(), \
"Dataset descriptors do not match."
measurements = np.append(measurements, ds.measurements, axis=0)
obs_descriptors = append_obs_descriptors(obs_descriptors,
ds.obs_descriptors.copy())

merged_dataset = Dataset(measurements,
descriptors=descriptors,
obs_descriptors=obs_descriptors,
channel_descriptors=channel_descriptors)
return merged_dataset
warn('Deprecated: [rsatoolbox.data.dataset.merge_subsets()]. Replace by '
'[rsatoolbox.data.ops.merge_datasets()]', DeprecationWarning)
return merge_datasets(dataset_list)
97 changes: 97 additions & 0 deletions src/rsatoolbox/data/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Operations on multiple Datasets
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Union, List, Set, overload
from copy import deepcopy
from warnings import warn
try:
from typing import Literal # pylint: disable=ungrouped-imports
except ImportError:
from typing_extensions import Literal
from numpy import concatenate, repeat
import rsatoolbox
if TYPE_CHECKING:
DESC_LEVEL = Union[Literal['obs'], Literal['set']]
from rsatoolbox.data.dataset import Dataset, TemporalDataset


@overload
def merge_datasets(sets: List[TemporalDataset]) -> TemporalDataset:
...


@overload
def merge_datasets(sets: List[Dataset]) -> Dataset:
...


def merge_datasets(sets: Union[List[Dataset], List[TemporalDataset]]
) -> Union[Dataset, TemporalDataset]:
"""Concatenate measurements to create one Dataset of the same type
Only descriptors that exist on all subsets are assigned to the merged
dataset.
Dataset-level `descriptors` that are identical across subsets will be
passed on, those that vary will become `obs_descriptors`.
Channel and Time descriptors must be identical across subsets.
Args:
sets (Union[List[Dataset], List[TemporalDataset]]): List of Dataset
or TemporalDataset objects. Must all be the same type.
Returns:
Union[Dataset, TemporalDataset]: The new dataset combining measurements
and descriptors from the given subset datasets.
"""
if len(sets) == 0:
warn('[merge_datasets] Received empty list, returning empty Dataset')
return rsatoolbox.data.dataset.Dataset(measurements=[])
if len({type(s) for s in sets}) > 1:
raise ValueError('All datasets must be of the same type')
ds0 = sets[0]
# numpy pre-allocates so this seems to be a performant solution:
meas = concatenate([ds.measurements for ds in sets], axis=0)
obs_descs = dict()
# loop over obs descriptors that all subsets have in common:
for k in _shared_descriptors(sets, 'obs'):
obs_descs[k] = concatenate([ds.obs_descriptors[k] for ds in sets])
dat_decs = dict()
for k in _shared_descriptors(sets):
if len({s.descriptors[k] for s in sets}) == 1:
# descriptor always has the same value
dat_decs[k] = ds0.descriptors[k]
else:
# descriptor varies across subsets, so repeat it by observation
obs_descs[k] = repeat(
[ds.descriptors[k] for ds in sets],
[ds.n_obs for ds in sets]
)
# order is important as long as TemporalDataset inherits from Dataset
if isinstance(ds0, rsatoolbox.data.dataset.TemporalDataset):
return rsatoolbox.data.dataset.TemporalDataset(
measurements=meas,
descriptors=dat_decs,
obs_descriptors=obs_descs,
channel_descriptors=deepcopy(ds0.channel_descriptors),
time_descriptors=deepcopy(ds0.time_descriptors),
)
if isinstance(ds0, rsatoolbox.data.dataset.Dataset):
return rsatoolbox.data.dataset.Dataset(
measurements=meas,
descriptors=dat_decs,
obs_descriptors=obs_descs,
channel_descriptors=deepcopy(ds0.channel_descriptors)
)
raise ValueError('Unsupported Dataset type')


def _shared_descriptors(
datasets: Union[List[Dataset], List[TemporalDataset]],
level: DESC_LEVEL = 'set') -> Set[str]:
"""Find descriptors that all datasets have in common
"""
if level == 'set':
each_keys = [set(d.descriptors.keys()) for d in datasets]
else:
each_keys = [set(d.obs_descriptors.keys()) for d in datasets]
return set.intersection(*each_keys)
62 changes: 62 additions & 0 deletions src/rsatoolbox/io/mne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations
from typing import Optional, Dict, TYPE_CHECKING
from os.path import basename
from rsatoolbox.data.dataset import TemporalDataset
if TYPE_CHECKING:
from mne.epochs import EpochsFIF


def read_epochs(fpath: str) -> TemporalDataset:
"""Create TemporalDataset from epochs in mne FIF file
Args:
fpath (str): Full path to epochs file
Returns:
TemporalDataset: dataset with epochs
"""
# pylint: disable-next=import-outside-toplevel
from mne import read_epochs as mne_read_epochs
epo = mne_read_epochs(fpath, preload=True, verbose='error')
fname = basename(fpath)
descs = dict(filename=fname, **descriptors_from_bids_filename(fname))
return dataset_from_epochs(epo, descs)


def dataset_from_epochs(
epochs: EpochsFIF,
descriptors: Optional[Dict] = None
) -> TemporalDataset:
"""Create TemporalDataset from MNE epochs object
Args:
fpath (str): Full path to epochs file
Returns:
TemporalDataset: dataset with epochs
"""
descriptors = descriptors or dict()
return TemporalDataset(
measurements=epochs.get_data(),
descriptors=descriptors,
obs_descriptors=dict(event=epochs.events[:, 2]),
channel_descriptors=dict(name=epochs.ch_names),
time_descriptors=dict(time=epochs.times)
)


def descriptors_from_bids_filename(fname: str) -> Dict[str, str]:
"""parse a filename for BIDS-style entities
Args:
fname (str): filename
Returns:
Dict[str, str]: sub, run or task descriptors
"""
descs = dict()
for dname in ['sub', 'run', 'task']:
for segment in fname.split('_'):
if segment.startswith(dname + '-'):
descs[dname] = segment[len(dname)+1:]
return descs
Loading

0 comments on commit 366a621

Please sign in to comment.