diff --git a/src/rsatoolbox/vis/rdm_plot.py b/src/rsatoolbox/vis/rdm_plot.py index 6bd4d5c2..daf241e0 100755 --- a/src/rsatoolbox/vis/rdm_plot.py +++ b/src/rsatoolbox/vis/rdm_plot.py @@ -240,6 +240,10 @@ def show_rdm_panel( gridlines: Optional[npt.ArrayLike] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, + overlay: Optional[NDArray] = None, + overlay_color: str = '#00ff0050', + contour: Optional[NDArray] = None, + contour_color: str = 'red' ) -> AxesImage: """show_rdm_panel. Add RDM heatmap to the axis ax. @@ -259,12 +263,21 @@ def show_rdm_panel( argument. vmax (float): Maximum intensity for colorbar mapping. matplotlib imshow argument. + overlay ((str, str) or NDArray): RDM descriptor name-value tuple, or vector (one value per pair) + which indicates whether to highlight the given cells + overlay_color (str): Color to use to highlight the pairs in the overlay argument. + Use RGBA to specify transparency. Default is 50% opaque green. + contour ((str, str) or NDArray): RDM descriptor name-value tuple, or vector (one value per pair) + which indicates whether to add a border to the given cells + contour_color (str): Color to use for a border around pairs in the contour argument. + Use RGBA to specify transparency. Default is red. Returns: matplotlib.image.AxesImage: Matplotlib handle. """ conf = SingleRdmPlot.from_show_rdm_panel_args(rdms, cmap, nanmask, - rdm_descriptor, gridlines, vmin, vmax) + rdm_descriptor, gridlines, vmin, vmax, overlay, overlay_color, + contour, contour_color) return _show_rdm_panel(conf, ax or plt.gca()) @@ -603,6 +616,10 @@ def for_single(self, index: int) -> SingleRdmPlot: conf.nanmask = self.nanmask conf.vmin = self.vmin conf.vmax = self.vmax + conf.overlay = self.overlay + conf.overlay_color = self.overlay_color + conf.contour = self.contour + conf.contour_color = self.contour_color if self.rdm_descriptor in conf.rdms.rdm_descriptors: conf.title = conf.rdms.rdm_descriptors[self.rdm_descriptor][0] else: @@ -621,17 +638,25 @@ class SingleRdmPlot: vmin: Optional[float] vmax: Optional[float] title: str + overlay: NDArray + overlay_color: str + contour: NDArray + contour_color: str @classmethod def from_show_rdm_panel_args( cls, rdms: RDMs, - cmap: Union[str, Colormap] = 'bone', - nanmask: Optional[NDArray] = None, - rdm_descriptor: Optional[str] = None, - gridlines: Optional[npt.ArrayLike] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, + cmap: Union[str, Colormap], + nanmask: Optional[NDArray], + rdm_descriptor: Optional[str], + gridlines: Optional[npt.ArrayLike], + vmin: Optional[float], + vmax: Optional[float], + overlay: Optional[NDArray], + overlay_color: str, + contour: Optional[NDArray], + contour_color: str ) -> SingleRdmPlot: """Create an object from the original arguments to show_rdm_panel() """ @@ -654,4 +679,19 @@ def from_show_rdm_panel_args( conf.title = rdm_descriptor or '' conf.vmin = vmin conf.vmax = vmax + conf.overlay = conf.interpret_rdm_arg(overlay, rdms) + conf.overlay_color = overlay_color + conf.contour = conf.interpret_rdm_arg(contour, rdms) + conf.contour_color = contour_color return conf + + def interpret_rdm_arg(self, val: Optional[ArrayOrRdmDescriptor], rdms: RDMs) -> NDArray: + """Resolve argument that can be an rdm descriptor key/value pair or a utv + """ + if val is None: + return np.zeros(rdms.n_cond) + else: + if isinstance(val, np.ndarray): + return val + else: + return rdms.subset(*val).dissimilarities[0, :] diff --git a/tests/test_vis_rdm_plot_conf.py b/tests/test_vis_rdm_plot_conf.py index 94d3bdaa..7fd921c1 100644 --- a/tests/test_vis_rdm_plot_conf.py +++ b/tests/test_vis_rdm_plot_conf.py @@ -6,7 +6,7 @@ class TestRdmPlot(TestCase): - def test_nrow_ncolumn(self): + def test_from_show_rdm_args__nrow_ncolumn(self): from rsatoolbox.vis.rdm_plot import MultiRdmPlot rdms = Mock() rdms.n_cond = 3 @@ -37,7 +37,7 @@ def test_nrow_ncolumn(self): self.assertEqual(conf.n_column, 4) self.assertEqual(conf.n_row, 3) - def test_contour_overlay(self): + def test_from_show_rdm_args__multi_contour_overlay(self): from rsatoolbox.vis.rdm_plot import MultiRdmPlot rdms = Mock() rdms.n_cond = 3 @@ -69,6 +69,31 @@ def test_contour_overlay(self): contour = ('name', 'foo'), contour_color = 'red' ) - assert_array_equal(conf.overlay, numpy.array([0,0,0,1,1,1])) + assert_array_equal(conf.overlay, numpy.array([0, 0, 0, 1, 1, 1])) + assert_array_equal(conf.contour, numpy.array([0, 1, 0, 1, 0, 1])) + + def test_single_from_show_rdm_panel_args(self): + from rsatoolbox.vis.rdm_plot import SingleRdmPlot + rdms = Mock() + rdms.n_cond = 3 + rdms.n_rdm = 1 + rdms.get_matrices.return_value = numpy.zeros([1, 3, 3]) + rdms.rdm_descriptors = dict() + mask_rdm = Mock() + mask_rdm.dissimilarities = numpy.array([[0, 1, 0, 1, 0, 1]]) + rdms.dissimilarities = numpy.zeros([10, 6]) + conf = SingleRdmPlot.from_show_rdm_panel_args( + rdms = rdms, + cmap = 'bone', + nanmask = None, + rdm_descriptor = None, + gridlines = None, + vmin = None, + vmax = None, + overlay = numpy.array([0, 0, 0, 1, 1, 1]), + overlay_color='#00ff0050', + contour = numpy.array([0, 1, 0, 1, 0, 1]), + contour_color = 'red' + ) + assert_array_equal(conf.overlay, numpy.array([0, 0, 0, 1, 1, 1])) assert_array_equal(conf.contour, numpy.array([0, 1, 0, 1, 0, 1])) - \ No newline at end of file