Skip to content

Commit

Permalink
fix: overhaul of image manipulation to prepare plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jun 13, 2023
1 parent 284def8 commit a91765d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 66 deletions.
2 changes: 1 addition & 1 deletion nireports/interfaces/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _run_interface(self, runtime):

class _PlotMosaicInputSpec(_PlotBaseInputSpec):
bbox_mask_file = File(exists=True, desc="brain mask")
only_noise = traits.Bool(False, desc="plot only noise")
only_noise = traits.Bool(False, usedefault=True, desc="plot only noise")
view = traits.List(
traits.Enum("axial", "sagittal", "coronal"),
value=["axial", "sagittal"],
Expand Down
126 changes: 61 additions & 65 deletions nireports/reportlets/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# STATEMENT OF CHANGES: This file was ported carrying over full git history from
# NiPreps projects licensed under the Apache-2.0 terms.
"""Base components to generate mosaic-like reportlets."""
from warnings import warn
from uuid import uuid4
from os import path as op
import math
Expand Down Expand Up @@ -266,7 +267,6 @@ def plot_slice(
vmax=None,
vmin=None,
annotate=None,
swapaxes=False,
):
if isinstance(cmap, (str, bytes)):
cmap = get_cmap(cmap)
Expand All @@ -283,9 +283,10 @@ def plot_slice(
if spacing is None:
spacing = [1.0, 1.0]

if swapaxes:
dslice = np.swapaxes(dslice, 0, 1)
spacing = (spacing[1], spacing[0])
# Always swap axes because imshow defines the image as (M, N)
# where M are rows (i.e., Y axis) and N are columns (X axis)
dslice = np.swapaxes(dslice, 0, 1)
spacing = (spacing[1], spacing[0])

ax.imshow(
dslice,
Expand Down Expand Up @@ -505,7 +506,11 @@ def plot_mosaic(
):
"""Plot a mosaic of 2D cuts."""

VIEW_AXES_ORDER = (2, 1, 0)
VIEW_AXES_ORDER = {
"axial": (0, 1, 2),
"sagittal": (2, 0, 1),
"coronal": (0, 2, 1),
}

if isinstance(views, str):
views = (views, None, None)
Expand All @@ -532,18 +537,9 @@ def plot_mosaic(
out_file = "mosaic.svg"

if plot_sagittal and views[1] is None and views[0] != "sagittal":
warn("Argument ``plot_sagittal`` for plot_mosaic() should not be used.")
views = (views[0], "sagittal", None)

# Select the axis through which we cut the planes
axes_order = [
["sagittal", "coronal", "axial"].index(v)
for v in views if v is not None
]

# Complete axes
if len(axes_order) < 3:
axes_order += list(set(range(3)) - set(axes_order))

# Create mask for bounding box
bbox_data = None
if bbox_mask_file is not None:
Expand All @@ -567,40 +563,17 @@ def plot_mosaic(

nrows = min((shape[-1] + 1) // ncols, maxrows)

swapaxes = views in (
("axial", "coronal", None),
("axial", "coronal", "sagittal"),
("coronal", "axial", None),
("coronal", "axial", "sagittal"),
("sagittal", "axial", "coronal"),
("sagittal", "axial", None),

)

# create figures
if fig is None:
fig = plt.figure(layout=None)

# Remove extra dimensions
img_data = np.moveaxis(
np.squeeze(img_data),
axes_order,
VIEW_AXES_ORDER[:len(axes_order)],
)

# Load overlay if present
if overlay_mask:
overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata()

if bbox_data is not None:
overlay_data = _bbox(overlay_data, bbox_data)

overlay_data = np.moveaxis(
overlay_data,
axes_order,
VIEW_AXES_ORDER[:len(axes_order)],
)

# Decimate if too many values
z_vals = np.unique(np.linspace(
0, shape[-1] - 1, num=(ncols * nrows), dtype=int, endpoint=True,
Expand Down Expand Up @@ -657,6 +630,20 @@ def plot_mosaic(
wspace=0.0001,
)

# Mosaic view 1: nrows x ncols (main view)
view_data = np.moveaxis(
np.squeeze(img_data),
(0, 1, 2),
VIEW_AXES_ORDER[views[0]],
)

if overlay_mask:
view_overlay_data = np.moveaxis(
overlay_data,
(0, 1, 2),
VIEW_AXES_ORDER[views[0]],
)

for ii, row_slices in enumerate(main_mosaic_idx):
for jj, z_val in enumerate(row_slices):
if z_val < 0:
Expand All @@ -667,14 +654,13 @@ def plot_mosaic(
panel_axs[ii, jj].set_rasterized(True)

plot_slice(
img_data[:, :, z_val],
view_data[:, :, z_val],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
spacing=view_spacing[0],
label=f"{z_val:d}",
swapaxes=swapaxes,
annotate=axannotation[views[0]] if annotate else None,
)

Expand All @@ -686,69 +672,78 @@ def plot_mosaic(
alphas = np.linspace(0, 0.75, msk_cmap.N + 3)
msk_cmap._lut[:, -1] = alphas
plot_slice(
overlay_data[:, :, z_val],
view_overlay_data[:, :, z_val],
vmin=0,
vmax=1,
cmap=msk_cmap,
ax=panel_axs[ii, jj],
spacing=view_spacing[0],
swapaxes=swapaxes,
)

if views[1] is not None:
step = max(int(img_data.shape[1] / (ncols[1] + 1)), 1)
# Mosaic view 2
view_data = np.moveaxis(
np.squeeze(img_data),
(0, 1, 2),
VIEW_AXES_ORDER[views[1]],
)

if overlay_mask:
view_overlay_data = np.moveaxis(
overlay_data,
(0, 1, 2),
VIEW_AXES_ORDER[views[1]],
)
step = max(int(view_data.shape[2] / (ncols[1] + 1)), 1)
start = step
stop = img_data.shape[1] - step
stop = view_data.shape[2] - step
panel_axs = subfigs[1].subgridspec(1, ncols[1], wspace=0.0001)

swapaxes = views in (
("axial", "coronal", None),
("axial", "coronal", "sagittal"),
("axial", "sagittal", "coronal"),
("axial", "sagittal", None),
("coronal", "axial", None),
("coronal", "axial", "sagittal"),
)

y_vals = np.linspace(start, stop, num=ncols[1], dtype=int, endpoint=True)
for jj, slice_val in enumerate(y_vals):
ax = fig.add_subplot(panel_axs[jj])
plot_slice(
img_data[:, slice_val, :],
view_data[:, :, slice_val],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
label=f"{slice_val:d}",
spacing=view_spacing[1],
swapaxes=swapaxes,
annotate=axannotation[views[1]] if annotate else None,
)

if views[1] is not None and views[2] is not None:
step = max(int(img_data.shape[0] / (ncols[2] + 1)), 1)
# Mosaic view 2
view_data = np.moveaxis(
np.squeeze(img_data),
(0, 1, 2),
VIEW_AXES_ORDER[views[2]],
)

if overlay_mask:
view_overlay_data = np.moveaxis(
overlay_data,
(0, 1, 2),
VIEW_AXES_ORDER[views[2]],
)

step = max(int(view_data.shape[2] / (ncols[2] + 1)), 1)
start = step
stop = img_data.shape[0] - step
stop = view_data.shape[2] - step
panel_axs = subfigs[2].subgridspec(1, ncols[2], wspace=0.0001)

swapaxes = views in (
("axial", "coronal", "sagittal"),
("axial", "sagittal", "coronal"),
("coronal", "sagittal", "axial"),
)

x_vals = np.linspace(start, stop, num=ncols[2], dtype=int, endpoint=True)
for jj, slice_val in enumerate(x_vals):
ax = fig.add_subplot(panel_axs[jj])
plot_slice(
img_data[slice_val, ...],
view_data[:, :, slice_val],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
label=f"{slice_val:d}",
spacing=view_spacing[2],
swapaxes=swapaxes,
annotate=axannotation[views[2]] if annotate else None,
)

Expand All @@ -764,4 +759,5 @@ def plot_mosaic(
out_file = op.abspath(fname + "_mosaic.svg")

fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight")
plt.close(fig)
return out_file

0 comments on commit a91765d

Please sign in to comment.