Skip to content

Commit

Permalink
Merge pull request #298 from rsagroup/rdms-to-pandas
Browse files Browse the repository at this point in the history
RDMs to Pandas DataFrame
  • Loading branch information
JasperVanDenBosch authored Feb 26, 2023
2 parents 8182ac0 + d0b35e3 commit 2361bc5
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 10 deletions.
42 changes: 42 additions & 0 deletions src/rsatoolbox/io/pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Conversions from rsatoolbox classes to pandas table objects
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pandas import DataFrame
import numpy
from numpy import asarray
if TYPE_CHECKING:
from rsatoolbox.rdm.rdms import RDMs


def rdms_to_df(rdms: RDMs) -> DataFrame:
"""Create DataFrame representation of the RDMs object
A column for:
- dissimilarity
- each rdm descriptor
- two for each pattern descriptor, suffixed by _1 and _2 respectively
Multiple RDMs are stacked row-wise.
See also the `RDMs.to_df()` method which calls this function
Args:
rdms (RDMs): the object to convert
Returns:
DataFrame: long-form pandas DataFrame with
dissimilarities and descriptors.
"""
n_rdms, n_pairs = rdms.dissimilarities.shape
cols = dict(dissimilarity=rdms.dissimilarities.ravel())
for dname, dvals in rdms.rdm_descriptors.items():
# rename the default index desc as that has special meaning in df
cname = 'rdm_index' if dname == 'index' else dname
cols[cname] = numpy.repeat(dvals, n_pairs)
for dname, dvals in rdms.pattern_descriptors.items():
ix = numpy.triu_indices(len(dvals), 1)
# rename the default index desc as that has special meaning in df
cname = 'pattern_index' if dname == 'index' else dname
for p in (0, 1):
cols[f'{cname}_{p+1}'] = numpy.tile(asarray(dvals)[ix[p]], n_rdms)
return DataFrame(cols)
11 changes: 11 additions & 0 deletions src/rsatoolbox/rdm/rdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import deepcopy
from collections.abc import Iterable
import numpy as np
from rsatoolbox.io.pandas import rdms_to_df
from rsatoolbox.rdm.combine import _mean
from rsatoolbox.util.rdm_utils import batch_to_vectors
from rsatoolbox.util.rdm_utils import batch_to_matrices
Expand Down Expand Up @@ -400,6 +401,16 @@ def to_dict(self):
rdm_dict['dissimilarity_measure'] = self.dissimilarity_measure
return rdm_dict

def to_df(self):
"""Return a new long-form pandas DataFrame representing this RDM
See `rsatoolbox.io.pandas.rdms_to_df` for details
Returns:
pandas.DataFrame: The DataFrame for this RDMs object
"""
return rdms_to_df(self)

def reorder(self, new_order):
"""Reorder the patterns according to the index in new_order
Expand Down
13 changes: 7 additions & 6 deletions src/rsatoolbox/rdm/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# -*- coding: utf-8 -*-
""" transforms, which can be applied to RDMs
"""

from __future__ import annotations
from copy import deepcopy
import numpy as np
from scipy.stats import rankdata
from .rdms import RDMs


def rank_transform(rdms, method='average'):
def rank_transform(rdms: RDMs, method='average'):
""" applies a rank_transform and generates a new RDMs object
This assigns a rank to each dissimilarity estimate in the RDM,
deals with rank ties and saves ranks as new dissimilarity estimates.
Expand All @@ -30,9 +30,9 @@ def rank_transform(rdms, method='average'):
dissimilarities = rdms.get_vectors()
dissimilarities = np.array([rankdata(dissimilarities[i], method=method)
for i in range(rdms.n_rdm)])
measure = rdms.dissimilarity_measure
if not measure[-7:] == '(ranks)':
measure = measure + ' (ranks)'
measure = rdms.dissimilarity_measure or ''
if '(ranks)' not in measure:
measure = (measure + ' (ranks)').strip()
rdms_new = RDMs(dissimilarities,
dissimilarity_measure=measure,
descriptors=deepcopy(rdms.descriptors),
Expand Down Expand Up @@ -103,8 +103,9 @@ def transform(rdms, fun):
"""
dissimilarities = rdms.get_vectors()
dissimilarities = fun(dissimilarities)
meas = 'transformed ' + rdms.dissimilarity_measure
rdms_new = RDMs(dissimilarities,
dissimilarity_measure='transformed ' + rdms.dissimilarity_measure,
dissimilarity_measure=meas,
descriptors=deepcopy(rdms.descriptors),
rdm_descriptors=deepcopy(rdms.rdm_descriptors),
pattern_descriptors=deepcopy(rdms.pattern_descriptors))
Expand Down
16 changes: 12 additions & 4 deletions tests/test_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ def test_rank_transform(self):
self.assertEqual(rank_rdm.n_cond, rdms.n_cond)
self.assertEqual(rank_rdm.dissimilarity_measure, 'Euclidean (ranks)')

def test_rank_transform_unknown_measure(self):
from rsatoolbox.rdm import rank_transform
rdms = rsr.RDMs(dissimilarities=np.zeros((8, 10)))
rank_rdm = rank_transform(rdms)
self.assertEqual(rank_rdm.dissimilarity_measure, '(ranks)')

def test_sqrt_transform(self):
from rsatoolbox.rdm import sqrt_transform
dis = np.zeros((8, 10))
Expand Down Expand Up @@ -463,17 +469,19 @@ def test_copy(self):
)
)
copy = orig.copy()
## We don't want a reference:
# We don't want a reference:
self.assertIsNot(copy, orig)
self.assertIsNot(copy.dissimilarities, orig.dissimilarities)
self.assertIsNot(
copy.pattern_descriptors.get('order'),
orig.pattern_descriptors.get('order')
)
## But check that attributes are equal
# But check that attributes are equal
assert_array_equal(copy.dissimilarities, orig.dissimilarities)
self.assertEqual(copy.dissimilarity_measure,
orig.dissimilarity_measure)
self.assertEqual(
copy.dissimilarity_measure,
orig.dissimilarity_measure
)
self.assertEqual(copy.descriptors, orig.descriptors)
self.assertEqual(copy.rdm_descriptors, orig.rdm_descriptors)
assert_array_equal(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_rdms_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations
from unittest import TestCase
from typing import TYPE_CHECKING, Union, List
from numpy.testing import assert_array_equal
import numpy
from pandas import Series, DataFrame
if TYPE_CHECKING:
from numpy.typing import NDArray


class RdmsToPandasTests(TestCase):

def assertValuesEqual(self,
actual: Series,
expected: Union[NDArray, List]):
assert_array_equal(numpy.asarray(actual.values), expected)

def test_to_df(self):
"""Convert an RDMs object to a pandas DataFrame
Default is long form; multiple rdms are stacked row-wise.
"""
from rsatoolbox.rdm.rdms import RDMs
dissimilarities = numpy.random.rand(2, 6)
rdms = RDMs(
dissimilarities,
rdm_descriptors=dict(xy=[c for c in 'xy']),
pattern_descriptors=dict(abcd=numpy.asarray([c for c in 'abcd']))
)
df = rdms.to_df()
self.assertIsInstance(df, DataFrame)
self.assertEqual(len(df.columns), 7)
self.assertValuesEqual(df.dissimilarity, dissimilarities.ravel())
self.assertValuesEqual(df['rdm_index'], ([0]*6) + ([1]*6))
self.assertValuesEqual(df['xy'], (['x']*6) + (['y']*6))
self.assertValuesEqual(df['pattern_index_1'],
([0]*3 + [1]*2 + [2]*1)*2)
self.assertValuesEqual(df['pattern_index_2'], [1, 2, 3, 2, 3, 3]*2)
self.assertValuesEqual(df['abcd_1'], [c for c in 'aaabbc']*2)
self.assertValuesEqual(df['abcd_2'], [c for c in 'bcdcdd']*2)

0 comments on commit 2361bc5

Please sign in to comment.