diff --git a/nireports/interfaces/mosaic.py b/nireports/interfaces/mosaic.py index 76fa34ca..3648e219 100644 --- a/nireports/interfaces/mosaic.py +++ b/nireports/interfaces/mosaic.py @@ -111,7 +111,7 @@ def _run_interface(self, runtime): class _PlotMosaicInputSpec(_PlotBaseInputSpec): bbox_mask_file = File(exists=True, desc="brain mask") - only_noise = traits.Bool(False, desc="plot only noise") + only_noise = traits.Bool(False, usedefault=True, desc="plot only noise") view = traits.List( traits.Enum("axial", "sagittal", "coronal"), value=["axial", "sagittal"], diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 9b382718..cfa4b53a 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -23,6 +23,7 @@ # STATEMENT OF CHANGES: This file was ported carrying over full git history from # NiPreps projects licensed under the Apache-2.0 terms. """Base components to generate mosaic-like reportlets.""" +from warnings import warn from uuid import uuid4 from os import path as op import math @@ -265,8 +266,7 @@ def plot_slice( ax=None, vmax=None, vmin=None, - annotate=False, - swapaxes=False, + annotate=None, ): if isinstance(cmap, (str, bytes)): cmap = get_cmap(cmap) @@ -283,9 +283,10 @@ def plot_slice( if spacing is None: spacing = [1.0, 1.0] - if swapaxes: - dslice = np.swapaxes(dslice, 0, 1) - spacing = (spacing[1], spacing[0]) + # Always swap axes because imshow defines the image as (M, N) + # where M are rows (i.e., Y axis) and N are columns (X axis) + dslice = np.swapaxes(dslice, 0, 1) + spacing = (spacing[1], spacing[0]) ax.imshow( dslice, @@ -305,11 +306,11 @@ def plot_slice( bgcolor = cmap(min(vmin, 0.0)) fgcolor = cmap(vmax) - if annotate: + if annotate is not None: ax.text( 0.95, 0.95, - "R", + annotate[0], color=fgcolor, transform=ax.transAxes, horizontalalignment="center", @@ -320,7 +321,7 @@ def plot_slice( ax.text( 0.05, 0.95, - "L", + annotate[1], color=fgcolor, transform=ax.transAxes, horizontalalignment="center", @@ -498,14 +499,21 @@ def plot_mosaic( vmin=None, vmax=None, cmap="Greys_r", - plot_sagittal=True, + plot_sagittal=False, fig=None, maxrows=16, views=("axial", "sagittal", None), ): """Plot a mosaic of 2D cuts.""" - VIEW_AXES_ORDER = (2, 1, 0) + VIEW_AXES_ORDER = { + "axial": (0, 1, 2), + "sagittal": (2, 0, 1), + "coronal": (0, 2, 1), + } + + if isinstance(views, str): + views = (views, None, None) if len(views) != 3: _views = [None, None, None] @@ -522,72 +530,53 @@ def plot_mosaic( if not hasattr(img, "shape"): nii = nb.as_closest_canonical(nb.load(img)) img_data = nii.get_fdata() - zooms = nii.header.get_zooms() + zooms = nii.header.get_zooms()[:3] else: img_data = img zooms = [1.0, 1.0, 1.0] out_file = "mosaic.svg" - shape = img_data.shape[:3] - view_hratios = { - "axial": 1.0, - "coronal": (zooms[2] * shape[2]) / (zooms[1] * shape[1]), - "sagittal": (zooms[2] * shape[2]) / (zooms[1] * shape[1]), - } - view_x = {"axial": 0, "coronal": 0, "sagittal": 1} - view_y = {"axial": 1, "coronal": 2, "sagittal": 2} - if plot_sagittal and views[1] is None and views[0] != "sagittal": + warn("Argument ``plot_sagittal`` for plot_mosaic() should not be used.") views = (views[0], "sagittal", None) - # Select the axis through which we cut the planes - axes_order = [ - ["sagittal", "coronal", "axial"].index(views[0]), - ] + # Create mask for bounding box + bbox_data = None + if bbox_mask_file is not None: + bbox_data = np.asanyarray( + nb.as_closest_canonical(nb.load(bbox_mask_file)).dataobj + ) > 1e-3 + elif img_data.shape[-1] > (ncols * maxrows): + lowthres = np.percentile(img_data, 5) + bbox_data = np.ones_like(img_data) + bbox_data[img_data <= lowthres] = 0 - if views[1] is not None: - axes_order.append(["sagittal", "coronal", "axial"].index(views[1])) + if bbox_data is not None: + img_data = _bbox(img_data, bbox_data) - # If 3D, complete last axis - if img_data.ndim > 3: - raise RuntimeError("Dataset has more than three dimensions") - elif img_data.ndim == 3: - axes_order += list(set(range(3)) - set(axes_order)) + shape = np.array(img_data.shape[:3]) + extents = shape * zooms - # Remove extra dimensions - img_data = np.moveaxis( - np.squeeze(img_data), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) + view_x = {"axial": 0, "coronal": 0, "sagittal": 1} + view_y = {"axial": 1, "coronal": 2, "sagittal": 2} + axannotation = {"axial": "RL", "coronal": "RL", "sagittal": "AP"} + + nrows = min((shape[-1] + 1) // ncols, maxrows) + + # create figures + if fig is None: + fig = plt.figure(layout=None) # Load overlay if present if overlay_mask: - overlay_data = np.moveaxis( - nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata(), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - - # Create mask for bounding box - if bbox_mask_file is not None: - bbox_data = np.moveaxis( - nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata(), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - img_data = _bbox(img_data, bbox_data) - elif img_data.shape[-1] > (ncols * maxrows): - lowthres = np.percentile(img_data, 5) - mask_file = np.ones_like(img_data) - mask_file[img_data <= lowthres] = 0 - img_data = _bbox(img_data, mask_file) + overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata() - nrows = min((img_data.shape[-1] + 1) // ncols, maxrows) + if bbox_data is not None: + overlay_data = _bbox(overlay_data, bbox_data) # Decimate if too many values z_vals = np.unique(np.linspace( - 0, img_data.shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, + 0, shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, )) n_gs = sum(bool(v) for v in views) @@ -595,49 +584,36 @@ def plot_mosaic( main_mosaic_idx[:len(z_vals)] = z_vals main_mosaic_idx = main_mosaic_idx.reshape(nrows, ncols) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ("sagittal", "axial", "coronal"), - ("sagittal", "axial", None), - - ) - - nrows = [nrows, 1, 1] - - # create figures - - if fig is None: - fig = plt.figure(layout=None) - fig_height = [] panel_width = [] + ncols = [ncols] + view_spacing = [] for ii, vv in enumerate(views): if vv is None: break axis_x = view_x[vv] axis_y = view_y[vv] + view_spacing.append((zooms[axis_x], zooms[axis_y])) + + view_rows = nrows + if ii > 0: + ncols.append(int(panel_width[0] // extents[axis_x])) + view_rows = 1 - fig_height.append(zooms[axis_y] * shape[axis_y] * nrows[ii]) - panel_width.append(zooms[axis_x] * shape[axis_x]) + fig_height.append(extents[axis_y] * view_rows) + panel_width.append(extents[axis_x] * ncols[-1]) - fig_ratio = sum(fig_height) / (panel_width[0] * ncols) + fig_ratio = sum(fig_height) / panel_width[0] fig.set_size_inches(20, 20 * fig_ratio) - height_ratios = [ - view_hratios[r] * nrows[i] - for i, r in enumerate(views) if r is not None - ] subfigs = GridSpec( nrows=n_gs, ncols=1, - top=0.96, - bottom=0.01, - hspace=0.08, - height_ratios=height_ratios if len(height_ratios) > 1 else [1], + # top=0.96, + # bottom=0.01, + hspace=0.001, + height_ratios=np.array(fig_height) / fig_height[0] ) est_vmin, est_vmax = _get_limits(img_data, only_plot_noise=only_plot_noise) @@ -646,10 +622,28 @@ def plot_mosaic( if not vmax: vmax = est_vmax - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[0]] - # Fill in the main mosaic panel - panel_axs = subfigs[0].subgridspec(nrows[0], ncols, hspace=0.01, wspace=0.00001) + panel_axs = subfigs[0].subgridspec( + nrows, + ncols[0], + hspace=0.0001, + wspace=0.0001, + ) + + # Mosaic view 1: nrows x ncols (main view) + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + for ii, row_slices in enumerate(main_mosaic_idx): for jj, z_val in enumerate(row_slices): if z_val < 0: @@ -660,15 +654,14 @@ def plot_mosaic( panel_axs[ii, jj].set_rasterized(True) plot_slice( - img_data[:, :, z_val], + view_data[:, :, z_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, - spacing=slice_spacing, + spacing=view_spacing[0], label=f"{z_val:d}", - swapaxes=swapaxes, - annotate=annotate and views[0] in ("axial", "coronal"), + annotate=axannotation[views[0]] if annotate else None, ) if overlay_mask: @@ -679,74 +672,79 @@ def plot_mosaic( alphas = np.linspace(0, 0.75, msk_cmap.N + 3) msk_cmap._lut[:, -1] = alphas plot_slice( - overlay_data[:, :, z_val], + view_overlay_data[:, :, z_val], vmin=0, vmax=1, cmap=msk_cmap, ax=panel_axs[ii, jj], - spacing=slice_spacing, - swapaxes=swapaxes, + spacing=view_spacing[0], ) if views[1] is not None: - ncols_2 = math.floor((panel_width[0] * ncols) / panel_width[1]) - 1 - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[1]] - step = max(int(img_data.shape[1] / (ncols_2 + 1)), 1) - start = step - stop = img_data.shape[1] - step - panel_axs = subfigs[1].subgridspec(1, ncols_2, wspace=0.0001) - - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("axial", "sagittal", None), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], ) - y_vals = np.linspace(start, stop, num=ncols_2, dtype=int, endpoint=True) + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], + ) + step = max(int(view_data.shape[2] / (ncols[1] + 1)), 1) + start = step + stop = view_data.shape[2] - step + panel_axs = subfigs[1].subgridspec(1, ncols[1], wspace=0.0001) + + y_vals = np.linspace(start, stop, num=ncols[1], dtype=int, endpoint=True) for jj, slice_val in enumerate(y_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[:, slice_val, :], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", - spacing=slice_spacing, - swapaxes=swapaxes, - annotate=annotate and views[1] in ("axial", "coronal"), + spacing=view_spacing[1], + annotate=axannotation[views[1]] if annotate else None, ) if views[1] is not None and views[2] is not None: - ncols_3 = math.floor((panel_width[0] * ncols) / panel_width[2]) - 1 - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[2]] - step = max(int(img_data.shape[0] / (ncols_3 + 1)), 1) - start = step - stop = img_data.shape[0] - step - panel_axs = subfigs[2].subgridspec(1, ncols_3, wspace=0.0001) - - swapaxes = views in ( - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("coronal", "sagittal", "axial"), + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], ) - x_vals = np.linspace(start, stop, num=ncols_3, dtype=int, endpoint=True) + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], + ) + + step = max(int(view_data.shape[2] / (ncols[2] + 1)), 1) + start = step + stop = view_data.shape[2] - step + panel_axs = subfigs[2].subgridspec(1, ncols[2], wspace=0.0001) + + x_vals = np.linspace(start, stop, num=ncols[2], dtype=int, endpoint=True) for jj, slice_val in enumerate(x_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[slice_val, ...], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", - spacing=slice_spacing, - swapaxes=swapaxes, - annotate=annotate and views[2] in ("axial", "coronal"), + spacing=view_spacing[2], + annotate=axannotation[views[2]] if annotate else None, ) if title: @@ -761,4 +759,5 @@ def plot_mosaic( out_file = op.abspath(fname + "_mosaic.svg") fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight") + plt.close(fig) return out_file