Skip to content

Commit

Permalink
contour and overlay args on single rdm plot
Browse files Browse the repository at this point in the history
  • Loading branch information
JasperVanDenBosch committed Oct 17, 2023
1 parent 4b03de6 commit b1b6571
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
54 changes: 47 additions & 7 deletions src/rsatoolbox/vis/rdm_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def show_rdm_panel(
gridlines: Optional[npt.ArrayLike] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
overlay: Optional[NDArray] = None,
overlay_color: str = '#00ff0050',
contour: Optional[NDArray] = None,
contour_color: str = 'red'
) -> AxesImage:
"""show_rdm_panel. Add RDM heatmap to the axis ax.
Expand All @@ -259,12 +263,21 @@ def show_rdm_panel(
argument.
vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow
argument.
overlay ((str, str) or NDArray): RDM descriptor name-value tuple, or vector (one value per pair)
which indicates whether to highlight the given cells
overlay_color (str): Color to use to highlight the pairs in the overlay argument.
Use RGBA to specify transparency. Default is 50% opaque green.
contour ((str, str) or NDArray): RDM descriptor name-value tuple, or vector (one value per pair)
which indicates whether to add a border to the given cells
contour_color (str): Color to use for a border around pairs in the contour argument.
Use RGBA to specify transparency. Default is red.
Returns:
matplotlib.image.AxesImage: Matplotlib handle.
"""
conf = SingleRdmPlot.from_show_rdm_panel_args(rdms, cmap, nanmask,
rdm_descriptor, gridlines, vmin, vmax)
rdm_descriptor, gridlines, vmin, vmax, overlay, overlay_color,
contour, contour_color)
return _show_rdm_panel(conf, ax or plt.gca())


Expand Down Expand Up @@ -603,6 +616,10 @@ def for_single(self, index: int) -> SingleRdmPlot:
conf.nanmask = self.nanmask
conf.vmin = self.vmin
conf.vmax = self.vmax
conf.overlay = self.overlay
conf.overlay_color = self.overlay_color
conf.contour = self.contour
conf.contour_color = self.contour_color
if self.rdm_descriptor in conf.rdms.rdm_descriptors:
conf.title = conf.rdms.rdm_descriptors[self.rdm_descriptor][0]
else:
Expand All @@ -621,17 +638,25 @@ class SingleRdmPlot:
vmin: Optional[float]
vmax: Optional[float]
title: str
overlay: NDArray
overlay_color: str
contour: NDArray
contour_color: str

@classmethod
def from_show_rdm_panel_args(
cls,
rdms: RDMs,
cmap: Union[str, Colormap] = 'bone',
nanmask: Optional[NDArray] = None,
rdm_descriptor: Optional[str] = None,
gridlines: Optional[npt.ArrayLike] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
cmap: Union[str, Colormap],
nanmask: Optional[NDArray],
rdm_descriptor: Optional[str],
gridlines: Optional[npt.ArrayLike],
vmin: Optional[float],
vmax: Optional[float],
overlay: Optional[NDArray],
overlay_color: str,
contour: Optional[NDArray],
contour_color: str
) -> SingleRdmPlot:
"""Create an object from the original arguments to show_rdm_panel()
"""
Expand All @@ -654,4 +679,19 @@ def from_show_rdm_panel_args(
conf.title = rdm_descriptor or ''
conf.vmin = vmin
conf.vmax = vmax
conf.overlay = conf.interpret_rdm_arg(overlay, rdms)
conf.overlay_color = overlay_color
conf.contour = conf.interpret_rdm_arg(contour, rdms)
conf.contour_color = contour_color
return conf

def interpret_rdm_arg(self, val: Optional[ArrayOrRdmDescriptor], rdms: RDMs) -> NDArray:
"""Resolve argument that can be an rdm descriptor key/value pair or a utv
"""
if val is None:
return np.zeros(rdms.n_cond)
else:
if isinstance(val, np.ndarray):
return val
else:
return rdms.subset(*val).dissimilarities[0, :]
33 changes: 29 additions & 4 deletions tests/test_vis_rdm_plot_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class TestRdmPlot(TestCase):

def test_nrow_ncolumn(self):
def test_from_show_rdm_args__nrow_ncolumn(self):
from rsatoolbox.vis.rdm_plot import MultiRdmPlot
rdms = Mock()
rdms.n_cond = 3
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_nrow_ncolumn(self):
self.assertEqual(conf.n_column, 4)
self.assertEqual(conf.n_row, 3)

def test_contour_overlay(self):
def test_from_show_rdm_args__multi_contour_overlay(self):
from rsatoolbox.vis.rdm_plot import MultiRdmPlot
rdms = Mock()
rdms.n_cond = 3
Expand Down Expand Up @@ -69,6 +69,31 @@ def test_contour_overlay(self):
contour = ('name', 'foo'),
contour_color = 'red'
)
assert_array_equal(conf.overlay, numpy.array([0,0,0,1,1,1]))
assert_array_equal(conf.overlay, numpy.array([0, 0, 0, 1, 1, 1]))
assert_array_equal(conf.contour, numpy.array([0, 1, 0, 1, 0, 1]))

def test_single_from_show_rdm_panel_args(self):
from rsatoolbox.vis.rdm_plot import SingleRdmPlot
rdms = Mock()
rdms.n_cond = 3
rdms.n_rdm = 1
rdms.get_matrices.return_value = numpy.zeros([1, 3, 3])
rdms.rdm_descriptors = dict()
mask_rdm = Mock()
mask_rdm.dissimilarities = numpy.array([[0, 1, 0, 1, 0, 1]])
rdms.dissimilarities = numpy.zeros([10, 6])
conf = SingleRdmPlot.from_show_rdm_panel_args(
rdms = rdms,
cmap = 'bone',
nanmask = None,
rdm_descriptor = None,
gridlines = None,
vmin = None,
vmax = None,
overlay = numpy.array([0, 0, 0, 1, 1, 1]),
overlay_color='#00ff0050',
contour = numpy.array([0, 1, 0, 1, 0, 1]),
contour_color = 'red'
)
assert_array_equal(conf.overlay, numpy.array([0, 0, 0, 1, 1, 1]))
assert_array_equal(conf.contour, numpy.array([0, 1, 0, 1, 0, 1]))

0 comments on commit b1b6571

Please sign in to comment.