From d15868daa773cbcc8a509063b1ee40127483738f Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 7 Jun 2023 10:46:28 -0400 Subject: [PATCH 1/4] enh: add ratio calculation to plotting --- nireports/reportlets/mosaic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 9b382718..22bc1ff2 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -522,7 +522,7 @@ def plot_mosaic( if not hasattr(img, "shape"): nii = nb.as_closest_canonical(nb.load(img)) img_data = nii.get_fdata() - zooms = nii.header.get_zooms() + zooms = nii.header.get_zooms()[:3] else: img_data = img zooms = [1.0, 1.0, 1.0] @@ -530,8 +530,8 @@ def plot_mosaic( shape = img_data.shape[:3] view_hratios = { - "axial": 1.0, - "coronal": (zooms[2] * shape[2]) / (zooms[1] * shape[1]), + "axial": (zooms[1] * shape[1]) / (zooms[0] * shape[0]), + "coronal": (zooms[2] * shape[2]) / (zooms[0] * shape[0]), "sagittal": (zooms[2] * shape[2]) / (zooms[1] * shape[1]), } view_x = {"axial": 0, "coronal": 0, "sagittal": 1} From 585febfb7185b37cfb4a5174028de41d0a8e83f3 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 7 Jun 2023 11:22:33 -0400 Subject: [PATCH 2/4] fix: calculate ratios considering the mask --- nireports/reportlets/mosaic.py | 177 +++++++++++++++++---------------- 1 file changed, 89 insertions(+), 88 deletions(-) diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 22bc1ff2..ab79ec61 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -265,7 +265,7 @@ def plot_slice( ax=None, vmax=None, vmin=None, - annotate=False, + annotate=None, swapaxes=False, ): if isinstance(cmap, (str, bytes)): @@ -305,11 +305,11 @@ def plot_slice( bgcolor = cmap(min(vmin, 0.0)) fgcolor = cmap(vmax) - if annotate: + if annotate is not None: ax.text( 0.95, 0.95, - "R", + annotate[0], color=fgcolor, transform=ax.transAxes, horizontalalignment="center", @@ -320,7 +320,7 @@ def plot_slice( ax.text( 0.05, 0.95, - "L", + annotate[1], color=fgcolor, transform=ax.transAxes, horizontalalignment="center", @@ -498,7 +498,7 @@ def plot_mosaic( vmin=None, vmax=None, cmap="Greys_r", - plot_sagittal=True, + plot_sagittal=False, fig=None, maxrows=16, views=("axial", "sagittal", None), @@ -507,6 +507,9 @@ def plot_mosaic( VIEW_AXES_ORDER = (2, 1, 0) + if isinstance(views, str): + views = (views, None, None) + if len(views) != 3: _views = [None, None, None] @@ -528,32 +531,54 @@ def plot_mosaic( zooms = [1.0, 1.0, 1.0] out_file = "mosaic.svg" - shape = img_data.shape[:3] - view_hratios = { - "axial": (zooms[1] * shape[1]) / (zooms[0] * shape[0]), - "coronal": (zooms[2] * shape[2]) / (zooms[0] * shape[0]), - "sagittal": (zooms[2] * shape[2]) / (zooms[1] * shape[1]), - } - view_x = {"axial": 0, "coronal": 0, "sagittal": 1} - view_y = {"axial": 1, "coronal": 2, "sagittal": 2} - if plot_sagittal and views[1] is None and views[0] != "sagittal": views = (views[0], "sagittal", None) # Select the axis through which we cut the planes axes_order = [ - ["sagittal", "coronal", "axial"].index(views[0]), + ["sagittal", "coronal", "axial"].index(v) + for v in views if v is not None ] - if views[1] is not None: - axes_order.append(["sagittal", "coronal", "axial"].index(views[1])) - - # If 3D, complete last axis - if img_data.ndim > 3: - raise RuntimeError("Dataset has more than three dimensions") - elif img_data.ndim == 3: + # Complete axes + if len(axes_order) < 3: axes_order += list(set(range(3)) - set(axes_order)) + # Create mask for bounding box + bbox_data = None + if bbox_mask_file is not None: + bbox_data = np.asanyarray(nb.as_closest_canonical(nb.load(bbox_mask_file))) > 1e-3 + elif img_data.shape[-1] > (ncols * maxrows): + lowthres = np.percentile(img_data, 5) + bbox_data = np.ones_like(img_data) + bbox_data[img_data <= lowthres] = 0 + + if bbox_data is not None: + img_data = _bbox(img_data, bbox_data) + + shape = np.array(img_data.shape[:3]) + extents = shape * zooms + + view_x = {"axial": 0, "coronal": 0, "sagittal": 1} + view_y = {"axial": 1, "coronal": 2, "sagittal": 2} + axannotation = {"axial": "RL", "coronal": "RL", "sagittal": "AP"} + + nrows = min((shape[-1] + 1) // ncols, maxrows) + + swapaxes = views in ( + ("axial", "coronal", None), + ("axial", "coronal", "sagittal"), + ("coronal", "axial", None), + ("coronal", "axial", "sagittal"), + ("sagittal", "axial", "coronal"), + ("sagittal", "axial", None), + + ) + + # create figures + if fig is None: + fig = plt.figure(layout=None) + # Remove extra dimensions img_data = np.moveaxis( np.squeeze(img_data), @@ -563,31 +588,20 @@ def plot_mosaic( # Load overlay if present if overlay_mask: - overlay_data = np.moveaxis( - nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata(), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) + overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata() - # Create mask for bounding box - if bbox_mask_file is not None: - bbox_data = np.moveaxis( - nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata(), + if bbox_data is not None: + overlay_data = _bbox(overlay_data, bbox_data) + + overlay_data = np.moveaxis( + overlay_data, axes_order, VIEW_AXES_ORDER[:len(axes_order)], ) - img_data = _bbox(img_data, bbox_data) - elif img_data.shape[-1] > (ncols * maxrows): - lowthres = np.percentile(img_data, 5) - mask_file = np.ones_like(img_data) - mask_file[img_data <= lowthres] = 0 - img_data = _bbox(img_data, mask_file) - - nrows = min((img_data.shape[-1] + 1) // ncols, maxrows) # Decimate if too many values z_vals = np.unique(np.linspace( - 0, img_data.shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, + 0, shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, )) n_gs = sum(bool(v) for v in views) @@ -595,49 +609,36 @@ def plot_mosaic( main_mosaic_idx[:len(z_vals)] = z_vals main_mosaic_idx = main_mosaic_idx.reshape(nrows, ncols) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ("sagittal", "axial", "coronal"), - ("sagittal", "axial", None), - - ) - - nrows = [nrows, 1, 1] - - # create figures - - if fig is None: - fig = plt.figure(layout=None) - fig_height = [] panel_width = [] + ncols = [ncols] + view_spacing = [] for ii, vv in enumerate(views): if vv is None: break axis_x = view_x[vv] axis_y = view_y[vv] + view_spacing.append((zooms[axis_x], zooms[axis_y])) + + view_rows = nrows + if ii > 0: + ncols.append(int(panel_width[0] // extents[axis_x])) + view_rows = 1 - fig_height.append(zooms[axis_y] * shape[axis_y] * nrows[ii]) - panel_width.append(zooms[axis_x] * shape[axis_x]) + fig_height.append(extents[axis_y] * view_rows) + panel_width.append(extents[axis_x] * ncols[-1]) - fig_ratio = sum(fig_height) / (panel_width[0] * ncols) + fig_ratio = sum(fig_height) / panel_width[0] fig.set_size_inches(20, 20 * fig_ratio) - height_ratios = [ - view_hratios[r] * nrows[i] - for i, r in enumerate(views) if r is not None - ] subfigs = GridSpec( nrows=n_gs, ncols=1, - top=0.96, - bottom=0.01, - hspace=0.08, - height_ratios=height_ratios if len(height_ratios) > 1 else [1], + # top=0.96, + # bottom=0.01, + hspace=0.001, + height_ratios=np.array(fig_height) / fig_height[0] ) est_vmin, est_vmax = _get_limits(img_data, only_plot_noise=only_plot_noise) @@ -646,10 +647,14 @@ def plot_mosaic( if not vmax: vmax = est_vmax - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[0]] - # Fill in the main mosaic panel - panel_axs = subfigs[0].subgridspec(nrows[0], ncols, hspace=0.01, wspace=0.00001) + panel_axs = subfigs[0].subgridspec( + nrows, + ncols[0], + hspace=0.0001, + wspace=0.0001, + ) + for ii, row_slices in enumerate(main_mosaic_idx): for jj, z_val in enumerate(row_slices): if z_val < 0: @@ -665,10 +670,10 @@ def plot_mosaic( vmax=vmax, cmap=cmap, ax=ax, - spacing=slice_spacing, + spacing=view_spacing[0], label=f"{z_val:d}", swapaxes=swapaxes, - annotate=annotate and views[0] in ("axial", "coronal"), + annotate=axannotation[views[0]] if annotate else None, ) if overlay_mask: @@ -684,17 +689,15 @@ def plot_mosaic( vmax=1, cmap=msk_cmap, ax=panel_axs[ii, jj], - spacing=slice_spacing, + spacing=view_spacing[0], swapaxes=swapaxes, ) if views[1] is not None: - ncols_2 = math.floor((panel_width[0] * ncols) / panel_width[1]) - 1 - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[1]] - step = max(int(img_data.shape[1] / (ncols_2 + 1)), 1) + step = max(int(img_data.shape[1] / (ncols[1] + 1)), 1) start = step stop = img_data.shape[1] - step - panel_axs = subfigs[1].subgridspec(1, ncols_2, wspace=0.0001) + panel_axs = subfigs[1].subgridspec(1, ncols[1], wspace=0.0001) swapaxes = views in ( ("axial", "coronal", None), @@ -705,7 +708,7 @@ def plot_mosaic( ("coronal", "axial", "sagittal"), ) - y_vals = np.linspace(start, stop, num=ncols_2, dtype=int, endpoint=True) + y_vals = np.linspace(start, stop, num=ncols[1], dtype=int, endpoint=True) for jj, slice_val in enumerate(y_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( @@ -715,18 +718,16 @@ def plot_mosaic( cmap=cmap, ax=ax, label=f"{slice_val:d}", - spacing=slice_spacing, + spacing=view_spacing[1], swapaxes=swapaxes, - annotate=annotate and views[1] in ("axial", "coronal"), + annotate=axannotation[views[1]] if annotate else None, ) if views[1] is not None and views[2] is not None: - ncols_3 = math.floor((panel_width[0] * ncols) / panel_width[2]) - 1 - slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[2]] - step = max(int(img_data.shape[0] / (ncols_3 + 1)), 1) + step = max(int(img_data.shape[0] / (ncols[2] + 1)), 1) start = step stop = img_data.shape[0] - step - panel_axs = subfigs[2].subgridspec(1, ncols_3, wspace=0.0001) + panel_axs = subfigs[2].subgridspec(1, ncols[2], wspace=0.0001) swapaxes = views in ( ("axial", "coronal", "sagittal"), @@ -734,7 +735,7 @@ def plot_mosaic( ("coronal", "sagittal", "axial"), ) - x_vals = np.linspace(start, stop, num=ncols_3, dtype=int, endpoint=True) + x_vals = np.linspace(start, stop, num=ncols[2], dtype=int, endpoint=True) for jj, slice_val in enumerate(x_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( @@ -744,9 +745,9 @@ def plot_mosaic( cmap=cmap, ax=ax, label=f"{slice_val:d}", - spacing=slice_spacing, + spacing=view_spacing[2], swapaxes=swapaxes, - annotate=annotate and views[2] in ("axial", "coronal"), + annotate=axannotation[views[2]] if annotate else None, ) if title: From 284def8ebe6a5f9d1431379477bbbc7984e8bbb0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Mon, 12 Jun 2023 16:47:02 +0200 Subject: [PATCH 3/4] fix: bad mask comparison --- nireports/reportlets/mosaic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index ab79ec61..93ba339b 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -547,7 +547,9 @@ def plot_mosaic( # Create mask for bounding box bbox_data = None if bbox_mask_file is not None: - bbox_data = np.asanyarray(nb.as_closest_canonical(nb.load(bbox_mask_file))) > 1e-3 + bbox_data = np.asanyarray( + nb.as_closest_canonical(nb.load(bbox_mask_file)).dataobj + ) > 1e-3 elif img_data.shape[-1] > (ncols * maxrows): lowthres = np.percentile(img_data, 5) bbox_data = np.ones_like(img_data) From a91765d4b592ceac1209f40c3261380c9081155c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 13 Jun 2023 11:20:26 +0200 Subject: [PATCH 4/4] fix: overhaul of image manipulation to prepare plotting --- nireports/interfaces/mosaic.py | 2 +- nireports/reportlets/mosaic.py | 126 ++++++++++++++++----------------- 2 files changed, 62 insertions(+), 66 deletions(-) diff --git a/nireports/interfaces/mosaic.py b/nireports/interfaces/mosaic.py index 76fa34ca..3648e219 100644 --- a/nireports/interfaces/mosaic.py +++ b/nireports/interfaces/mosaic.py @@ -111,7 +111,7 @@ def _run_interface(self, runtime): class _PlotMosaicInputSpec(_PlotBaseInputSpec): bbox_mask_file = File(exists=True, desc="brain mask") - only_noise = traits.Bool(False, desc="plot only noise") + only_noise = traits.Bool(False, usedefault=True, desc="plot only noise") view = traits.List( traits.Enum("axial", "sagittal", "coronal"), value=["axial", "sagittal"], diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 93ba339b..cfa4b53a 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -23,6 +23,7 @@ # STATEMENT OF CHANGES: This file was ported carrying over full git history from # NiPreps projects licensed under the Apache-2.0 terms. """Base components to generate mosaic-like reportlets.""" +from warnings import warn from uuid import uuid4 from os import path as op import math @@ -266,7 +267,6 @@ def plot_slice( vmax=None, vmin=None, annotate=None, - swapaxes=False, ): if isinstance(cmap, (str, bytes)): cmap = get_cmap(cmap) @@ -283,9 +283,10 @@ def plot_slice( if spacing is None: spacing = [1.0, 1.0] - if swapaxes: - dslice = np.swapaxes(dslice, 0, 1) - spacing = (spacing[1], spacing[0]) + # Always swap axes because imshow defines the image as (M, N) + # where M are rows (i.e., Y axis) and N are columns (X axis) + dslice = np.swapaxes(dslice, 0, 1) + spacing = (spacing[1], spacing[0]) ax.imshow( dslice, @@ -505,7 +506,11 @@ def plot_mosaic( ): """Plot a mosaic of 2D cuts.""" - VIEW_AXES_ORDER = (2, 1, 0) + VIEW_AXES_ORDER = { + "axial": (0, 1, 2), + "sagittal": (2, 0, 1), + "coronal": (0, 2, 1), + } if isinstance(views, str): views = (views, None, None) @@ -532,18 +537,9 @@ def plot_mosaic( out_file = "mosaic.svg" if plot_sagittal and views[1] is None and views[0] != "sagittal": + warn("Argument ``plot_sagittal`` for plot_mosaic() should not be used.") views = (views[0], "sagittal", None) - # Select the axis through which we cut the planes - axes_order = [ - ["sagittal", "coronal", "axial"].index(v) - for v in views if v is not None - ] - - # Complete axes - if len(axes_order) < 3: - axes_order += list(set(range(3)) - set(axes_order)) - # Create mask for bounding box bbox_data = None if bbox_mask_file is not None: @@ -567,27 +563,10 @@ def plot_mosaic( nrows = min((shape[-1] + 1) // ncols, maxrows) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ("sagittal", "axial", "coronal"), - ("sagittal", "axial", None), - - ) - # create figures if fig is None: fig = plt.figure(layout=None) - # Remove extra dimensions - img_data = np.moveaxis( - np.squeeze(img_data), - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - # Load overlay if present if overlay_mask: overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata() @@ -595,12 +574,6 @@ def plot_mosaic( if bbox_data is not None: overlay_data = _bbox(overlay_data, bbox_data) - overlay_data = np.moveaxis( - overlay_data, - axes_order, - VIEW_AXES_ORDER[:len(axes_order)], - ) - # Decimate if too many values z_vals = np.unique(np.linspace( 0, shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True, @@ -657,6 +630,20 @@ def plot_mosaic( wspace=0.0001, ) + # Mosaic view 1: nrows x ncols (main view) + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[0]], + ) + for ii, row_slices in enumerate(main_mosaic_idx): for jj, z_val in enumerate(row_slices): if z_val < 0: @@ -667,14 +654,13 @@ def plot_mosaic( panel_axs[ii, jj].set_rasterized(True) plot_slice( - img_data[:, :, z_val], + view_data[:, :, z_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, spacing=view_spacing[0], label=f"{z_val:d}", - swapaxes=swapaxes, annotate=axannotation[views[0]] if annotate else None, ) @@ -686,69 +672,78 @@ def plot_mosaic( alphas = np.linspace(0, 0.75, msk_cmap.N + 3) msk_cmap._lut[:, -1] = alphas plot_slice( - overlay_data[:, :, z_val], + view_overlay_data[:, :, z_val], vmin=0, vmax=1, cmap=msk_cmap, ax=panel_axs[ii, jj], spacing=view_spacing[0], - swapaxes=swapaxes, ) if views[1] is not None: - step = max(int(img_data.shape[1] / (ncols[1] + 1)), 1) + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[1]], + ) + step = max(int(view_data.shape[2] / (ncols[1] + 1)), 1) start = step - stop = img_data.shape[1] - step + stop = view_data.shape[2] - step panel_axs = subfigs[1].subgridspec(1, ncols[1], wspace=0.0001) - swapaxes = views in ( - ("axial", "coronal", None), - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("axial", "sagittal", None), - ("coronal", "axial", None), - ("coronal", "axial", "sagittal"), - ) - y_vals = np.linspace(start, stop, num=ncols[1], dtype=int, endpoint=True) for jj, slice_val in enumerate(y_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[:, slice_val, :], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", spacing=view_spacing[1], - swapaxes=swapaxes, annotate=axannotation[views[1]] if annotate else None, ) if views[1] is not None and views[2] is not None: - step = max(int(img_data.shape[0] / (ncols[2] + 1)), 1) + # Mosaic view 2 + view_data = np.moveaxis( + np.squeeze(img_data), + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], + ) + + if overlay_mask: + view_overlay_data = np.moveaxis( + overlay_data, + (0, 1, 2), + VIEW_AXES_ORDER[views[2]], + ) + + step = max(int(view_data.shape[2] / (ncols[2] + 1)), 1) start = step - stop = img_data.shape[0] - step + stop = view_data.shape[2] - step panel_axs = subfigs[2].subgridspec(1, ncols[2], wspace=0.0001) - swapaxes = views in ( - ("axial", "coronal", "sagittal"), - ("axial", "sagittal", "coronal"), - ("coronal", "sagittal", "axial"), - ) - x_vals = np.linspace(start, stop, num=ncols[2], dtype=int, endpoint=True) for jj, slice_val in enumerate(x_vals): ax = fig.add_subplot(panel_axs[jj]) plot_slice( - img_data[slice_val, ...], + view_data[:, :, slice_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, label=f"{slice_val:d}", spacing=view_spacing[2], - swapaxes=swapaxes, annotate=axannotation[views[2]] if annotate else None, ) @@ -764,4 +759,5 @@ def plot_mosaic( out_file = op.abspath(fname + "_mosaic.svg") fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight") + plt.close(fig) return out_file