From a91765d4b592ceac1209f40c3261380c9081155c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 13 Jun 2023 11:20:26 +0200 Subject: [PATCH] fix: overhaul of image manipulation to prepare plotting --- nireports/interfaces/mosaic.py | 2 +- nireports/reportlets/mosaic.py | 126 ++++++++++++++++----------------- 2 files changed, 62 insertions(+), 66 deletions(-) diff --git a/nireports/interfaces/mosaic.py b/nireports/interfaces/mosaic.py index 76fa34ca..3648e219 100644 --- a/nireports/interfaces/mosaic.py +++ b/nireports/interfaces/mosaic.py @@ -111,7 +111,7 @@ def _run_interface(self, runtime): class _PlotMosaicInputSpec(_PlotBaseInputSpec): bbox_mask_file = File(exists=True, desc="brain mask") - only_noise = traits.Bool(False, desc="plot only noise") + only_noise = traits.Bool(False, usedefault=True, desc="plot only noise") view = traits.List( traits.Enum("axial", "sagittal", "coronal"), value=["axial", "sagittal"], diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 93ba339b..cfa4b53a 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -23,6 +23,7 @@ # STATEMENT OF CHANGES: This file was ported carrying over full git history from # NiPreps projects licensed under the Apache-2.0 terms. """Base components to generate mosaic-like reportlets.""" +from warnings import warn from uuid import uuid4 from os import path as op import math @@ -266,7 +267,6 @@ def plot_slice( vmax=None, vmin=None, annotate=None, - swapaxes=False, ): if isinstance(cmap, (str, bytes)): cmap = get_cmap(cmap) @@ -283,9 +283,10 @@ def plot_slice( if spacing is None: spacing = [1.0, 1.0] - if swapaxes: - dslice = np.swapaxes(dslice, 0, 1) - spacing = (spacing[1], spacing[0]) + # Always swap axes because imshow defines the image as (M, N) + # where M are rows (i.e., Y axis) and N are columns (X axis) + dslice = np.swapaxes(dslice, 0, 1) + spacing = (spacing[1], spacing[0]) ax.imshow( dslice, @@ -505,7 +506,11 @@ def plot_mosaic( ): """Plot a mosaic of 2D cuts.""" - VIEW_AXES_ORDER = (2, 1, 0) + VIEW_AXES_ORDER = { + "axial": (0, 1, 2), + "sagittal": (2, 0, 1), + "coronal": (0, 2, 1), + } if isinstance(views, str): views = (views, None, None) @@ -532,18 +537,9 @@ def plot_mosaic( out_file = "mosaic.svg" if plot_sagittal and views[1] is None and views[0] != "sagittal": + warn("Argument ``plot_sagittal`` for plot_mosaic() should not be used.") views = (views[0], "sagittal", None) - # Select the axis through which we cut the planes - axes_order = [ - ["sagittal", "coronal", "axial"].index(v) - for v in views if v is not None - ] - - # Complete axes - if len(axes_order) < 3: - axes_order += list(set(range(3)) - set(axes_order)) - # Create mask for bounding box bbox_data = None if bbox_mask_file is not None: @@ -567,27 +563,10 @@ def plot_mosaic( nrows = min((shape[-1] + 1) // ncols, maxrows) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ("sagittal", "axial", "coronal"), - ("sagittal", "axial", None), - - ) - # create figures if fig is None: fig = plt.figure(layout=None) - # Remove extra dimensions - img_data = np.moveaxis( - np.squeeze(img_data), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - # Load overlay if present if overlay_mask: overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata() @@ -595,12 +574,6 @@ def plot_mosaic( if bbox_data is not None: overlay_data = _bbox(overlay_data, bbox_data) - overlay_data = np.moveaxis( - overlay_data, - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - # Decimate if too many values z_vals = np.unique(np.linspace( 0, shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, @@ -657,6 +630,20 @@ def plot_mosaic( wspace=0.0001, ) + # Mosaic view 1: nrows x ncols (main view) + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + for ii, row_slices in enumerate(main_mosaic_idx): for jj, z_val in enumerate(row_slices): if z_val < 0: @@ -667,14 +654,13 @@ def plot_mosaic( panel_axs[ii, jj].set_rasterized(True) plot_slice( - img_data[:, :, z_val], + view_data[:, :, z_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, spacing=view_spacing[0], label=f"{z_val:d}", - swapaxes=swapaxes, annotate=axannotation[views[0]] if annotate else None, ) @@ -686,69 +672,78 @@ def plot_mosaic( alphas = np.linspace(0, 0.75, msk_cmap.N + 3) msk_cmap._lut[:, -1] = alphas plot_slice( - overlay_data[:, :, z_val], + view_overlay_data[:, :, z_val], vmin=0, vmax=1, cmap=msk_cmap, ax=panel_axs[ii, jj], spacing=view_spacing[0], - swapaxes=swapaxes, ) if views[1] is not None: - step = max(int(img_data.shape[1] / (ncols[1] + 1)), 1) + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], + ) + step = max(int(view_data.shape[2] / (ncols[1] + 1)), 1) start = step - stop = img_data.shape[1] - step + stop = view_data.shape[2] - step panel_axs = subfigs[1].subgridspec(1, ncols[1], wspace=0.0001) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("axial", "sagittal", None), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ) - y_vals = np.linspace(start, stop, num=ncols[1], dtype=int, endpoint=True) for jj, slice_val in enumerate(y_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[:, slice_val, :], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", spacing=view_spacing[1], - swapaxes=swapaxes, annotate=axannotation[views[1]] if annotate else None, ) if views[1] is not None and views[2] is not None: - step = max(int(img_data.shape[0] / (ncols[2] + 1)), 1) + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], + ) + + step = max(int(view_data.shape[2] / (ncols[2] + 1)), 1) start = step - stop = img_data.shape[0] - step + stop = view_data.shape[2] - step panel_axs = subfigs[2].subgridspec(1, ncols[2], wspace=0.0001) - swapaxes = views in ( - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("coronal", "sagittal", "axial"), - ) - x_vals = np.linspace(start, stop, num=ncols[2], dtype=int, endpoint=True) for jj, slice_val in enumerate(x_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[slice_val, ...], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", spacing=view_spacing[2], - swapaxes=swapaxes, annotate=axannotation[views[2]] if annotate else None, ) @@ -764,4 +759,5 @@ def plot_mosaic( out_file = op.abspath(fname + "_mosaic.svg") fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight") + plt.close(fig) return out_file