From 5860c77505024c1009d33789d33ac65144de0461 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 21 Sep 2023 00:35:25 +0200 Subject: [PATCH] Add vsharp (2D & 3D), 3D transformers/engine, unet 3D, CMR challenge code (dataset, cartesuan subsampling), new metrics, transformer models --- direct/common/subsample.py | 273 +++ direct/data/datasets.py | 294 +++ direct/data/datasets_config.py | 17 +- direct/data/h5_data.py | 13 +- direct/data/mri_transforms.py | 370 +++- direct/data/transforms.py | 10 +- direct/engine.py | 29 +- direct/functionals/__init__.py | 2 + direct/functionals/hfen.py | 322 +++ direct/functionals/psnr.py | 25 +- direct/functionals/snr.py | 66 + direct/functionals/ssim.py | 62 +- direct/nn/conv/conv.py | 267 +++ direct/nn/get_nn_model_config.py | 66 +- direct/nn/lpd/config.py | 14 + direct/nn/lpd/lpd.py | 33 +- direct/nn/mri_models.py | 353 ++- direct/nn/transformers/__init__.py | 4 + direct/nn/transformers/config.py | 135 ++ direct/nn/transformers/transformers.py | 968 ++++++++ direct/nn/transformers/transformers_engine.py | 221 ++ direct/nn/transformers/uformer.py | 1961 +++++++++++++++++ direct/nn/transformers/utils.py | 206 ++ direct/nn/transformers/vision_transformers.py | 701 ++++++ direct/nn/types.py | 2 + direct/nn/unet/config.py | 23 + direct/nn/unet/unet_2d.py | 137 +- direct/nn/unet/unet_3d.py | 329 +++ direct/nn/varsplitnet/config.py | 93 +- direct/nn/varsplitnet/varsplitnet.py | 7 +- direct/nn/vsharp/__init__.py | 2 + direct/nn/vsharp/config.py | 136 ++ direct/nn/vsharp/vsharp.py | 601 +++++ direct/nn/vsharp/vsharp_engine.py | 273 +++ direct/predict.py | 3 +- direct/train.py | 4 +- direct/types.py | 10 + direct/utils/__init__.py | 23 +- tests/tests_common/test_subsample.py | 44 + tests/tests_data/test_mri_transforms.py | 261 ++- tests/tests_functionals/test_hfen.py | 42 + tests/tests_functionals/test_snr.py | 36 + tests/tests_functionals/test_ssim.py | 32 +- tests/tests_nn/test_conjgradnet_engine.py | 1 + tests/tests_nn/test_iterdualnet_engine.py | 2 +- tests/tests_nn/test_jointicnet_engine.py | 2 +- tests/tests_nn/test_kikinet_engine.py | 2 +- tests/tests_nn/test_lpd_engine.py | 1 + tests/tests_nn/test_multidomainnet_engine.py | 2 +- tests/tests_nn/test_recurrentvarnet_engine.py | 1 + tests/tests_nn/test_rim_engine.py | 1 + tests/tests_nn/test_unet_engine.py | 1 + tests/tests_nn/test_varnet_engine.py | 1 + tests/tests_nn/test_varsplitnet_engine.py | 1 + tests/tests_nn/test_vsharp.py | 130 ++ tests/tests_nn/test_vsharp_engine.py | 140 ++ tests/tests_nn/test_xpdnet_engine.py | 1 + 57 files changed, 8523 insertions(+), 233 deletions(-) create mode 100644 direct/functionals/hfen.py create mode 100644 direct/functionals/snr.py create mode 100644 direct/nn/transformers/__init__.py create mode 100644 direct/nn/transformers/config.py create mode 100644 direct/nn/transformers/transformers.py create mode 100644 direct/nn/transformers/transformers_engine.py create mode 100644 direct/nn/transformers/uformer.py create mode 100644 direct/nn/transformers/utils.py create mode 100644 direct/nn/transformers/vision_transformers.py create mode 100644 direct/nn/unet/unet_3d.py create mode 100644 direct/nn/vsharp/__init__.py create mode 100644 direct/nn/vsharp/config.py create mode 100644 direct/nn/vsharp/vsharp.py create mode 100644 direct/nn/vsharp/vsharp_engine.py create mode 100644 tests/tests_functionals/test_hfen.py create mode 100644 tests/tests_functionals/test_snr.py create mode 100644 tests/tests_nn/test_vsharp.py create mode 100644 tests/tests_nn/test_vsharp_engine.py diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 51144238..d3ea52a7 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -28,6 +28,9 @@ __all__ = ( "CalgaryCampinasMaskFunc", + "CartesianRandomMaskFunc", + "CartesianEquispacedMaskFunc", + "CartesianMagicMaskFunc", "FastMRIRandomMaskFunc", "FastMRIEquispacedMaskFunc", "FastMRIMagicMaskFunc", @@ -243,6 +246,84 @@ def mask_func( return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) +class CartesianRandomMaskFunc(FastMRIRandomMaskFunc): + r"""Cartesian random vertical line mask function. + + Similar to :class:`FastMRIRandomMaskFunc, but instead of center fraction (`center_fractions`) representing + the fraction of center lines to the original size, here, it represents the actual number of center lines. + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[int], Tuple[int, ...]]] = None, + uniform_range: bool = False, + ): + """Inits :class:`CartesianRandomMaskFunc`. + + Parameters + ---------- + accelerations: Union[List[Number], Tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions: list or tuple of ints, optional + Number of low-frequency (center) columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly each time. + uniform_range: bool + If True then an acceleration will be uniformly sampled between the two values. Default: True. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates a random vertical Cartesian mask. + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + + Returns + ------- + mask: torch.Tensor + The sampling mask. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + + num_center_lines, acceleration = self.choose_acceleration() + print(num_center_lines, acceleration) + + mask = self.center_mask_func(num_cols, num_center_lines) + + if return_acs: + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + # Create the mask + prob = (num_cols / acceleration - num_center_lines) / (num_cols - num_center_lines) + mask = mask | (self.rng.uniform(size=num_cols) < prob) + + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + class FastMRIEquispacedMaskFunc(FastMRIMaskFunc): r"""Equispaced vertical line mask function. @@ -324,6 +405,90 @@ def mask_func( return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) +class CartesianEquispacedMaskFunc(FastMRIEquispacedMaskFunc): + r"""Cartesian equispaced vertical line mask function. + + Similar to :class:`FastMRIEquispacedMaskFunc, but instead of center fraction (`center_fractions`) representing + the fraction of center lines to the original size, here, it represents the actual number of center lines. + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[int], Tuple[int, ...]]] = None, + uniform_range: bool = False, + ): + """Inits :class:`CartesianEquispacedMaskFunc`. + + Parameters + ---------- + accelerations: Union[List[Number], Tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions: list or tuple of ints, optional + Number of low-frequency (center) columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly each time. + uniform_range: bool + If True then an acceleration will be uniformly sampled between the two values. Default: True. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates an equispaced vertical Cartesian mask. + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + + Returns + ------- + mask: torch.Tensor + The sampling mask. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + + num_center_lines, acceleration = self.choose_acceleration() + + num_center_lines = int(num_center_lines) + mask = self.center_mask_func(num_cols, num_center_lines) + + if return_acs: + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_center_lines - num_cols)) / ( + num_center_lines * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + class FastMRIMagicMaskFunc(FastMRIMaskFunc): """Vertical line mask function as implemented in [1]_. @@ -422,6 +587,114 @@ def mask_func( return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) +class CartesianMagicMaskFunc(FastMRIMagicMaskFunc): + r"""Cartesian equispaced mask function as implemented in [1]_. + + Similar to :class:`FastMRIMagicMaskFunc, but instead of center fraction (`center_fractions`) representing + the fraction of center lines to the original size, here, it represents the actual number of center lines. + + References + ---------- + .. [1] Defazio, Aaron. “Offset Sampling Improves Deep Learning Based Accelerated MRI Reconstructions by + Exploiting Symmetry.” ArXiv:1912.01101 [Cs, Eess], Feb. 2020. arXiv.org, http://arxiv.org/abs/1912.01101. + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[int], Tuple[int, ...]]] = None, + uniform_range: bool = False, + ): + """Inits :class:`CartesianMagicMaskFunc`. + + Parameters + ---------- + accelerations: Union[List[Number], Tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions: list or tuple of ints, optional + Number of low-frequency (center) columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly each time. + uniform_range: bool + If True then an acceleration will be uniformly sampled between the two values. Default: True. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + r"""Creates an equispaced Cartesian mask that exploits conjugate symmetry. + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + mask: torch.Tensor + The sampling mask. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + + num_center_lines, acceleration = self.choose_acceleration() + + # bound the number of low frequencies between 1 and target columns + target_cols_to_sample = int(round(num_cols / acceleration)) + num_center_lines = max(min(num_center_lines, target_cols_to_sample), 1) + + acs_mask = self.center_mask_func(num_cols, num_center_lines) + + if return_acs: + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + + # adjust acceleration rate based on target acceleration. + adjusted_target_cols_to_sample = target_cols_to_sample - num_center_lines + adjusted_acceleration = 0 + if adjusted_target_cols_to_sample > 0: + adjusted_acceleration = int(round(num_cols / adjusted_target_cols_to_sample)) + + offset = self.rng.randint(0, high=adjusted_acceleration) + + if offset % 2 == 0: + offset_pos = offset + 1 + offset_neg = offset + 2 + else: + offset_pos = offset - 1 + 3 + offset_neg = offset - 1 + 0 + + poslen = (num_cols + 1) // 2 + neglen = num_cols - (num_cols + 1) // 2 + mask_positive = np.zeros(poslen, dtype=bool) + mask_negative = np.zeros(neglen, dtype=bool) + + mask_positive[offset_pos::adjusted_acceleration] = True + mask_negative[offset_neg::adjusted_acceleration] = True + mask_negative = np.flip(mask_negative) + + mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) + mask = np.logical_or(mask, acs_mask) + + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + class CalgaryCampinasMaskFunc(BaseMaskFunc): BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/" MASK_MD5S = { diff --git a/direct/data/datasets.py b/direct/data/datasets.py index b8446b66..e6530450 100644 --- a/direct/data/datasets.py +++ b/direct/data/datasets.py @@ -6,11 +6,13 @@ import contextlib import logging import pathlib +import re import sys import xml.etree.ElementTree as etree # nosec from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import h5py import numpy as np from omegaconf import DictConfig from torch.utils.data import Dataset, IterableDataset @@ -20,6 +22,7 @@ from direct.data.sens import simulate_sensitivity_maps from direct.types import PathOrString from direct.utils import remove_keys, str_to_class +from direct.utils.dataset import get_filenames_for_datasets logger = logging.getLogger(__name__) @@ -395,6 +398,297 @@ def __broadcast_mask(self, kspace_shape, mask): return mask +class CMRxReconDataset(Dataset): + """CMRxRecon Challenge Dataset [1]_. + + References + ---------- + + .. [1] https://cmrxrecon.github.io/Challenge.html + """ + + def __init__( + self, + data_root: pathlib.Path, + transform: Optional[Callable] = None, + filenames_filter: Union[List[PathOrString], None] = None, + filenames_lists: Union[List[PathOrString], None] = None, + filenames_lists_root: Union[PathOrString, None] = None, + regex_filter: Optional[str] = None, + metadata: Optional[Dict[PathOrString, Dict]] = None, + kspace_key: str = "kspace_full", + extra_keys: Optional[Tuple] = None, + pass_attrs: bool = False, + text_description: Optional[str] = None, + compute_mask: bool = False, + kspace_context: Optional[str] = None, + ) -> None: + """Inits :class:`CMRxReconDataset`. + + Parameters + ---------- + data_root : pathlib.Path + Root directory to data. + transform : Callable, optional + A list of transforms to be applied on the generated samples. Default is None. + filenames_filter : Union[List[PathOrString], None] + List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob + on the root. If set, will skip searching for files in the root. Default: None. + filenames_lists : Union[List[PathOrString], None] + List of paths pointing to `.lst` file(s) that contain file-names in `root` to filter. + Should be the same as the ones that can be derived from a glob on the root. If this is set, + this will override the `filenames_filter` option if not None. Defualt: None. + filenames_lists_root : Union[PathOrString, None] + Root of `filenames_lists`. Ignored if `filename_lists` is None. Default: None. + regex_filter : str, optional + Regular expression filter on the absolute filename. Will be applied after any filenames filter. + Default: None. + metadata: dict, optional + If given, this dictionary will be passed to the output transform. Default: None. + kspace_key : str + Key to load the k-space. Typically, `kspace_full` for fully-sampled data, or `kspace_subxx` for + sub-sampled data. Default: `kspace_full`. + extra_keys: Tuple of strings + Add extra keys in h5 file to output. May be used to load sampling masks, e.g. `maskxx`. Default: None. + pass_attrs: bool + Pass the attributes saved in the h5 file. Default: False. + text_description: str + Description of dataset, can be useful for logging. + compute_mask : bool + If True, it will compute the sampling mask from data. This should be typically True at inference, where + data are already undersampled. This will also compute `acs_mask`, which is by default the 24 + center lines. Default: False. + kspace_context : str, optional + Can be either None, `time` or `slice`. If None, data will be loaded per slice or time-frame (2D data). + If `time`, all time frames(phases) per slice will be loaded (3D data). If `slice`, all sliced per time frame + will be loaded (3D data). Default: None. + + """ + self.logger = logging.getLogger(type(self).__name__) + + self.root = pathlib.Path(data_root) + self.filenames_filter = filenames_filter + + self.metadata = metadata + + self.text_description = text_description + + self.kspace_key = kspace_key + + self.data: List[Tuple] = [] + + self.volume_indices: Dict[pathlib.Path, range] = {} + + if kspace_context not in [None, "slice", "time"]: + raise ValueError(f"Attribute `kspace_context` can be None for 2D data or `slice` or `time` for 3D.") + + self.kspace_context = kspace_context + + self.ndim = 2 if self.kspace_context is None else 3 + + # If filenames_filter and filenames_lists are given, it will load files in filenames_filter + # and filenames_lists will be ignored. + if filenames_filter is None: + if filenames_lists is not None: + if filenames_lists_root is None: + e = "`filenames_lists` is passed but `filenames_lists_root` is None." + self.logger.error(e) + raise ValueError(e) + filenames = get_filenames_for_datasets( + lists=filenames_lists, files_root=filenames_lists_root, data_root=data_root + ) + self.logger.info("Attempting to load %s filenames from list(s).", len(filenames)) + else: + self.logger.info("Parsing directory %s for mat files.", self.root) + filenames = list(self.root.glob("*.mat")) + else: + self.logger.info("Attempting to load %s filenames.", len(filenames_filter)) + filenames = filenames_filter + + filenames = [pathlib.Path(_) for _ in filenames] + + if regex_filter: + filenames = [_ for _ in filenames if re.match(regex_filter, str(_))] + + if len(filenames) == 0: + warn = ( + f"Found 0 mat files in directory {self.root}." + if not self.text_description + else f"Found 0 mat files in directory {self.root} for dataset {self.text_description}." + ) + self.logger.warning(warn) + else: + self.logger.info("Using %s mat files in %s.", len(filenames), self.root) + + self.parse_filenames_data(filenames, extra_mats=None) # Collect information on the image masks_dict. + self.pass_attrs = pass_attrs + self.extra_keys = extra_keys + + self.compute_mask = compute_mask + + self.transform = transform + + if self.text_description: + self.logger.info("Dataset description: %s.", self.text_description) + + def parse_filenames_data(self, filenames, extra_mats=None): + current_slice_number = 0 # This is required to keep track of where a volume is in the dataset + + for idx, filename in enumerate(filenames): + if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1): + self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.") + try: + if not filename.exists(): + raise OSError(f"{filename} does not exist.") + kspace_shape = h5py.File(filename, "r")[self.kspace_key].shape # pylint: disable = E1101 + self.verify_extra_mat_integrity( + filename, kspace_shape, extra_mats=extra_mats + ) # pylint: disable = E1101 + except Exception as exc: + self.logger.warning("%s failed with Exception: %s. Skipping...", filename, exc) + continue + + if self.kspace_context is None: + num_slices = np.prod(kspace_shape[:2]) + elif self.kspace_context == "slice": + # Slice dimension second + num_slices = kspace_shape[0] + else: + # Time dimension first + num_slices = kspace_shape[1] + + self.data += [(filename, slc) for slc in range(num_slices)] + + self.volume_indices[filename] = range( + current_slice_number, + current_slice_number + num_slices, + ) + + current_slice_number += num_slices + + @staticmethod + def verify_extra_mat_integrity(image_fn, _, extra_mats): + if not extra_mats: + return + + for key in extra_mats: + mat_key, path = extra_mats[key] + extra_fn = path / image_fn.name + try: + with h5py.File(extra_fn, "r") as file: + _ = file[mat_key].shape + except Exception as exc: + raise ValueError(f"Reading of {extra_fn} for key {mat_key} failed: {exc}.") from exc + + def __len__(self): + return len(self.data) + + def get_slice_data(self, filename, slice_no, key, pass_attrs=False, extra_keys=None): + extra_data = {} + if not filename.exists(): + raise OSError(f"{filename} does not exist.") + + try: + data = h5py.File(filename, "r") + except Exception as e: + raise Exception(f"Reading filename {filename} caused exception: {e}") + + shape = data[key].shape + if self.kspace_context is None: + inds = {(i): (k, l) for i, (k, l) in enumerate([(k, l) for k in range(shape[0]) for l in range(shape[1])])} + ind = inds[slice_no] + curr_data = np.array(data[key][ind[0]][ind[1]]) + elif self.kspace_context == "slice": + # Slice dimension + curr_data = np.array(data[key][slice_no]) + else: + # Time dimension + curr_data = np.array(data[key][:, slice_no]) + + if pass_attrs: + extra_data["attrs"] = dict(data.attrs) + + if extra_keys: + for extra_key in self.extra_keys: + if extra_key == "attrs": + raise ValueError("attrs need to be passed by setting `pass_attrs = True`.") + extra_data[extra_key] = data[extra_key][()] + data.close() + return curr_data, extra_data + + def get_num_slices(self, filename): + num_slices = self.volume_indices[filename].stop - self.volume_indices[filename].start + return num_slices + + def __getitem__(self, idx: int) -> Dict[str, Any]: + filename, slice_no = self.data[idx] + filename = pathlib.Path(filename) + metadata = None if not self.metadata else self.metadata[filename.name] + + kspace, extra_data = self.get_slice_data( + filename, slice_no, key=self.kspace_key, pass_attrs=self.pass_attrs, extra_keys=self.extra_keys + ) + + kspace = kspace["real"] + 1j * kspace["imag"] + kspace = np.swapaxes(kspace, -1, -2) + + if kspace.ndim == 2: # Singlecoil data. + kspace = kspace[np.newaxis, ...] + + sample = {"kspace": kspace, "filename": str(filename), "slice_no": slice_no} + + if self.compute_mask: + nx, ny = kspace.shape[-2:] + sampling_mask = np.abs(kspace).sum(tuple(range(len(kspace.shape) - 2))) != 0 + assert tuple(sampling_mask.shape) == (nx, ny) + acs_mask = np.zeros((nx, ny), dtype=bool) + acs_mask[:, ny // 2 - 12 : ny // 2 + 12] = True + + sample["sampling_mask"] = sampling_mask[np.newaxis, ..., np.newaxis] + sample["acs_mask"] = acs_mask[np.newaxis, ..., np.newaxis] + + elif any("mask" in key for key in extra_data): + mask_keys = [key for key in extra_data if "mask" in key] + # This will load up randomly a mask if more than one keys + mask_key = np.random.choice(mask_keys) + + sampling_mask = np.array(extra_data[mask_key]).astype(bool) + for key in mask_keys: + del extra_data[key] + + ny, nx = sampling_mask.shape + sampling_mask = np.swapaxes(sampling_mask, -1, -2) + + acs_mask = np.zeros((nx, ny), dtype=bool) + acs_mask[:, ny // 2 - 12 : ny // 2 + 12] = True + + sample["sampling_mask"] = sampling_mask[np.newaxis, ..., np.newaxis] + sample["acs_mask"] = acs_mask[np.newaxis, ..., np.newaxis] + + if self.kspace_context and "sampling_mask" in sample: + sample["sampling_mask"] = sample["sampling_mask"][np.newaxis] + sample["acs_mask"] = sample["acs_mask"][np.newaxis] + + if metadata is not None: + sample["metadata"] = metadata + + sample.update(extra_data) + + shape = kspace.shape + sample["reconstruction_size"] = (int(np.round(shape[-2] / 3)), int(np.round(shape[-1] / 2)), 1) + if self.kspace_context: + # Add context dimension in reconstruction size without any crop + context_size = shape[0] + sample["reconstruction_size"] = (context_size,) + sample["reconstruction_size"] + # If context put coil dim first + sample["kspace"] = np.swapaxes(sample["kspace"], 0, 1) + + if self.transform: + sample = self.transform(sample) + + return sample + + class CalgaryCampinasDataset(H5SliceData): """Calgary-Campinas challenge dataset.""" diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index 2d4e4d32..b69b1955 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -38,6 +38,8 @@ class RandomAugmentationTransformsConfig(BaseConfig): random_flip: bool = False random_flip_type: Optional[str] = "random" random_flip_probability: Optional[float] = 0.5 + random_reverse: bool = False + random_reverse_probability: Optional[float] = 0.5 @dataclass @@ -48,7 +50,7 @@ class NormalizationTransformConfig(BaseConfig): @dataclass class TransformsConfig(BaseConfig): - masking: MaskingConfig = MaskingConfig() + masking: Optional[MaskingConfig] = MaskingConfig() cropping: CropTransformConfig = CropTransformConfig() random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig() padding_eps: float = 0.001 @@ -82,6 +84,19 @@ class H5SliceConfig(DatasetConfig): filenames_lists_root: Optional[str] = None +@dataclass +class CMRxReconConfig(DatasetConfig): + regex_filter: Optional[str] = None + data_root: Optional[str] = None + filenames_filter: Optional[List[str]] = None + filenames_lists: Optional[List[str]] = None + filenames_lists_root: Optional[str] = None + kspace_key: str = "kspace_full" + compute_mask: bool = False + extra_keys: Optional[List[str]] = None + kspace_context: Optional[str] = None + + @dataclass class FastMRIConfig(H5SliceConfig): pass_attrs: bool = True diff --git a/direct/data/h5_data.py b/direct/data/h5_data.py index 91306a7e..13e32820 100644 --- a/direct/data/h5_data.py +++ b/direct/data/h5_data.py @@ -129,12 +129,13 @@ def __init__( else: self.logger.info("Using %s h5 files in %s.", len(filenames), self.root) + self.sensitivity_maps = cast_as_path(sensitivity_maps) + self.parse_filenames_data( filenames, extra_h5s=pass_h5s, filter_slice=slice_data ) # Collect information on the image masks_dict. self.pass_h5s = pass_h5s - self.sensitivity_maps = cast_as_path(sensitivity_maps) self.pass_attrs = pass_attrs self.extra_keys = extra_keys self.pass_dictionaries = pass_dictionaries @@ -154,6 +155,12 @@ def parse_filenames_data(self, filenames, extra_h5s=None, filter_slice=None): try: kspace_shape = h5py.File(filename, "r")["kspace"].shape # pylint: disable = E1101 self.verify_extra_h5_integrity(filename, kspace_shape, extra_h5s=extra_h5s) # pylint: disable = E1101 + if self.sensitivity_maps: + _ = h5py.File(self.sensitivity_maps / filename.name, "r") + + except FileNotFoundError as exc: + self.logger.warning("%s sensitivity map not found. Failed with: %s. Skipping...", filename, exc) + continue except OSError as exc: self.logger.warning("%s failed with OSError: %s. Skipping...", filename, exc) @@ -216,7 +223,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: # If the sensitivity maps exist, load these if self.sensitivity_maps: - sensitivity_map, _ = self.get_slice_data(self.sensitivity_maps / filename.name, slice_no) + sensitivity_map, _ = self.get_slice_data( + self.sensitivity_maps / filename.name, slice_no, key="sensitivity_map" + ) sample["sensitivity_map"] = sensitivity_map if metadata is not None: diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 6be976d4..42fb3164 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -12,12 +12,11 @@ import numpy as np import torch -import torchvision from direct.algorithms.mri_algorithms import EspiritCalibration from direct.data import transforms as T from direct.exceptions import ItemNotFoundException -from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey +from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey from direct.utils import DirectModule, DirectTransform from direct.utils.asserts import assert_complex @@ -80,7 +79,7 @@ def __init__( self, degrees: Sequence[int] = (-90, 90), p: float = 0.5, - kspace_key: KspaceKey = KspaceKey.kspace, + keys_to_rotate: Tuple[TransformKey, ...] = (TransformKey.kspace,), ): r"""Inits :class:`RandomRotation`. @@ -90,9 +89,9 @@ def __init__( Degrees of rotation. Must be a multiple of 90. If len(degrees) > 1, then a degree will be chosen at random. Default: (-90, 90). p: float - Probability of the backprojected :math:`k`-space being rotated. Default: 0.5 - kspace_key: KspaceKey - Default: KspaceKey.kspace. + Probability of rotation. Default: 0.5 + keys_to_rotate : tuple of TransformKeys + Keys to rotate. Default: "kspace". """ super().__init__() @@ -100,7 +99,7 @@ def __init__( self.degrees = degrees self.p = p - self.kspace_key = kspace_key + self.keys_to_rotate = keys_to_rotate def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Calls :class:`RandomRotation`. @@ -113,19 +112,22 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns ------- Dict[str, Any] - Sample with rotated :math:`k`-space. + Sample with rotated values of `keys_to_rotate`. """ if random.SystemRandom().random() <= self.p: - kspace = T.view_as_complex(sample[self.kspace_key].clone()) degree = random.SystemRandom().choice(self.degrees) k = degree // 90 - rotated_kspace = torch.rot90(kspace, k=k, dims=(1, 2)) - sample[self.kspace_key] = T.view_as_real(rotated_kspace) + for key in self.keys_to_rotate: + if key in sample: + value = T.view_as_complex(sample[key].clone()) + sample[key] = T.view_as_real(torch.rot90(value, k=k, dims=(-2, -1))) # If rotated by multiples of (n + 1) * 90 degrees, reconstruction size also needs to change reconstruction_size = sample.get("reconstruction_size", None) if reconstruction_size and (k % 2) == 1: - sample["reconstruction_size"] = reconstruction_size[:2][::-1] + reconstruction_size[2:] + sample["reconstruction_size"] = ( + reconstruction_size[:-3] + reconstruction_size[-3:-1][::-1] + reconstruction_size[-1:] + ) return sample @@ -134,6 +136,7 @@ class RandomFlipType(DirectEnum): horizontal = "horizontal" vertical = "vertical" random = "random" + both = "both" class RandomFlip(DirectTransform): @@ -146,7 +149,7 @@ def __init__( self, flip: RandomFlipType = RandomFlipType.random, p: float = 0.5, - kspace_key: KspaceKey = KspaceKey.kspace, + keys_to_flip: Tuple[TransformKey, ...] = (TransformKey.kspace,), ): r"""Inits :class:`RandomFlip`. @@ -155,21 +158,15 @@ def __init__( flip : RandomFlipType Horizontal, vertical, or random choice of the two. Default: RandomFlipType.random. p : float - Probability of the backprojected :math:`k`-space being flipped. Default: 0.5 - kspace_key : KspaceKey - Default: KspaceKey.kspace. + Probability of flip. Default: 0.5 + keys_to_flip : tuple of TransformKeys + Keys to flip. Default: "kspace". """ super().__init__() - if flip == "horizontal": - self.flipper = torchvision.transforms.RandomHorizontalFlip(p=p) - elif flip == "vertical": - self.flipper = torchvision.transforms.RandomVerticalFlip(p=p) - else: - self.flipper = random.SystemRandom().choice( - [torchvision.transforms.RandomHorizontalFlip(p=p), torchvision.transforms.RandomVerticalFlip(p=p)] - ) - self.kpace_key = kspace_key + self.flip = flip + self.p = p + self.keys_to_flip = keys_to_flip def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Calls :class:`RandomFlip`. @@ -182,11 +179,84 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns ------- Dict[str, Any] - Sample with flipped :math:`k`-space. + Sample with flipped values of `keys_to_flip`. """ - kspace = T.view_as_complex(sample[self.kpace_key].clone()) - kspace = self.flipper(kspace) - sample[self.kpace_key] = T.view_as_real(kspace) + if random.SystemRandom().random() <= self.p: + dims = ( + (-2,) + if self.flip == "horizontal" + else (-1,) + if self.flip == "vertical" + else (-2, -1) + if self.flip == "both" + else (random.SystemRandom().choice([-2, -1]),) + ) + + for key in self.keys_to_flip: + if key in sample: + value = T.view_as_complex(sample[key].clone()) + value = torch.flip(value, dims=dims) + sample[key] = T.view_as_real(value) + + return sample + + +class RandomReverse(DirectTransform): + r"""Random reverse of the order along a given dimension of a PyTorch tensor.""" + + def __init__( + self, + dim: int = 1, + p: float = 0.5, + keys_to_reverse: Tuple[TransformKey, ...] = (TransformKey.kspace,), + ): + r"""Inits :class:`RandomReverse`. + + Parameters + ---------- + dim : int + Dimension along to perform reversion. Typically, this is for time or slice dimension. Default: 2. + p : float + Probability of flip. Default: 0.5 + keys_to_reverse : tuple of TransformKeys + Keys to reverse. Default: "kspace". + """ + super().__init__() + + self.dim = dim + self.p = p + self.keys_to_reverse = keys_to_reverse + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Calls :class:`RandomReverse`. + + Parameters + ---------- + sample: Dict[str, Any] + Dict sample. + + Returns + ------- + Dict[str, Any] + Sample with flipped values of `keys_to_flip`. + """ + if random.SystemRandom().random() <= self.p: + dim = self.dim + for key in self.keys_to_reverse: + if key in sample: + tensor = sample[key].clone() + + if dim < 0: + dim += tensor.dim() + + tensor = T.view_as_complex(tensor) + + index = [slice(None)] * tensor.dim() + index[dim] = torch.arange(tensor.size(dim) - 1, -1, -1, dtype=torch.long) + + tensor = tensor[tuple(index)] + + sample[key] = T.view_as_real(tensor) return sample @@ -238,7 +308,7 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: Sample with `sampling_mask` key. """ if not self.shape: - shape = sample["kspace"].shape[1:] + shape = sample["kspace"].shape[-3:] elif any(_ is None for _ in self.shape): # Allow None as values. kspace_shape = list(sample["kspace"].shape[1:-1]) shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,) @@ -249,15 +319,17 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False) + if sample["kspace"].ndim == 5: + sampling_mask = sampling_mask.unsqueeze(0) + if "padding" in sample: sampling_mask = T.apply_padding(sampling_mask, sample["padding"]) - # Shape (1, height, width, 1) + # Shape (1, [1], height, width, 1) sample["sampling_mask"] = sampling_mask if self.return_acs: - kspace_shape = sample["kspace"].shape[1:] - sample["acs_mask"] = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True) + sample["acs_mask"] = self.mask_func(shape=shape, seed=seed, return_acs=True) return sample @@ -404,20 +476,28 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: Cropped and masked sample. """ - kspace = sample["kspace"] # shape (coil, height, width, complex=2) + kspace = sample["kspace"] # shape (coil, [slice], height, width, complex=2) - backprojected_kspace = self.backward_operator(kspace, dim=(1, 2)) # shape (coil, height, width, complex=2) + dim = self.spatial_dims["2D"] if kspace.ndim == 4 else self.spatial_dims["3D"] + + backprojected_kspace = self.backward_operator(kspace, dim=dim) # shape (coil, height, width, complex=2) if isinstance(self.crop, IntegerListOrTupleString): crop_shape = IntegerListOrTupleString(self.crop) elif isinstance(self.crop, str): assert self.crop in sample, f"Not found {self.crop} key in sample." - crop_shape = sample[self.crop][:2] + crop_shape = sample[self.crop][:-1] else: - crop_shape = self.crop + if kspace.ndim == 5 and len(self.crop) == 2: + crop_shape = (kspace.shape[1],) + tuple(self.crop) + else: + crop_shape = tuple(self.crop) + cropper_data_list = [backprojected_kspace] + if "sensitivity_map" in sample: + cropper_data_list += [sample["sensitivity_map"]] cropper_args = { - "data_list": [backprojected_kspace], + "data_list": cropper_data_list, "crop_shape": crop_shape, "contiguous": False, } @@ -425,11 +505,24 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: cropper_args["seed"] = ( None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"]))) ) - cropped_backprojected_kspace = self.crop_func(**cropper_args) + cropped_output = self.crop_func(**cropper_args) + if "sensitivity_map" in sample: + cropped_backprojected_kspace, sensitivity_map = cropped_output + sample["sensitivity_map"] = sensitivity_map + else: + cropped_backprojected_kspace = cropped_output + + if "sampling_mask" in sample: + sample["sampling_mask"] = T.complex_center_crop( + sample["sampling_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape + ) + sample["acs_mask"] = T.complex_center_crop( + sample["acs_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape + ) # Compute new k-space for the cropped_backprojected_kspace - # shape (coil, new_height, new_width, complex=2) - sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=(1, 2)) # The cropped kspace + # shape (coil, [slice], new_height, new_width, complex=2) + sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=dim) # The cropped kspace return sample @@ -450,7 +543,7 @@ def __init__( self, kspace_key: KspaceKey = KspaceKey.kspace, padding_key: str = "padding", - eps: float = 0.0001, + eps: Union[float, None] = 0.0001, ) -> None: """Inits :class:`ComputeZeroPadding`. @@ -485,10 +578,22 @@ def __call__(self, sample: Dict[str, Any], coil_dim: int = 0) -> Dict[str, Any]: sample : Dict[str, Any] Dict sample containing key `padding_key`. """ + if self.eps is None: + return sample + shape = sample[self.kspace_key].shape - kspace = T.modulus(sample[self.kspace_key]).sum(coil_dim) - padding = (kspace < torch.mean(kspace) * self.eps).to(kspace.device).unsqueeze(coil_dim).unsqueeze(-1) + kspace = T.modulus(sample[self.kspace_key].clone()).sum(coil_dim) + if len(shape) == 5: # Check if 3D data + # Assumes that slice dim is 0 + kspace = kspace.sum(0) + + padding = (kspace < (torch.mean(kspace) * self.eps)).to(kspace.device) + + if len(shape) == 5: + padding = padding.unsqueeze(0) + + padding = padding.unsqueeze(coil_dim).unsqueeze(-1) sample[self.padding_key] = padding return sample @@ -540,6 +645,7 @@ class ReconstructionType(str, Enum): complex_mod = "complex_mod" sense = "sense" sense_mod = "sense_mod" + ifft = "ifft" class ComputeImageModule(DirectModule): @@ -563,7 +669,7 @@ def __init__( backward_operator: callable The backward operator, e.g. some form of inverse FFT (centered or uncentered). type_reconstruction: ReconstructionType - Type of reconstruction. Can be "complex", "complex_mod", "sense", "sense_mod" or "rss". + Type of reconstruction. Can be "complex", "complex_mod", "sense", "sense_mod", "rss" or "ifft". Default: ReconstructionType.rss. """ super().__init__() @@ -587,10 +693,12 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: "rss", "complex_mod" or "sense_mod", and of shape(\*spatial_dims, complex_dim=2) otherwise. """ kspace_data = sample[self.kspace_key] - + dim = self.spatial_dims["2D"] if kspace_data.ndim == 5 else self.spatial_dims["3D"] # Get complex-valued data solution - image = self.backward_operator(kspace_data, dim=self.spatial_dims) - if self.type_reconstruction in [ + image = self.backward_operator(kspace_data, dim=dim) + if self.type_reconstruction == ReconstructionType.ifft: + sample[self.target_key] = image + elif self.type_reconstruction in [ ReconstructionType.complex, ReconstructionType.complex_mod, ]: @@ -651,14 +759,15 @@ def __call__(self, sample: Dict[str, Any], coil_dim: int = 0) -> Dict[str, Any]: Contains key `"body_coil_image`. """ kspace = sample["kspace"] - # We need to create an ACS mask based on the shape of this kspace, as it can be cropped. + # We need to create an ACS mask based on the shape of this kspace, as it can be cropped. seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - kspace_shape = sample["kspace"].shape[1:] + kspace_shape = tuple(sample["kspace"].shape[-3:]) acs_mask = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True) - + print(acs_mask.shape) kspace = acs_mask * kspace + 0.0 - acs_image = self.backward_operator(kspace, dim=(1, 2)) + dim = self.spatial_dims["2D"] if kspace.ndim == 4 else self.spatial_dims["3D"] + acs_image = self.backward_operator(kspace, dim=dim) sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim=coil_dim) return sample @@ -722,7 +831,7 @@ def __init__( super().__init__() self.backward_operator = backward_operator self.kspace_key = kspace_key - if type_of_map not in ["unit", "rss_estimate", "espirit"]: + if type_of_map not in ["unit", "rss_estimate", "espirit", "key_rss_estimate"]: raise ValueError( f"Expected type of map to be either `unit`, `rss_estimate`, `espirit`. Got {type_of_map}." ) @@ -783,8 +892,9 @@ def estimate_acs_image(self, sample: Dict[str, Any], width_dim: int = -2) -> tor kspace_acs = kspace_data * sample["acs_mask"] * gaussian_mask + 0.0 # Get complex-valued data solution - # Shape (batch, coil, height, width, complex=2) - acs_image = self.backward_operator(kspace_acs, dim=(2, 3)) + # Shape (batch, [slice], coil, height, width, complex=2) + dim = self.spatial_dims["2D"] if kspace_data.ndim == 5 else self.spatial_dims["3D"] + acs_image = self.backward_operator(kspace_acs, dim=dim) return acs_image @@ -820,7 +930,19 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) # Shape (batch, coil, height, width, complex=2) sensitivity_map = T.safe_divide(acs_image, acs_image_rss) + elif self.type_of_map == "key_rss_estimate": + dim = self.spatial_dims["2D"] if sample[self.kspace_key].ndim == 5 else self.spatial_dims["3D"] + image = self.backward_operator(sample[self.kspace_key], dim=dim) + image_rss = ( + T.root_sum_of_squares(image, dim=self.coil_dim).unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) + ) + sensitivity_map = T.safe_divide(image, image_rss) else: + if sample[self.kspace_key].ndim > 5: + raise NotImplementedError( + "EstimateSensitivityMapModule is not yet implemented for " + "Espirit sensitivity map estimation for 3D data." + ) sensitivity_map = self.espirit_calibrator(sample) sensitivity_map_norm = torch.sqrt( @@ -832,6 +954,42 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample +class AddBooleanKeysModule(DirectModule): + """Adds keys with boolean values to sample.""" + + def __init__(self, keys: List[str], values: List[bool]): + """Inits :class:`AddBooleanKeysModule`. + + Parameters + ---------- + keys : List[str] + A list of keys to be added. + values : List[bool] + A list of values corresponding to the keys. + """ + super().__init__() + self.keys = keys + self.values = values + + def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Adds boolean keys to the input sample dictionary. + + Parameters + ---------- + sample : Dict[str, Any] + The input sample dictionary. + + Returns + ------- + Dict[str, Any] + The modified sample with added boolean keys. + """ + for key, value in zip(self.keys, self.values): + sample[key] = value + + return sample + + class DeleteKeysModule(DirectModule): """Remove keys from the sample if present.""" @@ -979,21 +1137,21 @@ class ComputeScalingFactorModule(DirectModule): def __init__( self, - normalize_key: Union[None, str] = "masked_kspace", + normalize_key: Union[None, TransformKey] = TransformKey.masked_kspace, percentile: Union[None, float] = 0.99, - scaling_factor_key: str = "scaling_factor", + scaling_factor_key: TransformKey = TransformKey.scaling_factor, ): """Inits :class:`ComputeScalingFactorModule`. Parameters ---------- - normalize_key : str or None + normalize_key : TransformKey or None Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction on the ACS is available (typically `body_coil_image`). Default: "masked_kspace". percentile : float or None Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99. - scaling_factor_key : str - Name of how the scaling factor will be stored. Default: 'scaling_factor'. + scaling_factor_key : TransformKey + Name of how the scaling factor will be stored. Default: "scaling_factor". """ super().__init__() self.normalize_key = normalize_key @@ -1041,14 +1199,14 @@ class NormalizeModule(DirectModule): def __init__( self, - scaling_factor_key: str = "scaling_factor", - keys_to_normalize: Optional[List[str]] = None, + scaling_factor_key: TransformKey = TransformKey.scaling_factor, + keys_to_normalize: Optional[List[TransformKey]] = None, ): """Inits :class:`NormalizeModule`. Parameters ---------- - scaling_factor_key : str + scaling_factor_key : TransformKey Name of scaling factor key expected in sample. Default: 'scaling_factor'. """ super().__init__() @@ -1179,6 +1337,8 @@ def __call__(self, sample): for k, v in sample.items(): if isinstance(v, (torch.Tensor, np.ndarray)): sample[k] = v[None] + else: + sample[k] = [v] sample = self._transform.forward(sample) @@ -1186,6 +1346,8 @@ def __call__(self, sample): for k, v in sample.items(): if isinstance(v, (torch.Tensor, np.ndarray)): sample[k] = v.squeeze(0) + else: + sample[k] = v[0] return sample @@ -1250,11 +1412,11 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample["sensitivity_map"] = T.to_tensor(sample["sensitivity_map"]).float() if "target" in sample: # Shape: 2D: (coil, height, width), 3D: (coil, slice, height, width) - sample["target"] = sample["target"] + sample["target"] = torch.from_numpy(sample["target"]).float() if "sampling_mask" in sample: - sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).byte() + sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).bool() if "acs_mask" in sample: - sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]) + sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]).bool() if "scaling_factor" in sample: sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float() if "loglikelihood_scaling" in sample: @@ -1351,10 +1513,20 @@ def build_pre_mri_transforms( ] if random_rotation: mri_transforms += [ - RandomRotation(degrees=random_rotation_degrees, p=random_rotation_probability, kspace_key=KspaceKey.kspace) + RandomRotation( + degrees=random_rotation_degrees, + p=random_rotation_probability, + keys_to_rotate=(TransformKey.kspace, TransformKey.sensitivity_map), + ) ] if random_flip: - mri_transforms += [RandomFlip(flip=random_flip_type, p=random_flip_probability, kspace_key=KspaceKey.kspace)] + mri_transforms += [ + RandomFlip( + flip=random_flip_type, + p=random_flip_probability, + keys_to_flip=(TransformKey.kspace, TransformKey.sensitivity_map), + ) + ] if mask_func: mri_transforms += [ ComputeZeroPadding(KspaceKey.kspace, "padding", padding_eps), @@ -1385,7 +1557,7 @@ def build_post_mri_transforms( delete_acs_mask: bool = True, delete_kspace: bool = True, image_recon_type: ReconstructionType = ReconstructionType.rss, - scaling_key: KspaceKey = KspaceKey.masked_kspace, + scaling_key: TransformKey = TransformKey.masked_kspace, scale_percentile: Optional[float] = 0.99, ) -> object: """Build transforms for MRI. @@ -1423,8 +1595,8 @@ def build_post_mri_transforms( If True will delete key `kspace` (fully sampled k-space). Default: True. image_recon_type : ReconstructionType Type to reconstruct target image. Default: ReconstructionType.rss. - scaling_key : KspaceKey - Key in sample to scale scalable items in sample. Default: KspaceKey.masked_kspace. + scaling_key : TransformKey + Key in sample to scale scalable items in sample. Default: TransformKey.masked_kspace. scale_percentile : float, optional Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99 the same mask every time. Default: True. @@ -1471,9 +1643,9 @@ def build_post_mri_transforms( ComputeScalingFactorModule( normalize_key=scaling_key, percentile=scale_percentile, - scaling_factor_key="scaling_factor", + scaling_factor_key=TransformKey.scaling_factor, ), - NormalizeModule(scaling_factor_key="scaling_factor"), + NormalizeModule(scaling_factor_key=TransformKey.scaling_factor), ] if delete_kspace: mri_transforms += [DeleteKeysModule(keys=[KspaceKey.kspace])] @@ -1494,6 +1666,8 @@ def build_mri_transforms( random_flip: bool = False, random_flip_type: Optional[RandomFlipType] = RandomFlipType.random, random_flip_probability: Optional[float] = 0.5, + random_reverse: bool = False, + random_reverse_probability: float = 0.5, padding_eps: float = 0.0001, estimate_body_coil_image: bool = False, estimate_sensitivity_maps: bool = True, @@ -1507,7 +1681,7 @@ def build_mri_transforms( delete_kspace: bool = True, image_recon_type: ReconstructionType = ReconstructionType.rss, pad_coils: Optional[int] = None, - scaling_key: KspaceKey = KspaceKey.masked_kspace, + scaling_key: TransformKey = TransformKey.masked_kspace, scale_percentile: Optional[float] = 0.99, use_seed: bool = True, ) -> object: @@ -1552,6 +1726,10 @@ def build_mri_transforms( Default: RandomFlipType.random. random_flip_probability : float, optional Default: 0.5. + random_reverse : bool + If True will perform random reversion along the time or slice dimension (2). Default: False. + random_reverse_probability : float + Default: 0.5. padding_eps: float Padding epsilon. Default: 0.0001. estimate_body_coil_image : bool @@ -1579,8 +1757,8 @@ def build_mri_transforms( Type to reconstruct target image. Default: ReconstructionType.rss. pad_coils : int Number of coils to pad data to. - scaling_key : KspaceKey - Key in sample to scale scalable items in sample. Default: KspaceKey.masked_kspace. + scaling_key : TransformKey + Key in sample to scale scalable items in sample. Default: TransformKey.masked_kspace. scale_percentile : float, optional Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99 use_seed : bool @@ -1608,10 +1786,27 @@ def build_mri_transforms( ] if random_rotation: mri_transforms += [ - RandomRotation(degrees=random_rotation_degrees, p=random_rotation_probability, kspace_key=KspaceKey.kspace) + RandomRotation( + degrees=random_rotation_degrees, + p=random_rotation_probability, + keys_to_rotate=(TransformKey.kspace, TransformKey.sensitivity_map), + ) ] if random_flip: - mri_transforms += [RandomFlip(flip=random_flip_type, p=random_flip_probability, kspace_key=KspaceKey.kspace)] + mri_transforms += [ + RandomFlip( + flip=random_flip_type, + p=random_flip_probability, + keys_to_flip=(TransformKey.kspace, TransformKey.sensitivity_map), + ) + ] + if random_reverse: + mri_transforms += [ + RandomReverse( + p=random_reverse_probability, + keys_to_reverse=(TransformKey.kspace, TransformKey.sensitivity_map), + ) + ] if mask_func: mri_transforms += [ ComputeZeroPadding(KspaceKey.kspace, "padding", padding_eps), @@ -1648,12 +1843,6 @@ def build_mri_transforms( mri_transforms += [DeleteKeys(keys=["acs_mask"])] mri_transforms += [ - ComputeImage( - kspace_key=KspaceKey.kspace, - target_key="target", - backward_operator=backward_operator, - type_reconstruction=image_recon_type, - ), ApplyMask( sampling_mask_key="sampling_mask", input_kspace_key=KspaceKey.kspace, @@ -1663,9 +1852,18 @@ def build_mri_transforms( mri_transforms += [ ComputeScalingFactor( - normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key="scaling_factor" + normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.scaling_factor ), - Normalize(scaling_factor_key="scaling_factor"), + Normalize(scaling_factor_key=TransformKey.scaling_factor), + ] + + mri_transforms += [ + ComputeImage( + kspace_key=KspaceKey.kspace, + target_key=TransformKey.target, + backward_operator=backward_operator, + type_reconstruction=image_recon_type, + ) ] if delete_kspace: diff --git a/direct/data/transforms.py b/direct/data/transforms.py index b5d26a6d..b589c5a8 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -678,13 +678,13 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> torch.Tensor: The center cropped data. """ # TODO: Make dimension independent. - if not (0 < shape[0] <= data.shape[-2]) or not (0 < shape[1] <= data.shape[-1]): + if not (0 < shape[-2] <= data.shape[-2]) or not (0 < shape[-1] <= data.shape[-1]): raise ValueError(f"Crop shape should be smaller than data. Requested {shape}, got {data.shape}.") - width_lower = (data.shape[-2] - shape[0]) // 2 - width_upper = width_lower + shape[0] - height_lower = (data.shape[-1] - shape[1]) // 2 - height_upper = height_lower + shape[1] + width_lower = (data.shape[-2] - shape[-2]) // 2 + width_upper = width_lower + shape[-2] + height_lower = (data.shape[-1] - shape[-1]) // 2 + height_upper = height_lower + shape[-1] return data[..., width_lower:width_upper, height_lower:height_upper] diff --git a/direct/engine.py b/direct/engine.py index 68658097..2ac6bd31 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -167,6 +167,13 @@ def predict( self.ndim = dataset.ndim # type: ignore self.logger.info("Data dimensionality: %s.", self.ndim) + if self.ndim == 3 and batch_size > 1: + self.logger.warning( + f"Batch size for inference of 3D data must be 1. Received {batch_size}." + f"`batch_size` overwritten by 1." + ) + batch_size = 1 + self.checkpointer = Checkpointer( save_directory=experiment_directory, save_to_disk=False, model=self.model, **self.models # type: ignore ) @@ -346,7 +353,11 @@ def training_loop( "This message will only be displayed once." ) parameters = list(filter(lambda p: p.grad is not None, self.model.parameters())) - gradient_norm = sum([parameter.grad.data**2 for parameter in parameters]).sqrt() # type: ignore + gradient_norm = 0.0 + for p in parameters: + param_norm = p.grad.data.norm(2) + gradient_norm += param_norm.item() ** 2 + gradient_norm = gradient_norm ** (1.0 / 2) storage.add_scalar("train/gradient_norm", gradient_norm) # Same as self.__optimizer.step() for mixed precision. @@ -418,9 +429,17 @@ def validation_loop( curr_dataset_name = curr_validation_dataset.text_description self.logger.info("Evaluating: %s...", curr_dataset_name) self.logger.info("Building dataloader for dataset: %s.", curr_dataset_name) + if self.ndim == 3 and self.cfg.validation.batch_size > 1: # type: ignore + self.logger.warning( + f"Batch size for inference of 3D data must be 1. " + f"Received `batch_size` = {self.cfg.validation.batch_size}. Overwriting with 1." # type: ignore + ) # type: ignore + batch_size = 1 + else: + batch_size = self.cfg.validation.batch_size # type: ignore curr_batch_sampler = self.build_batch_sampler( curr_validation_dataset, - batch_size=self.cfg.validation.batch_size, # type: ignore + batch_size=batch_size, sampler_type="sequential", limit_number_of_volumes=None, ) @@ -675,9 +694,9 @@ def log_first_training_example_and_model(self, data): if self.ndim == 3: first_sampling_mask = first_sampling_mask[0] - slice_dim = -4 - num_slices = first_target.shape[slice_dim] - first_target = first_target[num_slices // 2] + num_slices = first_target.shape[0] + first_target = first_target[: num_slices // 2] + first_target = torch.cat([first_target[_] for _ in range(first_target.shape[0])], dim=-1) elif self.ndim > 3: raise NotImplementedError diff --git a/direct/functionals/__init__.py b/direct/functionals/__init__.py index dcd13597..ba789986 100644 --- a/direct/functionals/__init__.py +++ b/direct/functionals/__init__.py @@ -3,7 +3,9 @@ from direct.functionals.challenges import * from direct.functionals.grad import * +from direct.functionals.hfen import * from direct.functionals.nmae import NMAELoss from direct.functionals.nmse import * from direct.functionals.psnr import * +from direct.functionals.snr import SNRLoss, snr from direct.functionals.ssim import * diff --git a/direct/functionals/hfen.py b/direct/functionals/hfen.py new file mode 100644 index 00000000..9348d9e7 --- /dev/null +++ b/direct/functionals/hfen.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +# + +from __future__ import annotations + +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["hfen_l1", "hfen_l2", "HFENLoss", "HFENL1Loss", "HFENL2Loss"] + + +def get_log_kernel2d(kernel_size: int | list[int] = 5, sigma: Optional[float | list[float]] = None) -> torch.Tensor: + """Generates a 2D LoG (Laplacian of Gaussian) kernel. + + Parameters + ---------- + kernel_size : int or list of ints + Size of the kernel. Default: 5. + sigma : float or list of floats + Standard deviation(s) for the Gaussian distribution. Default: None. + + Returns + ------- + torch.Tensor: Generated LoG kernel. + """ + dim = 2 + if not kernel_size and sigma: + kernel_size = np.ceil(sigma * 6) + kernel_size = [kernel_size] * dim + elif kernel_size and not sigma: + sigma = kernel_size / 6.0 + sigma = [sigma] * dim + + if isinstance(kernel_size, int): + kernel_size = [kernel_size - 1] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + + grids = torch.meshgrid([torch.arange(-size // 2, size // 2 + 1, 1) for size in kernel_size], indexing="ij") + + kernel = 1 + for size, std, mgrid in zip(kernel_size, sigma, grids): + kernel *= torch.exp(-(mgrid**2 / (2.0 * std**2))) + + final_kernel = ( + kernel + * ((grids[0] ** 2 + grids[1] ** 2) - (2 * sigma[0] * sigma[1])) + * (1 / ((2 * math.pi) * (sigma[0] ** 2) * (sigma[1] ** 2))) + ) + + final_kernel = -final_kernel / torch.sum(final_kernel) + + return final_kernel + + +def compute_padding(kernel_size: int | list[int] = 5) -> int | tuple[int, ...]: + """Computes padding tuple based on the kernel size. + + For square kernels, pad can be an int, else, a tuple with an element for each dimension. + + Parameters + ---------- + kernel_size : int or list of ints + Size(s) of the kernel. + + Returns + ------- + int or tuple of ints + Computed padding. + """ + if isinstance(kernel_size, int): + return kernel_size // 2 + elif isinstance(kernel_size, list): + computed = [k // 2 for k in kernel_size] + out_padding = [] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + if kernel_size[i] % 2 == 0: + padding = computed_tmp - 1 + else: + padding = computed_tmp + out_padding.append(padding) + out_padding.append(computed_tmp) + + return tuple(out_padding) + + +class HFENLoss(nn.Module): + """High Frequency Error Norm (HFEN) Loss as defined in _[1]. + + Calculates: + + .. math:: + + || \text{LoG}(x_\text{rec}) - \text{LoG}(x_\text{tar}) ||_C + + Where C can be any norm, LoG is the Laplacian of Gaussian filter, and :math:`x_\text{rec}), \text{LoG}(x_\text{tar}` + are the reconstructed input and target images. + If normalized it scales it by :math:`|| \text{LoG}(x_\text{tar}) ||_C`. + + Code was borrowed and adapted from _[2]. + + References + ---------- + .. [1] S. Ravishankar and Y. Bresler, "MR Image Reconstruction From Highly Undersampled k-Space Data by + Dictionary Learning," in IEEE Transactions on Medical Imaging, vol. 30, no. + 5, pp. 1028-1041, May 2011, doi: 10.1109/TMI.2010.2090538. + .. [2] https://github.com/styler00dollar/pytorch-loss-functions/blob/main/vic/loss.py + """ + + def __init__( + self, + criterion: nn.Module, + reduction: str = "mean", + kernel_size: int | list[int] = 5, + sigma: float | list[float] = 2.5, + norm: bool = False, + ): + """Inits :class:`HFENLoss`. + + Parameters + ---------- + criterion : nn.Module + Loss function to calculate the difference between log1 and log2. + reduction : str + Criterion reduction. Default: "mean". + kernel_size : int or list of ints + Size of the LoG filter kernel. Default: 15. + sigma : float or list of floats + Standard deviation of the LoG filter kernel. Default: 2.5. + norm : bool + Whether to normalize the loss. + """ + super().__init__() + self.criterion = criterion(reduction=reduction) + self.norm = norm + kernel = get_log_kernel2d(kernel_size, sigma) + self.filter = self._compute_filter(kernel, kernel_size) + + @staticmethod + def _compute_filter(kernel: torch.Tensor, kernel_size: int | list[int] = 15) -> nn.Module: + """Computes the LoG filter based on the kernel and kernel size. + + Parameters + ---------- + kernel : torch.Tensor + The kernel tensor. + kernel_size : int or list of ints, optional + Size of the filter kernel. Default: 15. + + Returns + ------- + nn.Module + The computed filter. + """ + kernel = kernel.expand(1, 1, *kernel.size()).contiguous() + pad = compute_padding(kernel_size) + _filter = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=1, padding=pad, bias=False) + _filter.weight.data = kernel + + return _filter + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Forward pass of the :class:`HFENLoss`. + + Parameters + ---------- + input : torch.Tensor + Input tensor. + target : torch.Tensor + Target tensor. + + Returns + ------- + torch.Tensor + HFEN loss value. + """ + self.filter.to(input.device) + log1 = self.filter(input) + log2 = self.filter(target) + hfen_loss = self.criterion(log1, log2) + if self.norm: + hfen_loss /= self.criterion(torch.zeros_like(target, dtype=target.dtype, device=target.device), target) + return hfen_loss + + +class HFENL1Loss(HFENLoss): + """High Frequency Error Norm (HFEN) Loss using L1Loss criterion. + + Calculates: + + .. math:: + + || \text{LoG}(x_\text{rec}) - \text{LoG}(x_\text{tar}) ||_1 + + Where LoG is the Laplacian of Gaussian filter, and :math:`x_\text{rec}), \text{LoG}(x_\text{tar}` + are the reconstructed input and target images. + If normalized it scales it by :math:`|| \text{LoG}(x_\text{tar}) ||_1`. + """ + + def __init__( + self, + reduction: str = "mean", + kernel_size: int | list[int] = 15, + sigma: float | list[float] = 2.5, + norm: bool = False, + ): + """Inits :class:`HFENL1Loss`. + + Parameters + ---------- + reduction : str + Criterion reduction. Default: "mean". + kernel_size : int or list of ints + Size of the LoG filter kernel. Default: 15. + sigma : float or list of floats + Standard deviation of the LoG filter kernel. Default: 2.5. + norm : bool + Whether to normalize the loss. + """ + super().__init__(nn.L1Loss, reduction, kernel_size, sigma, norm) + + +class HFENL2Loss(HFENLoss): + """High Frequency Error Norm (HFEN) Loss using L1Loss criterion. + + Calculates: + + .. math:: + + || \text{LoG}(x_\text{rec}) - \text{LoG}(x_\text{tar}) ||_2 + + Where LoG is the Laplacian of Gaussian filter, and :math:`x_\text{rec}), \text{LoG}(x_\text{tar}` + are the reconstructed input and target images. + If normalized it scales it by :math:`|| \text{LoG}(x_\text{tar}) ||_2`. + """ + + def __init__( + self, + reduction: str = "mean", + kernel_size: int | list[int] = 15, + sigma: float | list[float] = 2.5, + norm: bool = False, + ): + """Inits :class:`HFENL2Loss`. + + Parameters + ---------- + reduction : str + Criterion reduction. Default: "mean". + kernel_size : int or list of ints + Size of the LoG filter kernel. Default: 15. + sigma : float or list of floats + Standard deviation of the LoG filter kernel. Default: 2.5. + norm : bool + Whether to normalize the loss. + """ + super().__init__(nn.MSELoss, reduction, kernel_size, sigma, norm) + + +def hfen_l1( + input: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + kernel_size: int | list[int] = 15, + sigma: float | list[float] = 2.5, + norm: bool = False, +) -> torch.Tensor: + """Calculates HFENL1 loss between input and target. + + Parameters + ---------- + input : torch.Tensor + Input tensor. + target : torch.Tensor + Target tensor. + reduction : str + Criterion reduction. Default: "mean". + kernel_size : int or list of ints + Size of the LoG filter kernel. Default: 15. + sigma : float or list of floats + Standard deviation of the LoG filter kernel. Default: 2.5. + norm : bool + Whether to normalize the loss. + """ + hfen_metric = HFENL1Loss(reduction, kernel_size, sigma, norm) + return hfen_metric(input, target) + + +def hfen_l2( + input: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + kernel_size: int | list[int] = 15, + sigma: float | list[float] = 2.5, + norm: bool = False, +) -> torch.Tensor: + """Calculates HFENL2 loss between input and target. + + Parameters + ---------- + input : torch.Tensor + Input tensor. + target : torch.Tensor + Target tensor. + reduction : str + Criterion reduction. Default: "mean". + kernel_size : int or list of ints + Size of the LoG filter kernel. Default: 15. + sigma : float or list of floats + Standard deviation of the LoG filter kernel. Default: 2.5. + norm : bool + Whether to normalize the loss. + """ + hfen_metric = HFENL2Loss(reduction, kernel_size, sigma, norm) + return hfen_metric(input, target) diff --git a/direct/functionals/psnr.py b/direct/functionals/psnr.py index 93f199e6..9fdae511 100644 --- a/direct/functionals/psnr.py +++ b/direct/functionals/psnr.py @@ -37,10 +37,29 @@ def batch_psnr(input_data, target_data, reduction="mean"): class PSNRLoss(nn.Module): - __constants__ = ["reduction"] + """PSNR loss PyTorch implementation.""" - def __init__(self, reduction="mean"): + def __init__(self, reduction: str = "mean") -> None: + """Inits :class:`PSNRLoss`. + + Parameters + ---------- + reduction : str + Batch reduction. Default: str. + """ + super().__init__() self.reduction = reduction - def forward(self, input_data, target_data): + def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`PSNRLoss`. + + Parameters + ---------- + input_data : torch.Tensor + target_data : torch.Tensor + + Returns + ------- + torch.Tensor + """ return batch_psnr(input_data, target_data, reduction=self.reduction) diff --git a/direct/functionals/snr.py b/direct/functionals/snr.py new file mode 100644 index 00000000..e6557619 --- /dev/null +++ b/direct/functionals/snr.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch +import torch.nn as nn + +__all__ = ("snr", "SNRLoss") + + +def snr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor: + """This function is a torch implementation of SNR metric for batches. + + Parameters + ---------- + input_data : torch.Tensor + target_data : torch.Tensor + reduction : str + + Returns + ------- + torch.Tensor + """ + batch_size = target_data.size(0) + input_view = input_data.view(batch_size, -1) + target_view = target_data.view(batch_size, -1) + + square_error = torch.sum(target_view**2, 1) + square_error_noise = torch.sum((input_view - target_view) ** 2, 1) + snr_metric = 10.0 * (torch.log10(square_error) - torch.log10(square_error_noise)) + + if reduction == "mean": + return snr_metric.mean() + if reduction == "sum": + return snr_metric.sum() + if reduction == "none": + return snr_metric + raise ValueError(f"Reduction is either `mean`, `sum` or `none`. Got {reduction}.") + + +class SNRLoss(nn.Module): + """SNR loss function PyTorch implementation.""" + + def __init__(self, reduction: str = "mean") -> None: + """Inits :class:`SNRLoss`. + + Parameters + ---------- + reduction : str + Batch reduction. Default: str. + """ + super().__init__() + self.reduction = reduction + + def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`SNRLoss`. + + Parameters + ---------- + input_data : torch.Tensor + target_data : torch.Tensor + + Returns + ------- + torch.Tensor + """ + return snr(input_data, target_data, reduction=self.reduction) diff --git a/direct/functionals/ssim.py b/direct/functionals/ssim.py index 93208a6c..90a931ab 100644 --- a/direct/functionals/ssim.py +++ b/direct/functionals/ssim.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -__all__ = ("SSIMLoss",) +__all__ = ("SSIMLoss", "SSIM3DLoss") class SSIMLoss(nn.Module): @@ -60,3 +60,63 @@ def forward(self, X, Y, data_range): S = (A1 * A2) / D return 1 - S.mean() + + +class SSIM3DLoss(nn.Module): + """SSIM loss module for 3D data.""" + + def __init__(self, win_size=7, k1=0.01, k2=0.03): + """ + Parameters + ---------- + win_size: int + Window size for SSIM calculation. Default: 7. + k1: float + k1 parameter for SSIM calculation. Default: 0.1. + k2: float + k2 parameter for SSIM calculation. Default: 0.03. + """ + super().__init__() + self.win_size = win_size + self.k1, self.k2 = k1, k2 + + def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`SSIM3Dloss`. + + Parameters + ---------- + X : torch.Tensor + Y : torch.Tensor + data_range : torch.Tensor + """ + data_range = data_range[:, None, None, None, None] + C1 = (self.k1 * data_range) ** 2 + C2 = (self.k2 * data_range) ** 2 + + # window size across last dimension is chosen to be the last dimension size if smaller than given window size + win_size_z = min(self.win_size, X.size(2)) + + NP = win_size_z * self.win_size**2 + w = torch.ones(1, 1, win_size_z, self.win_size, self.win_size, device=X.device) / NP + cov_norm = NP / (NP - 1) + + ux = F.conv3d(X, w) + uy = F.conv3d(Y, w) + uxx = F.conv3d(X * X, w) + uyy = F.conv3d(Y * Y, w) + uxy = F.conv3d(X * Y, w) + + vx = cov_norm * (uxx - ux * ux) + vy = cov_norm * (uyy - uy * uy) + vxy = cov_norm * (uxy - ux * uy) + + A1, A2, B1, B2 = ( + 2 * ux * uy + C1, + 2 * vxy + C2, + ux**2 + uy**2 + C1, + vx + vy + C2, + ) + D = B1 * B2 + S = (A1 * A2) / D + + return 1 - S.mean() diff --git a/direct/nn/conv/conv.py b/direct/nn/conv/conv.py index 8c3ad298..b4b55740 100644 --- a/direct/nn/conv/conv.py +++ b/direct/nn/conv/conv.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter class Conv2d(nn.Module): @@ -72,3 +74,268 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ out = self.conv(x) return out + + +# Centered Weight Normalization module +class CWNorm(nn.Module): + """Centered Weight Normalization module. + + This module performs centered weight normalization on the weight tensors of Conv2d layers. + """ + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + """Forward pass of the centered weight normalization module. + + Parameters + ---------- + weight : torch.Tensor + The weight tensor of the Conv2d layer. + + Returns + ------- + torch.Tensor + he normalized weight tensor. + """ + weight_ = weight.view(weight.size(0), -1) + weight_mean = weight_.mean(dim=1, keepdim=True) + weight_ = weight_ - weight_mean + norm = weight_.norm(dim=1, keepdim=True) + 1e-5 + weight_CWN = weight_ / norm + return weight_CWN.view(weight.size()) + + +# Custom Conv2d layer with centered weight normalization +class CWN_Conv2d(nn.Conv2d): + """Convolutional layer with Centered Weight Normalization. + + This layer extends the functionality of the standard Conv2d layer in PyTorch by applying + centered weight normalization to its weight tensors. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + NScale=1.414, + adjustScale=False, + *args, + **kwargs, + ): + """Inits :class:`CWN_Conv2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple + Size of the convolutional kernel. + stride : int or tuple, optional + Stride for the convolution operation. Default: 1. + padding : int or tuple, optional + Padding for the convolution operation. Default: 0. + dilation : int or tuple, optional + Dilation rate for the convolution operation. Default: 1. + groups : int, optional + Number of groups for grouped convolution. Default: 1. + bias : bool, optional + If True, the layer has a bias term. Default: True. + NScale : float, optional + The scale factor for the centered weight normalization. Default: 1.414. + adjustScale : bool, optional + If True, the scale factor is adjusted as a learnable parameter. Default: False. + """ + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, *args, **kwargs + ) + self.weight_normalization = CWNorm() + self.scale_ = torch.ones(out_channels, 1, 1, 1).fill_(NScale) + if adjustScale: + self.WNScale = Parameter(self.scale_) + else: + self.register_buffer("WNScale", self.scale_) + + def forward(self, input_f: torch.Tensor) -> torch.Tensor: + """Forward pass of the CWN_Conv2d layer. + + Parameters + ---------- + input_f : torch.Tensor + The input tensor to the convolutional layer. + + Returns + ------- + torch.Tensor + The output tensor after applying the convolution operation with centered weight normalization. + """ + weight_q = self.weight_normalization(self.weight) + weight_q = weight_q * self.WNScale + out = F.conv2d(input_f, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) + return out + + +class CWN_ConvTranspose2d(nn.ConvTranspose2d): + """Transposed Convolutional layer with Centered Weight Normalization. + + This layer extends the functionality of the standard ConvTranspose2d layer in PyTorch by applying + centered weight normalization to its weight tensors. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + NScale=1.414, + adjustScale=False, + *args, + **kwargs, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + *args, + **kwargs, + ) + """Inits :class:`CWN_ConvTranspose2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple + Size of the convolutional kernel. + stride : int or tuple, optional + Stride for the convolution operation. Default: 1. + padding : int or tuple, optional + Padding for the convolution operation. Default: 0. + dilation : int or tuple, optional + Dilation rate for the convolution operation. Default: 1. + groups : int, optional + Number of groups for grouped convolution. Default: 1. + bias : bool, optional + If True, the layer has a bias term. Default: True. + NScale : float, optional + The scale factor for the centered weight normalization. Default: 1.414. + adjustScale : bool, optional + If True, the scale factor is adjusted as a learnable parameter. Default: False. + """ + self.weight_normalization = CWNorm() + self.scale_ = torch.ones(in_channels, 1, 1, 1).fill_(NScale) + if adjustScale: + self.WNScale = Parameter(self.scale_) + else: + self.register_buffer("WNScale", self.scale_) + + def forward(self, input_f: torch.Tensor) -> torch.Tensor: + weight_q = self.weight_normalization(self.weight) + weight_q = weight_q * self.WNScale + out = F.conv_transpose2d( + input_f, weight_q, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation + ) + return out + + +class CWN_Conv3d(nn.Conv3d): + """Convolutional layer with Centered Weight Normalization for 3D data.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + NScale=1.414, + adjustScale=False, + *args, + **kwargs, + ): + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, *args, **kwargs + ) + self.weight_normalization = CWNorm() + self.scale_ = torch.ones(out_channels, 1, 1, 1, 1).fill_(NScale) + if adjustScale: + self.WNScale = Parameter(self.scale_) + else: + self.register_buffer("WNScale", self.scale_) + + def forward(self, input_f: torch.Tensor) -> torch.Tensor: + weight_q = self.weight_normalization(self.weight) + weight_q = weight_q * self.WNScale + out = F.conv3d(input_f, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) + return out + + +class CWN_ConvTranspose3d(nn.ConvTranspose3d): + """Transposed Convolutional layer with Centered Weight Normalization for 3D data.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + NScale=1.414, + adjustScale=False, + *args, + **kwargs, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + *args, + **kwargs, + ) + self.weight_normalization = CWNorm() + self.scale_ = torch.ones(in_channels, 1, 1, 1, 1).fill_(NScale) + if adjustScale: + self.WNScale = Parameter(self.scale_) + else: + self.register_buffer("WNScale", self.scale_) + + def forward(self, input_f: torch.Tensor) -> torch.Tensor: + weight_q = self.weight_normalization(self.weight) + weight_q = weight_q * self.WNScale + out = F.conv_transpose3d( + input_f, weight_q, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation + ) + return out diff --git a/direct/nn/get_nn_model_config.py b/direct/nn/get_nn_model_config.py index f260cfdf..d201fb4e 100644 --- a/direct/nn/get_nn_model_config.py +++ b/direct/nn/get_nn_model_config.py @@ -7,10 +7,22 @@ from direct.nn.conv.conv import Conv2d from direct.nn.didn.didn import DIDN from direct.nn.resnet.resnet import ResNet +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType, UFormerModel +from direct.nn.transformers.vision_transformers import VisionTransformerModel from direct.nn.types import ActivationType, ModelName from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d +def _get_activation(activation: ActivationType): + return ( + nn.PReLU() + if activation == ActivationType.prelu + else nn.ReLU() + if activation == ActivationType.relu + else nn.LeakyReLU() + ) + + def _get_model_config( model_architecture_name: ModelName, in_channels: int = COMPLEX_SIZE, out_channels: int = COMPLEX_SIZE, **kwargs ) -> nn.Module: @@ -22,6 +34,7 @@ def _get_model_config( "num_filters": kwargs.get("unet_num_filters", 32), "num_pool_layers": kwargs.get("unet_num_pool_layers", 4), "dropout_probability": kwargs.get("unet_dropout", 0.0), + "cwn_conv": kwargs.get("cwn_conv", False), } ) elif model_architecture_name == "resnet": @@ -44,17 +57,60 @@ def _get_model_config( "num_convs_recon": kwargs.get("didn_num_convs_recon", 9), } ) + elif model_architecture_name == "uformer": + model_architecture = UFormerModel + model_kwargs.update( + { + "patch_size": kwargs.get("patch_size", 256), + "embedding_dim": kwargs.get("embedding_dim", 32), + "encoder_depths": kwargs.get("encoder_depths", (2, 2, 2, 2)), + "encoder_num_heads": kwargs.get("encoder_num_heads", (1, 2, 4, 8)), + "bottleneck_depth": kwargs.get("bottleneck_depth", 2), + "bottleneck_num_heads": kwargs.get("bottleneck_num_heads", 16), + "win_size": kwargs.get("win_size", 8), + "mlp_ratio": kwargs.get("mlp_ratio", 4.0), + "qkv_bias": kwargs.get("qkv_bias", True), + "qk_scale": kwargs.get("qk_scale", None), + "drop_rate": kwargs.get("drop_rate", 0.0), + "attn_drop_rate": kwargs.get("attn_drop_rate", 0.0), + "drop_path_rate": kwargs.get("drop_path_rate", 0.1), + "patch_norm": kwargs.get("patch_norm", True), + "token_projection": kwargs.get("token_projection", AttentionTokenProjectionType.linear), + "token_mlp": kwargs.get("token_mlp", LeWinTransformerMLPTokenType.leff), + "shift_flag": kwargs.get("shift_flag", True), + "modulator": kwargs.get("modulator", False), + "cross_modulator": kwargs.get("cross_modulator", False), + "normalized": kwargs.get("normalized", True), + } + ) + elif model_architecture_name == "vision_transformer": + model_architecture = VisionTransformerModel + model_kwargs.update( + { + "average_img_size": kwargs.get("average_img_size", 320), + "patch_size": kwargs.get("patch_size", 10), + "embedding_dim": kwargs.get("embedding_dim", 64), + "depth": kwargs.get("depth", 8), + "num_heads": kwargs.get("num_heads", 9), + "mlp_ratio": kwargs.get("mlp_ratio", 4.0), + "qkv_bias": kwargs.get("qkv_bias", True), + "qk_scale": kwargs.get("qk_scale", None), + "drop_rate": kwargs.get("drop_rate", 0.0), + "normalized": kwargs.get("normalized", True), + "attn_drop_rate": kwargs.get("attn_drop_rate", 0.0), + "dropout_path_rate": kwargs.get("dropout_path_rate", 0.0), + "gpsa_interval": kwargs.get("gpsa_interval", (-1, -1)), + "locality_strength": kwargs.get("locality_strength", 1.0), + "use_pos_embedding": kwargs.get("use_pos_embedding", True), + } + ) else: model_architecture = Conv2d model_kwargs.update( { "hidden_channels": kwargs.get("conv_hidden_channels", 64), "n_convs": kwargs.get("conv_n_convs", 15), - "activation": nn.PReLU() - if kwargs.get("conv_activation", "prelu") == ActivationType.prelu - else nn.ReLU() - if kwargs.get("conv_activation", "relu") == ActivationType.relu - else nn.LeakyReLU(), + "activation": _get_activation(kwargs.get("conv_activation", ActivationType.relu)), "batchnorm": kwargs.get("conv_batchnorm", False), } ) diff --git a/direct/nn/lpd/config.py b/direct/nn/lpd/config.py index 0171c384..4f6a9871 100644 --- a/direct/nn/lpd/config.py +++ b/direct/nn/lpd/config.py @@ -19,6 +19,13 @@ class LPDNetConfig(ModelConfig): primal_unet_num_filters: int = 8 primal_unet_num_pool_layers: int = 4 primal_unet_dropout_probability: float = 0.0 + primal_uformer_patch_size: int = 128 + primal_uformer_embedding_dim: int = 8 + primal_uformer_encoder_depths: list[int, ...] = (2, 2, 2) + primal_uformer_encoder_num_heads: list[int, ...] = (2, 4, 8) + primal_uformer_bottleneck_depth: int = 2 + primal_uformer_bottleneck_num_heads: int = 16 + primal_uformer_win_size: int = 8 dual_conv_hidden_channels: int = 16 dual_conv_n_convs: int = 4 dual_conv_batchnorm: bool = False @@ -28,3 +35,10 @@ class LPDNetConfig(ModelConfig): dual_unet_num_filters: int = 8 dual_unet_num_pool_layers: int = 4 dual_unet_dropout_probability: float = 0.0 + dual_uformer_patch_size: int = 256 + dual_uformer_embedding_dim: int = 32 + dual_uformer_encoder_depths: list[int, ...] = (2, 2, 2) + dual_uformer_encoder_num_heads: list[int, ...] = (2, 4, 8) + dual_uformer_bottleneck_depth: int = 2 + dual_uformer_bottleneck_num_heads: int = 16 + dual_uformer_win_size: int = 8 diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py index 04683b49..5c5deb3e 100644 --- a/direct/nn/lpd/lpd.py +++ b/direct/nn/lpd/lpd.py @@ -10,6 +10,7 @@ from direct.nn.conv.conv import Conv2d from direct.nn.didn.didn import DIDN from direct.nn.mwcnn.mwcnn import MWCNN +from direct.nn.transformers.uformer import UFormerModel from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d @@ -177,10 +178,23 @@ def __init__( num_pool_layers=kwargs.get("primal_unet_num_pool_layers", 4), dropout_probability=kwargs.get("primal_unet_dropout_probability", 0.0), ) + elif primal_model_architecture == "UFORMER": + uformer = UFormerModel + primal_model = uformer( + in_channels=2 * (num_primal + 1), + out_channels=2 * num_primal, + patch_size=kwargs.get("primal_uformer_patch_size", 64), + win_size=kwargs.get("primal_uformer_win_size", 5), + embedding_dim=kwargs.get("primal_uformer_embedding_dim", 8), + encoder_depths=kwargs.get("primal_uformer_encoder_depths", [2, 2, 2]), + encoder_num_heads=kwargs.get("primal_uformer_encoder_num_heads", [2, 4, 8]), + bottleneck_depth=kwargs.get("primal_uformer_bottleneck_depth", 2), + bottleneck_num_heads=kwargs.get("primal_uformer_bottleneck_num_heads", 16), + ) else: raise NotImplementedError( - f"XPDNet is currently implemented only with primal_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET." - f"Got {primal_model_architecture}." + f"XPDNet is currently implemented only with primal_model_architecture == 'MWCNN', 'UNET', 'NORMUNET " + f"or 'UFORMER'. Got {primal_model_architecture}." ) dual_model: nn.Module if dual_model_architecture == "CONV": @@ -208,10 +222,23 @@ def __init__( num_pool_layers=kwargs.get("dual_unet_num_pool_layers", 4), dropout_probability=kwargs.get("dual_unet_dropout_probability", 0.0), ) + elif dual_model_architecture == "UFORMER": + uformer = UFormerModel + dual_model = uformer( + in_channels=2 * (num_dual + 2), + out_channels=2 * num_dual, + patch_size=kwargs.get("dual_uformer_patch_size", 64), + win_size=kwargs.get("dual_uformer_win_size", 5), + embedding_dim=kwargs.get("dual_uformer_embedding_dim", 8), + encoder_depths=kwargs.get("dual_uformer_encoder_depths", [2, 2, 2]), + encoder_num_heads=kwargs.get("dual_uformer_encoder_num_heads", [2, 4, 8]), + bottleneck_depth=kwargs.get("dual_uformer_bottleneck_depth", 2), + bottleneck_num_heads=kwargs.get("dual_uformer_bottleneck_num_heads", 16), + ) else: raise NotImplementedError( f"XPDNet is currently implemented for dual_model_architecture == 'CONV', 'DIDN'," - f" 'UNET' or 'NORMUNET'. Got dual_model_architecture == {dual_model_architecture}." + f" 'UNET', 'NORMUNET' or 'UFORMER'. Got dual_model_architecture == {dual_model_architecture}." ) self._coil_dim = 1 diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index 37dae3f4..064aad78 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -20,7 +20,19 @@ import direct.data.transforms as T from direct.config import BaseConfig from direct.engine import DoIterationOutput, Engine -from direct.functionals import NMAELoss, NMSELoss, NRMSELoss, SobelGradL1Loss, SobelGradL2Loss, SSIMLoss +from direct.functionals import ( + HFENL1Loss, + HFENL2Loss, + NMAELoss, + NMSELoss, + NRMSELoss, + PSNRLoss, + SNRLoss, + SobelGradL1Loss, + SobelGradL2Loss, + SSIM3DLoss, + SSIMLoss, +) from direct.types import TensorOrNone from direct.utils import ( communication, @@ -317,9 +329,9 @@ def ssim_loss( Parameters ---------- source: torch.Tensor - Source tensor of shape (batch, height, width, [complex=2]). + Source tensor of shape (batch, [slice], height, width, [complex=2]). target: torch.Tensor - Target tensor of shape (batch, height, width, [complex=2]). + Target tensor of shape (batch, [slice], height, width, [complex=2]). reduction: str Reduction type. Can be "sum" or "mean". reconstruction_size: Optional[Tuple] @@ -335,7 +347,8 @@ def ssim_loss( raise AssertionError( f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." ) - + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) source_abs, target_abs = _crop_volume(source, target, resolution) data_range = torch.tensor([target_abs.max()], device=target_abs.device) @@ -343,6 +356,44 @@ def ssim_loss( return ssim_loss + def ssim_3d_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate SSIM3D loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, slice, height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, slice, height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + ssim_loss: torch.Tensor + SSIM loss. + """ + resolution = get_resolution(reconstruction_size) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + if self.ndim != 3: + raise AssertionError(f"SSIM3D loss is only implemented for 3D data.") + source_abs, target_abs = _crop_volume(source, target, resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIM3DLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + def grad_l1_loss( source: torch.Tensor, target: torch.Tensor, @@ -354,9 +405,9 @@ def grad_l1_loss( Parameters ---------- source: torch.Tensor - Source tensor of shape (batch, height, width, [complex=2]). + Source tensor of shape (batch, [slice], height, width, [complex=2]). target: torch.Tensor - Target tensor of shape (batch, height, width, [complex=2]). + Target tensor of shape (batch, [slice], height, width, [complex=2]). reduction: str Reduction type. Can be "sum" or "mean". reconstruction_size: Optional[Tuple] @@ -368,6 +419,8 @@ def grad_l1_loss( Sobel grad L1 loss. """ resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) source_abs, target_abs = _crop_volume(source, target, resolution) grad_l1_loss = SobelGradL1Loss(reduction).to(source_abs.device).forward(source_abs, target_abs) @@ -384,9 +437,9 @@ def grad_l2_loss( Parameters ---------- source: torch.Tensor - Source tensor of shape (batch, height, width, [complex=2]). + Source tensor of shape (batch, [slice], height, width, [complex=2]). target: torch.Tensor - Target tensor of shape (batch, height, width, [complex=2]). + Target tensor of shape (batch, [slice], height, width, [complex=2]). reduction: str Reduction type. Can be "sum" or "mean". reconstruction_size: Optional[Tuple] @@ -398,11 +451,201 @@ def grad_l2_loss( Sobel grad L1 loss. """ resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) source_abs, target_abs = _crop_volume(source, target, resolution) grad_l2_loss = SobelGradL2Loss(reduction).to(source_abs.device).forward(source_abs, target_abs) return grad_l2_loss + def psnr_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate peak signal-to-noise ratio loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + psnr_loss: torch.Tensor + PSNR loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + psnr_loss = -PSNRLoss(reduction).to(source_abs.device).forward(source_abs, target_abs) + + return psnr_loss + + def snr_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate signal-to-noise loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + snr_loss: torch.Tensor + SNR loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + snr_loss = -SNRLoss(reduction).to(source_abs.device).forward(source_abs, target_abs) + + return snr_loss + + def hfen_l1_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate normalized HFEN L1 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + torch.Tensor + HFEN l1 loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + + return HFENL1Loss(reduction=reduction, norm=False).to(source_abs.device).forward(source_abs, target_abs) + + def hfen_l2_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate normalized HFEN L2 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + torch.Tensor + HFEN l2 loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + + return HFENL2Loss(reduction=reduction, norm=False).to(source_abs.device).forward(source_abs, target_abs) + + def hfen_l1_norm_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate normalized HFEN L1 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + torch.Tensor + Normalized HFEN l1 loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + + return HFENL1Loss(reduction=reduction, norm=True).to(source_abs.device).forward(source_abs, target_abs) + + def hfen_l2_norm_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate normalized HFEN L2 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, [slice], height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, [slice], height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + torch.Tensor + Normalized HFEN l2 loss. + """ + resolution = get_resolution(reconstruction_size) + if self.ndim == 3: + source, target = _reduce_slice_dim(source, target) + source_abs, target_abs = _crop_volume(source, target, resolution) + + return HFENL2Loss(reduction=reduction, norm=True).to(source_abs.device).forward(source_abs, target_abs) + # Build losses loss_dict = {} for curr_loss in self.cfg.training.loss.losses: # type: ignore @@ -413,6 +656,8 @@ def grad_l2_loss( loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) elif loss_fn == "ssim_loss": loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + elif loss_fn == "ssim_3d_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_3d_loss) elif loss_fn == "grad_l1_loss": loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l1_loss) elif loss_fn == "grad_l2_loss": @@ -423,6 +668,18 @@ def grad_l2_loss( loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nrmse_loss) elif loss_fn in ["nmae_loss", "kspace_nmae_loss"]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmae_loss) + elif loss_fn in ["snr_loss", "psnr_loss"]: + loss_dict[loss_fn] = multiply_function( + curr_loss.multiplier, (snr_loss if loss_fn == "snr" else psnr_loss) + ) + elif loss_fn == "hfen_l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_loss) + elif loss_fn == "hfen_l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_loss) + elif loss_fn == "hfen_l1_norm_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_norm_loss) + elif loss_fn == "hfen_l2_norm_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_norm_loss) else: raise ValueError(f"{loss_fn} not permissible.") @@ -446,16 +703,32 @@ def compute_sensitivity_map(self, sensitivity_map: torch.Tensor) -> torch.Tensor sensitivity_map: torch.Tensor Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2). """ - # Some things can be done with the sensitivity map here, e.g. apply a u-net - if "sensitivity_model" in self.models: + + multicoil = sensitivity_map.shape[self._coil_dim] > 1 + + # Pass to sensitivity model only if multiple coils + if multicoil and ("sensitivity_model" in self.models or "sensitivity_model_3d" in self.models): # Move channels to first axis sensitivity_map = sensitivity_map.permute( - (0, 1, 4, 2, 3) + (0, 1, 4, 2, 3) if self.ndim == 2 else (0, 1, 5, 2, 3, 4) ) # shape (batch, coil, complex=2, height, width) - sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( - (0, 1, 3, 4, 2) - ) # has channel last: shape (batch, coil, height, width, complex=2) + if self.ndim == 2: + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map) + else: + if "sensitivity_model_3d" in self.models: + sensitivity_map = self.compute_model_per_coil("sensitivity_model_3d", sensitivity_map) + else: + sensitivity_map = torch.stack( + [ + self.compute_model_per_coil("sensitivity_model", sensitivity_map[:, :, :, _]) + for _ in range(sensitivity_map.shape[3]) + ], + dim=3, + ) + sensitivity_map = sensitivity_map.permute( + (0, 1, 3, 4, 2) if self.ndim == 2 else (0, 1, 3, 4, 5, 2) + ) # has channel last: shape (batch, coil, [slice], height, width, complex=2) # The sensitivity map needs to be normalized such that # So \sum_{i \in \text{coils}} S_i S_i^* = 1 @@ -633,9 +906,21 @@ def evaluate( # type: ignore ) ): volume, target, volume_loss_dict, filename = output + if self.ndim == 3: + # Put slice and time data together + sc, c, z, x, y = volume.shape + volume_for_eval = volume.clone().transpose(1, 2).reshape(sc * z, c, x, y) + target_for_eval = target.clone().transpose(1, 2).reshape(sc * z, c, x, y) + else: + volume_for_eval = volume.clone() + target_for_eval = target.clone() + curr_metrics = { - metric_name: metric_fn(target, volume).clone() for metric_name, metric_fn in volume_metrics.items() + metric_name: metric_fn(target_for_eval, volume_for_eval).clone() + for metric_name, metric_fn in volume_metrics.items() } + del volume_for_eval, target_for_eval + curr_metrics_string = ", ".join([f"{x}: {float(y)}" for x, y in curr_metrics.items()]) self.logger.info("Metrics for %s: %s", filename, curr_metrics_string) # TODO: Path can be tricky if it is not unique (e.g. image.h5) @@ -644,6 +929,10 @@ def evaluate( # type: ignore # Log the center slice of the volume if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + if self.ndim == 3: + # If 3D data get every third slice + volume = torch.cat([volume[:, :, _] for _ in range(0, z, 3)], dim=-1) + target = torch.cat([target[:, :, _] for _ in range(0, z, 3)], dim=-1) visualize_slices.append(volume[volume.shape[0] // 2]) visualize_target.append(target[target.shape[0] // 2]) @@ -687,6 +976,7 @@ def compute_loss_on_data( data: Dict[str, Any], output_image: Optional[torch.Tensor] = None, output_kspace: Optional[torch.Tensor] = None, + weight: float = 1.0, ) -> Dict[str, torch.Tensor]: if output_image is None and output_kspace is None: raise ValueError("Inputs for `output_image` and `output_kspace` cannot be both None.") @@ -695,7 +985,7 @@ def compute_loss_on_data( if output_kspace is not None: output, target, reconstruction_size = output_kspace, data["kspace"], None else: - raise ValueError(f"Requested to compute `{key}` loss but received None for `output_kspace`.") + continue else: if output_image is not None: output, target, reconstruction_size = ( @@ -704,8 +994,8 @@ def compute_loss_on_data( data.get("reconstruction_size", None), ) else: - raise ValueError(f"Requested to compute `{key}` loss but received None for `output_image`.") - loss_dict[key] = value + loss_fns[key](output, target, "mean", reconstruction_size) + continue + loss_dict[key] = value + weight * loss_fns[key](output, target, "mean", reconstruction_size) return loss_dict def _forward_operator(self, image, sensitivity_map, sampling_mask): @@ -754,6 +1044,31 @@ def _crop_volume( return source_abs, target_abs +def _reduce_slice_dim(source: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """This will combine batch and slice dims, for source and target tensors. + + Batch and slice dimensions are assumed to be on first and second axes: `b, c = source.shape[:2]`. + + Parameters + ---------- + source: torch.Tensor + Has shape (batch, slice, *). + target: torch.Tensor + Has shape (batch, slice, *). + + Returns + ------- + (torch.Tensor, torch.Tensor) + Have shape (batch * slice, *). + """ + assert source.shape == target.shape + shape = source.shape + b, s = shape[:2] + source = source.reshape(b * s, *shape[2:]) + target = target.reshape(b * s, *shape[2:]) + return source, target + + def _process_output( data: torch.Tensor, scaling_factors: Optional[torch.Tensor] = None, @@ -782,7 +1097,7 @@ def _process_output( data = T.modulus_if_complex(data, complex_axis=complex_axis) - if len(data.shape) == 3: # (batch, height, width) + if len(data.shape) in [3, 4]: # (batch, height, width) data = data.unsqueeze(1) # Added channel dimension. if resolution is not None: diff --git a/direct/nn/transformers/__init__.py b/direct/nn/transformers/__init__.py new file mode 100644 index 00000000..c9fdc05d --- /dev/null +++ b/direct/nn/transformers/__init__.py @@ -0,0 +1,4 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +"""Transformers models direct module.""" diff --git a/direct/nn/transformers/config.py b/direct/nn/transformers/config.py new file mode 100644 index 00000000..daf6efdd --- /dev/null +++ b/direct/nn/transformers/config.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass +from typing import Optional, Tuple + +from direct.config.defaults import ModelConfig + + +@dataclass +class UFormerModelConfig(ModelConfig): + in_channels: int = 2 + out_channels: Optional[int] = None + patch_size: int = 128 + embedding_dim: int = 16 + encoder_depths: tuple[int, ...] = (2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 8 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: str = "linear" + token_mlp: str = "leff" + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + normalized: bool = True + + +@dataclass +class VariationalUFormerConfig(ModelConfig): + num_steps: int = 5 + patch_size: int = 128 + embedding_dim: int = 16 + encoder_depths: tuple[int, ...] = (2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 8 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: str = "linear" + token_mlp: str = "leff" + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + no_weight_sharing: bool = True + + +@dataclass +class MRIUFormerConfig(ModelConfig): + patch_size: int = 128 + embedding_dim: int = 16 + encoder_depths: tuple[int, ...] = (2, 2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 16 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: str = "linear" + token_mlp: str = "leff" + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + + +@dataclass +class ImageDomainUFormerConfig(MRIUFormerConfig): + pass + + +@dataclass +class KSpaceDomainUFormerConfig(MRIUFormerConfig): + multicoil_input_mode: str = "sense_sum" + patch_size: int = 64 + embedding_dim: int = 16 + encoder_depths: tuple[int, ...] = (2, 2, 2) + encoder_num_heads: tuple[int, ...] = (1, 2, 4) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 8 + + +@dataclass +class MRITransformerConfig(ModelConfig): + num_gradient_descent_steps: int = 8 + average_img_size: int = 320 + patch_size: int = 10 + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = 9 + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + gpsa_interval: Tuple[int, int] = (-1, -1) + locality_strength: float = 1.0 + use_pos_embedding: bool = True + + +@dataclass +class ImageDomainVisionTransformerConfig(ModelConfig): + use_mask: bool = True + average_img_size: int = 320 + patch_size: int = 10 + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = 9 + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + gpsa_interval: Tuple[int, int] = (-1, -1) + locality_strength: float = 1.0 + use_pos_embedding: bool = True diff --git a/direct/nn/transformers/transformers.py b/direct/nn/transformers/transformers.py new file mode 100644 index 00000000..574f5b25 --- /dev/null +++ b/direct/nn/transformers/transformers.py @@ -0,0 +1,968 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn + +from direct.data.transforms import apply_mask, apply_padding, expand_operator, reduce_operator +from direct.nn.transformers.uformer import * +from direct.nn.transformers.utils import norm, pad, pad_to_square, unnorm, unpad +from direct.nn.transformers.vision_transformers import VisionTransformer +from direct.types import DirectEnum + +__all__ = [ + "MRITransformer", + "ImageDomainVisionTransformer", + "ImageDomainUFormer", + "KSpaceDomainUFormerMultiCoilInputMode", + "KSpaceDomainUFormer", + "VariationalUFormer", +] + + +class VariationalUFormerBlock(nn.Module): + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + patch_size: int = 256, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + normalized: bool = True, + ): + super().__init__() + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.uformer = UFormerModel( + patch_size=patch_size, + in_channels=2, + out_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + normalized=normalized, + ) + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + current_kspace: torch.Tensor, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + learning_rate: torch.Tensor, + ) -> torch.Tensor: + """Performs the forward pass of :class:`VariationalUFormerBlock`. + + Parameters + ---------- + current_kspace: torch.Tensor + Current k-space prediction of shape (N, coil, height, width, complex=2). + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + learning_rate : torch.Tensor + (Trainable) Learning rate parameter of shape (1,). + + Returns + ------- + torch.Tensor + Next k-space prediction of shape (N, coil, height, width, complex=2). + """ + kspace_error = apply_mask(current_kspace - masked_kspace, sampling_mask, return_mask=False) + + regularization_term = reduce_operator( + self.backward_operator(current_kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim + ).permute(0, 3, 1, 2) + + regularization_term = self.uformer(regularization_term).permute(0, 2, 3, 1) + + regularization_term = self.forward_operator( + expand_operator(regularization_term, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims + ) + + return current_kspace - learning_rate * kspace_error + regularization_term + + +class VariationalUFormer(nn.Module): + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_steps: int = 8, + no_weight_sharing: bool = True, + patch_size: int = 256, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + **kwargs, + ): + super().__init__() + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.blocks = nn.ModuleList( + [ + VariationalUFormerBlock( + forward_operator=forward_operator, + backward_operator=backward_operator, + patch_size=patch_size, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + for _ in range((num_steps if no_weight_sharing else 1)) + ] + ) + self.lr = nn.Parameter(torch.tensor([1.0] * num_steps)) + self.num_steps = num_steps + self.no_weight_sharing = no_weight_sharing + + self.padding_factor = win_size * (2 ** len(encoder_depths)) + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + padding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Performs the forward pass of :class:`VariationalUFormer`. + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + padding : torch.Tensor, optional + Padding of shape (N, 1, height, width, 1). Default: None. + + Returns + ------- + torch.Tensor + k-space prediction of shape (N, coil, height, width, complex=2). + """ + kspace_prediction = masked_kspace.clone() + for step_idx in range(self.num_steps): + kspace_prediction = self.blocks[step_idx if self.no_weight_sharing else 0]( + kspace_prediction, masked_kspace, sampling_mask, sensitivity_map, self.lr[step_idx] + ) + kspace_prediction = masked_kspace + apply_mask(kspace_prediction, ~sampling_mask, return_mask=False) + if padding is not None: + kspace_prediction = apply_padding(kspace_prediction, padding) + return kspace_prediction + + +class MRIUFormer(nn.Module): + """A PyTorch module that implements MRI image reconstruction using an image domain UFormer.""" + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + patch_size: int = 256, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + **kwargs, + ): + """Inits :class:`MRIUFormer`. + + Parameters + ---------- + forward_operator: Callable + Forward Operator. + backward_operator: Callable + Backward Operator. + patch_size : int + Size of the patch. Default: 256. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.uformer = UFormer( + patch_size=patch_size, + in_channels=2, + out_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + self.padding_factor = win_size * (2 ** len(encoder_depths)) + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + +class KSpaceDomainUFormerMultiCoilInputMode(DirectEnum): + sense_sum = "sense_sum" + compute_per_coil = "compute_per_coil" + + +class KSpaceDomainUFormer(MRIUFormer): + """A PyTorch module that implements MRI image reconstruction using a k-space domain UFormer.""" + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + multicoil_input_mode: KSpaceDomainUFormerMultiCoilInputMode = KSpaceDomainUFormerMultiCoilInputMode.sense_sum, + patch_size: int = 128, + embedding_dim: int = 16, + encoder_depths: tuple[int, ...] = (2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 8, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + **kwargs, + ): + """Inits :class:`KSpaceDomainUFormer`. + + Parameters + ---------- + forward_operator: Callable + Forward Operator. + backward_operator: Callable + Backward Operator. + multicoil_input_mode: KSpaceDomainUFormerMultiCoilInputMode + Set to "sense_sum" to aggregate all coil data, or "compute_per_coil" to pass each coil data in + a different pass to the same model. Default: KSpaceDomainUFormerMultiCoilInputMode.sense_sum. + patch_size : int + Size of the patch. Default: 128. + embedding_dim : int + Size of the feature embedding. Default: 16. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4). + bottleneck_depth : int + Default: 8. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__( + forward_operator=forward_operator, + backward_operator=backward_operator, + patch_size=patch_size, + in_channels=2, + out_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + self.multicoil_input_mode = multicoil_input_mode + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: Optional[torch.Tensor] = None, + padding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Performs forward pass of :class:`KSpaceDomainUFormer`. + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map : torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor, optional + Sampling mask of shape (N, 1, height, width, 1). If not None, it will use + :math:`y_{inp} + (1-M) * f_\theta(y_{inp})` as output, else :math:`f_\theta(y_{inp})`. + padding : torch.Tensor, optional + Zero-padding (i.e. not sampled locations) that may be present in original k-space + of shape (N, 1, height, width, 1). If not None, padding will be applied to output. + + Returns + ------- + out : torch.Tensor + Prediction of output image of shape (N, height, width, complex=2). + """ + + # Pad to square in image domain + inp = self.backward_operator(masked_kspace, dim=self._spatial_dims).permute(0, 1, 4, 2, 3) + inp, _, wpad, hpad = pad_to_square(inp, self.padding_factor) + padded_sensitivity_map, _, _, _ = pad_to_square(sensitivity_map.permute(0, 1, 4, 2, 3), self.padding_factor) + padded_sensitivity_map = padded_sensitivity_map.permute(0, 1, 3, 4, 2) + + # Project back to k-space + inp = self.forward_operator(inp.permute(0, 1, 3, 4, 2).contiguous(), dim=self._spatial_dims) + if self.multicoil_input_mode == "sense_sum": + # Construct SENSE reconstruction + # \sum_{k=1}^{n_c} S^k * \mathcal{F}^{-1} (y^k) + inp = reduce_operator( + coil_data=self.backward_operator(inp, dim=self._spatial_dims), + sensitivity_map=padded_sensitivity_map, + dim=self._coil_dim, + ) + # Project the SENSE reconstruction to k-space domain and use as input to model + inp = self.forward_operator(inp, dim=[d - 1 for d in self._spatial_dims]) + inp = inp.permute(0, 3, 1, 2) + + inp, mean, std = norm(inp) + + out = self.uformer(inp) + + out = unnorm(out, mean, std) + + # Project k-space to image domain and unpad + out = self.backward_operator(out.permute(0, 2, 3, 1), dim=[d - 1 for d in self._spatial_dims]) + else: + # Pass each coil k-space to model + out = [] + for coil_idx in range(masked_kspace.shape[self._coil_dim]): + coil_data = inp[:, coil_idx].permute(0, 3, 1, 2) + + coil_data, mean, std = norm(coil_data) + + coil_data = self.uformer(coil_data) + + coil_data = unnorm(coil_data, mean, std).permute(0, 2, 3, 1) + + out.append(coil_data) + out = torch.stack(out, dim=self._coil_dim) + + out = reduce_operator( + coil_data=self.backward_operator(out, dim=self._spatial_dims), + sensitivity_map=padded_sensitivity_map, + dim=self._coil_dim, + ) + out = unpad(out.permute(0, 3, 1, 2), wpad, hpad).permute(0, 2, 3, 1) + + out = self.forward_operator(expand_operator(out, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + if sampling_mask is not None: + out = masked_kspace + apply_mask(out, ~sampling_mask, return_mask=False) + if padding is not None: + out = apply_padding(out, padding) + return out + + +class ImageDomainUFormer(MRIUFormer): + """A PyTorch module that implements MRI image reconstruction using an image domain UFormer.""" + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + patch_size: int = 256, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + **kwargs, + ): + """Inits :class:`ImageDomainUFormer`. + + Parameters + ---------- + forward_operator: Callable + Forward Operator. + backward_operator: Callable + Backward Operator. + patch_size : int + Size of the patch. Default: 256. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__( + forward_operator=forward_operator, + backward_operator=backward_operator, + patch_size=patch_size, + in_channels=2, + out_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: Optional[torch.Tensor] = None, + padding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Performs forward pass of :class:`ImageDomainUFormer`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor, optional + Sampling mask of shape (N, 1, height, width, 1). If not None, it will use + :math:`y_{inp} + (1-M) * f_\theta(y_{inp})` as output, else :math:`f_\theta(y_{inp})`. + padding : torch.Tensor, optional + Zero-padding (i.e. not sampled locations) that may be present in original k-space + of shape (N, 1, height, width, 1). If not None, padding will be applied to output. + + Returns + ------- + out : torch.Tensor + Prediction of output image of shape (N, height, width, complex=2). + """ + inp = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + inp = inp.permute(0, 3, 1, 2) + + inp, padding_mask, wpad, hpad = pad_to_square(inp, factor=self.padding_factor) + inp, mean, std = norm(inp) + + out = self.uformer(inp, padding_mask) + + out = unnorm(out, mean, std) + out = unpad(out, wpad, hpad).permute(0, 2, 3, 1) + + out = self.forward_operator( + expand_operator(out, sensitivity_map, dim=self._coil_dim), + dim=self._spatial_dims, + ) + + if sampling_mask is not None: + out = masked_kspace + apply_mask(out, ~sampling_mask, return_mask=False) + if padding is not None: + out = apply_padding(out, padding) + + return out + + +class MRITransformer(nn.Module): + """A PyTorch module that implements MRI image reconstruction using VisionTransformer.""" + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_gradient_descent_steps: int, + average_img_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 10, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + gpsa_interval: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + **kwargs, + ): + """Inits :class:`MRITransformer`. + + Parameters + ---------- + forward_operator : Callable + Forward operator function. + backward_operator : Callable + Backward operator function. + num_gradient_descent_steps : int + Number of gradient descent steps to perform. + average_img_size : int or tuple[int, int], optional + Size to which the input image is rescaled before processing. + patch_size : int or tuple[int, int], optional + Patch size used in VisionTransformer. + embedding_dim : int, optional + The number of embedding dimensions in the VisionTransformer. + depth : int, optional + The number of layers in the VisionTransformer. + num_heads : int, optional + The number of attention heads in the VisionTransformer. + mlp_ratio : float, optional + The ratio of MLP hidden size to embedding size in the VisionTransformer. + qkv_bias : bool, optional + Whether to include bias terms in the projection matrices in the VisionTransformer. + qk_scale : float, optional + Scale factor for query and key in the attention calculation in the VisionTransformer. + drop_rate : float, optional + Dropout probability for the VisionTransformer. + attn_drop_rate : float, optional + Dropout probability for the attention layer in the VisionTransformer. + dropout_path_rate : float, optional + Dropout probability for the intermediate skip connections in the VisionTransformer. + norm_layer : nn.Module, optional + Normalization layer used in the VisionTransformer. + gpsa_interval : tuple[int, int], optional + Interval for performing Generalized Positional Self-Attention (GPSA) in the VisionTransformer. + locality_strength : float, optional + The strength of locality in the GPSA in the VisionTransformer. + use_pos_embedding : bool, optional + Whether to use positional embedding in the VisionTransformer. + """ + super().__init__() + self.transformers = nn.ModuleList( + [ + VisionTransformer( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=2, + out_channels=2, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + norm_layer=norm_layer, + gpsa_interval=gpsa_interval, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + ) + for _ in range(num_gradient_descent_steps) + ] + ) + self.learning_rate = nn.Parameter(torch.ones(num_gradient_descent_steps)) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.num_gradient_descent_steps = num_gradient_descent_steps + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def _forward_operator( + self, image: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor + ) -> torch.Tensor: + forward = apply_mask( + self.forward_operator(expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims), + sampling_mask, + return_mask=False, + ) + return forward + + def _backward_operator( + self, kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor + ) -> torch.Tensor: + backward = reduce_operator( + self.backward_operator(apply_mask(kspace, sampling_mask, return_mask=False), self._spatial_dims), + sensitivity_map, + self._coil_dim, + ) + return backward + + def forward( + self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor = None + ) -> torch.Tensor: + """Performs forward pass of :class:`ImageDomainVisionTransformer`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + + Returns + ------- + out : torch.Tensor + Prediction of output image of shape (N, height, width, complex=2). + """ + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + for _ in range(self.num_gradient_descent_steps): + x_trans, wpad, hpad = pad(x.permute(0, 3, 1, 2), self.transformers[0].patch_size) + x_trans, mean, std = norm(x_trans) + + x_trans = x_trans + self.transformers[_](x_trans) + + x_trans = unnorm(x_trans, mean, std) + x_trans = unpad(x_trans, wpad, hpad).permute(0, 2, 3, 1) + + x = x - self.learning_rate[_] * ( + self._backward_operator( + self._forward_operator(x, sampling_mask, sensitivity_map) - masked_kspace, + sampling_mask, + sensitivity_map, + ) + + x_trans + ) + + return x + + +class ImageDomainVisionTransformer(nn.Module): + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + use_mask: bool = True, + average_img_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 10, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + gpsa_interval: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + **kwargs, + ): + super().__init__() + self.tranformer = VisionTransformer( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=4 if use_mask else 2, + out_channels=2, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + norm_layer=norm_layer, + gpsa_interval=gpsa_interval, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.use_mask = use_mask + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor = None + ) -> torch.Tensor: + """Performs forward pass of :class:`ImageDomainVisionTransformer`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + + Returns + ------- + out : torch.Tensor + Prediction of output image of shape (N, height, width, complex=2). + """ + inp = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + + if self.use_mask and sampling_mask is not None: + sampling_mask_inp = torch.cat( + [ + sampling_mask, + torch.zeros(*sampling_mask.shape, device=sampling_mask.device), + ], + dim=self._complex_dim, + ).to(inp.dtype) + # project it in image domain + sampling_mask_inp = self.backward_operator(sampling_mask_inp, dim=self._spatial_dims).squeeze( + self._coil_dim + ) + inp = torch.cat([inp, sampling_mask_inp], dim=self._complex_dim) + + inp = inp.permute(0, 3, 1, 2) + + inp, wpad, hpad = pad(inp, self.transformer.patch_size) + inp, mean, std = norm(inp) + + out = self.transformer(inp) + + out = unnorm(out, mean, std) + out = unpad(out, wpad, hpad) + + return out.permute(0, 2, 3, 1) diff --git a/direct/nn/transformers/transformers_engine.py b/direct/nn/transformers/transformers_engine.py new file mode 100644 index 00000000..fd9135c5 --- /dev/null +++ b/direct/nn/transformers/transformers_engine.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from torch import nn + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.nn.mri_models import MRIModelEngine + + +class VariationalUFormerEngine(MRIModelEngine): + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VariationalUFormerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + padding=data.get("padding", None), + ) + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim + ) + return output_image, output_kspace + + +class MRIUFormerEngine(MRIModelEngine): + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`MRIUFormerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + padding=data.get("padding", None), + ) + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim + ) + return output_image, output_kspace + + +class ImageDomainUFormerEngine(MRIUFormerEngine): + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`ImageDomainUFormerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + +class KSpaceDomainUFormerEngine(MRIUFormerEngine): + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`KSpaceDomainUFormerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + +class MRITransformerEngine(MRIModelEngine): + """MRI Transformer Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`ImageDomainTransformerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) + output_kspace = T.apply_padding( + self.backward_operator( + T.expand_operator(data=output_image, sensitivity_map=data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + data.get("padding", None), + ) + output_kspace = torch.where(data["sampling_mask"] == 0, output_kspace, data["masked_kspace"]) + + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), + dim=self._coil_dim, + ) + + return output_image, output_kspace + + +class ImageDomainVisionTransformerEngine(MRIModelEngine): + """Image Domain Transformer Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`ImageDomainTransformerEngine.""" + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) + output_kspace = T.apply_padding( + self.backward_operator( + T.expand_operator(data=output_image, sensitivity_map=data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + data.get("padding", None), + ) + output_kspace = torch.where(data["sampling_mask"] == 0, output_kspace, data["masked_kspace"]) + + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), + dim=self._coil_dim, + ) + + return output_image, output_kspace diff --git a/direct/nn/transformers/uformer.py b/direct/nn/transformers/uformer.py new file mode 100644 index 00000000..72ce22a1 --- /dev/null +++ b/direct/nn/transformers/uformer.py @@ -0,0 +1,1961 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.nn.init import trunc_normal_ + +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_square, unnorm, unpad +from direct.types import DirectEnum + +__all__ = ["AttentionTokenProjectionType", "LeWinTransformerMLPTokenType", "UFormer", "UFormerModel"] + + +class ECALayer1d(nn.Module): + """Efficient Channel Attention (ECA) module for 1D data.""" + + def __init__(self, channel: int, k_size: int = 3): + """Inits :class:`ECALayer1d`. + + Parameters + ---------- + channel : int + Number of channels of the input feature map. + k_size : int + Adaptive selection of kernel size. Default: 3. + """ + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + self.channel = channel + self.k_size = k_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes the output of the ECA layer. + + Parameters + ---------- + x : torch.Tensor + Input feature map. + + Returns + ------- + y : torch.Tensor + Output of the ECA layer. + """ + # feature descriptor on the global spatial information + y = self.avg_pool(x.transpose(-1, -2)) + + # Two different branches of ECA module + y = self.conv(y.transpose(-1, -2)) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) + + def flops(self) -> int: + """Computes the number of floating point operations in :class:`ECA`. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + flops += self.channel * self.channel * self.k_size + + return flops + + +class SepConv2d(torch.nn.Module): + """A 2D Separable Convolutional layer. + + Applies a depthwise convolution followed by a pointwise convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + act_layer: nn.Module = nn.ReLU, + ): + """Inits :class:`SepConv2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of ints + Size of the convolution kernel. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + padding : int or tuple of ints + Padding added to all four sides of the input. Default: 0. + dilation : int or tuple of ints + Spacing between kernel elements. Default: 1. + act_layer : torch.nn.Module + Activation layer applied after depthwise convolution. Default: nn.ReLU. + """ + super().__init__() + self.depthwise = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + ) + self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.act_layer = act_layer() if act_layer is not None else nn.Identity() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`SepConv2d`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after applying depthwise and pointwise convolutions with activation. + """ + x = self.depthwise(x) + x = self.act_layer(x) + x = self.pointwise(x) + return x + + def flops(self, HW: int) -> int: + """Calculate the number of floating point operations in :class:`SepConv2d`. + + Parameters + ---------- + HW : int + Size of the spatial dimension of the input tensor. + + Returns + ------- + int : Number of floating point operations. + """ + flops = 0 + flops += HW * self.in_channels * self.kernel_size**2 / self.stride**2 + flops += HW * self.in_channels * self.out_channels + return int(flops) + + +######## Embedding for q,k,v ######## +class ConvProjectionModule(nn.Module): + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + kernel_size: int = 3, + q_stride: int = 1, + k_stride: int = 1, + v_stride: int = 1, + bias: bool = True, + ): + """Inits :class:`ConvProjectionModule`. + + Parameters + ---------- + dim : int + Number of channels in the input tensor. + heads : int + Number of heads in multi-head attention. Default: 8. + dim_head : int + Dimension of each head. Default: 64. + kernel_size : int + Size of convolutional kernel. Default: 3. + q_stride : int + Stride of the convolutional kernel for queries. Default: 1. + k_stride : int + Stride of the convolutional kernel for keys. Default: 1. + v_stride : int + Stride of the convolutional kernel for values. Default: 1. + bias : bool + Whether to include a bias term. Default: True. + """ + super().__init__() + + inner_dim = dim_head * heads + self.heads = heads + pad = (kernel_size - q_stride) // 2 + self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias) + self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias) + self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of :class:`ConvProjectionModule`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + attn_kv : torch.Tensor, optional + Attention key/value tensor. Default None. + + Returns + ------- + q : torch.Tensor + Query tensor. + k : torch.Tensor + Key tensor. + v : torch.Tensor + Value tensor. + """ + b, n, c, h = *x.shape, self.heads + l = int(math.sqrt(n)) + w = int(math.sqrt(n)) + + attn_kv = x if attn_kv is None else attn_kv + x = rearrange(x, "b (l w) c -> b c l w", l=l, w=w) + attn_kv = rearrange(attn_kv, "b (l w) c -> b c l w", l=l, w=w) + q = self.to_q(x) + q = rearrange(q, "b (h d) l w -> b h (l w) d", h=h) + + k = self.to_k(attn_kv) + v = self.to_v(attn_kv) + k = rearrange(k, "b (h d) l w -> b h (l w) d", h=h) + v = rearrange(v, "b (h d) l w -> b h (l w) d", h=h) + return q, k, v + + def flops(self, q_L: int, kv_L: Optional[int] = None) -> int: + """Calculate the number of floating point operations in :class:`ConvProjectionModule`. + + Parameters + ---------- + q_L : int + Size of input patches. + kv_L : int, optional + Size of key/value patches. Default None. + + Returns + ------- + flops : int + Number of floating point operations. + """ + kv_L = kv_L or q_L + flops = 0 + flops += self.to_q.flops(q_L) + flops += self.to_k.flops(kv_L) + flops += self.to_v.flops(kv_L) + return flops + + +class LinearProjectionModule(nn.Module): + """Linear projection layer used in the window attention mechanism of the Transformer model.""" + + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, bias: bool = True): + """Inits :class:LinearProjectionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + heads : int + The number of heads in the multi-head attention mechanism. Default: 8. + dim_head : int, optional + The feature dimension of each head. Default: 64. + bias : bool, optional + Whether to use bias in the linear projections. Default: True. + """ + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = nn.Linear(dim, inner_dim, bias=bias) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) + self.dim = dim + self.inner_dim = inner_dim + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs forward pass of :class:`LinearProjectionModule`. + + Parameters + ---------- + x : torch.Tensor of shape (batch_size, seq_length, dim) + The input tensor. + attn_kv : torch.Tensor of shape (batch_size, seq_length, dim), optional + The tensor to be used for computing the attention scores. If None, the input tensor is used. Default: None. + + Returns + ------- + q : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of x used for computing the queries. + k : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the keys. + v : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the values. + + """ + B_, N, C = x.shape + if attn_kv is not None: + attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) + else: + attn_kv = x + N_kv = attn_kv.size(1) + q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + q = q[0] + k, v = kv[0], kv[1] + return q, k, v + + def flops(self, q_L: int, kv_L: Optional[int] = None) -> int: + """Calculate the number of floating point operations in :class:`LinearProjectionModule`. + + Parameters + ---------- + q_L : int + Size of input patches. + kv_L : int, optional + Size of key/value patches. Default None. + + Returns + ------- + flops : int + Number of floating point operations. + """ + kv_L = kv_L or q_L + flops = q_L * self.dim * self.inner_dim + kv_L * self.dim * self.inner_dim * 2 + return flops + + +########### window-based self-attention ############# +class AttentionTokenProjectionType(DirectEnum): + conv = "conv" + linear = "linear" + + +class WindowAttentionModule(nn.Module): + """A window-based multi-head self-attention module.""" + + def __init__( + self, + dim: int, + win_size: tuple[int, int], + num_heads: int, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + """Inits :class:`WindowAttentionModule`. + + Parameters + ---------- + dim : int + Input feature dimension. + win_size : tuple[int, int] + The window size (height and width). + num_heads : int + Number of heads for multi-head self-attention. + token_projection : AttentionTokenProjectionType + Type of projection for token-level queries, keys, and values. Either "conv" or "linear". + qkv_bias : bool + Whether to use bias in the linear projection layer for queries, keys, and values. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout rate for attention weights. + proj_drop : float + Dropout rate for the output of the last linear projection layer. + + """ + super().__init__() + self.dim = dim + self.win_size = win_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] + coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + trunc_normal_(self.relative_position_bias_table, std=0.02) + + if token_projection == "conv": + self.qkv = ConvProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + elif token_projection == "linear": + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + else: + raise Exception("Projection error!") + + self.token_projection = token_projection + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs forward pass of :class:`WindowAttentionModule`. + + Parameters + ---------- + x : torch.Tensor + A tensor of shape `(B, N, C)` representing the input features, where `B` is the batch size, `N` is the + sequence length, and `C` is the input feature dimension. + attn_kv : torch.Tensor, optional + An optional tensor of shape `(B, N, C)` representing the key-value pairs used for attention computation. + If `None`, the key-value pairs are computed from `x` itself. Default: None. + mask : torch.Tensor, optional + An optional tensor of shape representing the binary mask for the input sequence. + If `None`, no masking is applied. Default: None. + + Returns + ------- + torch.Tensor + A tensor of shape `(B, N, C)` representing the output features after attention computation. + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + ratio = attn.size(-1) // relative_position_bias.size(-1) + relative_position_bias = repeat(relative_position_bias, "nH l c -> nH l (c d)", d=ratio) + + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + mask = repeat(mask, "nW m n -> nW m (n d)", d=ratio) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N * ratio) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}" + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`LinearProjectionModule` for 1 window + with token length of N. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + N = self.win_size[0] * self.win_size[1] + nW = H * W / N + + flops += self.qkv.flops(H * W, H * W) + + flops += nW * self.num_heads * N * (self.dim // self.num_heads) * N + flops += nW * self.num_heads * N * N * (self.dim // self.num_heads) + + flops += nW * N * self.dim * self.dim + return int(flops) + + +########### self-attention ############# +class AttentionModule(nn.Module): + """Self-attention module.""" + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ): + """Inits :class:`AttentionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + num_heads : int + The number of attention heads. + qkv_bias : bool + Whether to include biases in the query, key, and value projections. Default: True. + qk_scale : float, optional + Scaling factor for the query and key projections. Default: None. + attn_drop : float + Dropout probability for the attention weights. Default: 0.0. + proj_drop : float + Dropout probability for the output of the attention module. Default: 0.0. + """ + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs the forward pass of :class:`AttentionModule`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + attn_kv : torch.Tensor, optional + The attention key/value tensor. + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}" + + def flops(self, q_num: int, kv_num: int) -> int: + """Calculate the number of floating point operations in :class:`LinearProjectionModule`. + + Parameters + ---------- + q_num : int + Size of input patches. + kv_num : int, optional + Size of key/value patches. Default None. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + + flops += self.qkv.flops(q_num, kv_num) + + flops += self.num_heads * q_num * (self.dim // self.num_heads) * kv_num + flops += self.num_heads * q_num * (self.dim // self.num_heads) * kv_num + + flops += q_num * self.dim * self.dim + return flops + + +######################################### +########### feed-forward network ############# +class MLP(nn.Module): + """Multi-layer perceptron with optional dropout regularization.""" + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ): + """Inits :class:`MLP`. + + Parameters: + ----------- + in_features : int + Number of input features. + hidden_features : int, optional + Number of output features in the hidden layer. If not specified, `in_features` is used. + out_features : int, optional + Number of output features. If not specified, `in_features` is used. + act_layer : nn.Module + Activation layer. Default: GeLU. + drop : float + Dropout probability. Default: 0.0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + output : torch.Tensor + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`MLP`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # fc1 + flops += H * W * self.in_features * self.hidden_features + # fc2 + flops += H * W * self.hidden_features * self.out_features + return flops + + +class LeFF(nn.Module): + """Locally-enhanced Feed-Forward Network module.""" + + def __init__(self, dim: int = 32, hidden_dim: int = 128, act_layer: nn.Module = nn.GELU, use_eca: bool = False): + """Inits :class:`LeFF`. + + Parameters + ---------- + dim : int + Dimension of the input and output features. Default: 32. + hidden_dim : int + Dimension of the hidden features. Default: 128. + act_layer : nn.Module + Activation layer to apply after the first linear layer and the depthwise convolution. Default: GELU. + use_eca : bool + If True, adds a 1D ECA layer after the second linear layer. Default: False. + """ + super().__init__() + self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) + self.dwconv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer() + ) + self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) + self.dim = dim + self.hidden_dim = hidden_dim + self.eca = ECALayer1d(dim) if use_eca else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`LeFF`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + """ + # bs x hw x c + bs, hw, c = x.size() + hh = int(math.sqrt(hw)) + + x = self.linear1(x) + + # spatial restore + x = rearrange(x, " b (h w) (c) -> b c h w ", h=hh, w=hh) + # bs,hidden_dim,32x32 + + x = self.dwconv(x) + + # flatten + x = rearrange(x, " b c h w -> b (h w) c", h=hh, w=hh) + + x = self.linear2(x) + x = self.eca(x) + + return x + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`LeFF`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # fc1 + flops += H * W * self.dim * self.hidden_dim + # dwconv + flops += H * W * self.hidden_dim * 3 * 3 + # fc2 + flops += H * W * self.hidden_dim * self.dim + # eca + if hasattr(self.eca, "flops"): + flops += self.eca.flops() + return flops + + +######################################### +########### window operation############# +def window_partition(x: torch.Tensor, win_size: int, dilation_rate: int = 1) -> torch.Tensor: + """Partition the input tensor into windows of specified size. + + Parameters + ---------- + x : torch.Tensor + The input tensor to be partitioned into windows. + win_size : int + The size of the square windows to partition the tensor into. + dilation_rate : int + The dilation rate for convolution. Default: 1. + + Returns + ------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + """ + B, H, W, C = x.shape + if dilation_rate != 1: + x = x.permute(0, 3, 1, 2) # B, C, H, W + assert type(dilation_rate) is int, "dilation_rate should be a int" + x = F.unfold( + x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) # B, C*Wh*Ww, H/Wh*W/Ww + windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww + windows = windows.permute(0, 2, 3, 1).contiguous() # B' ,Wh ,Ww ,C + else: + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C + return windows + + +def window_reverse(windows: torch.Tensor, win_size: int, H: int, W: int, dilation_rate: int = 1) -> torch.Tensor: + """Rearrange the partitioned tensor back to the original tensor. + + Parameters + ---------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + win_size : int + The size of the square windows used to partition the tensor. + H : int + The height of the original tensor before partitioning. + W : int + The width of the original tensor before partitioning. + dilation_rate : int + The dilation rate for convolution. Default 1. + + Returns + ------- + x: torch.Tensor + The original tensor rearranged from the partitioned tensor. + + """ + # B' ,Wh ,Ww ,C + B = int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + if dilation_rate != 1: + x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww + x = F.fold( + x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) + else: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class DownSampleBlock(nn.Module): + """Convolution based downsample block.""" + + def __init__(self, in_channels: int, out_channels: int): + """Inits :class:`DownSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`DownSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Downsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`DownSampleBlock`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # conv + flops += H / 2 * W / 2 * self.in_channels * self.out_channels * 4 * 4 + return int(flops) + + +class UpSampleBlock(nn.Module): + """Convolution based upsample block.""" + + def __init__(self, in_channels: int, out_channels: int): + """Inits :class:`UpSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.deconv = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`UpSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Upsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`UpSampleBlock`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # conv + flops += H * 2 * W * 2 * self.in_channels * self.out_channels * 2 * 2 + return flops + + +class InputProjection(nn.Module): + """Input convolutional projection used in the U-Former model.""" + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: nn.Module = nn.LeakyReLU, + ): + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 3. + out_channels : int + Number of output channels after the projection. Default: 64. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module + Activation layer to apply after the projection. Default: nn.LeakyReLU. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + act_layer(inplace=True), + ) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`InputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`InputProjection`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # conv + flops += H * W * self.in_channels * self.out_channels * 3 * 3 + + if self.norm is not None: + flops += H * W * self.out_channels + return flops + + +class OutputProjection(nn.Module): + """Output convolutional projection used in the U-Former model.""" + + def __init__( + self, + in_channels: int = 64, + out_channels: int = 3, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: Optional[nn.Module] = None, + ): + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 64. + out_channels : int + Number of output channels after the projection. Default: 3. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module, optional + Activation layer to apply after the projection. Default: None. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + ) + if act_layer is not None: + self.proj.add_module(act_layer(inplace=True)) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`OutputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self, H: int, W: int) -> int: + """Calculate the number of floating point operations in :class:`InputProjection`. + + Parameters + ---------- + H : int + Height. + W : int + Width. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # conv + flops += H * W * self.in_channels * self.out_channels * 3 * 3 + + if self.norm is not None: + flops += H * W * self.out_channels + return flops + + +class LeWinTransformerMLPTokenType(DirectEnum): + mlp = "mlp" + ffn = "ffn" + leff = "leff" + + +class LeWinTransformerBlock(nn.Module): + """Applies a window-based multi-head self-attention and MLP or LeFF on the input tensor.""" + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + num_heads: int, + win_size: int = 8, + shift_size: int = 0, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + modulator: bool = False, + cross_modulator: bool = False, + ): + r"""Inits :class:`LeWinTransformerBlock`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + shift_size : int + The number of pixels to shift the window. Default: 0. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + act_layer : nn.Module + The activation function to use. Default: nn.GELU. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.win_size = win_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.token_mlp = token_mlp + if min(self.input_resolution) <= self.win_size: + self.shift_size = 0 + self.win_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" + + if modulator: + self.modulator = nn.Embedding(win_size * win_size, dim) # modulator + else: + self.modulator = None + + if cross_modulator: + self.cross_modulator = nn.Embedding(win_size * win_size, dim) # cross_modulator + self.cross_attn = AttentionModule( + dim, + num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.norm_cross = norm_layer(dim) + else: + self.cross_modulator = None + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionModule( + dim, + win_size=(self.win_size, self.win_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + token_projection=token_projection, + ) + + self.drop_path = DropoutPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if token_mlp in ["ffn", "mlp"]: + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + elif token_mlp == "leff": + self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer) + else: + raise Exception("FFN error!") + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}" + ) + + def forward(self, x, mask=None): + """Performs the forward pass of :class:`LeWinTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + + ## input mask + if mask != None: + input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1) + input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 + attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + ## shift mask + if self.shift_size > 0: + # calculate attention mask for SW-MSA + shift_mask = torch.zeros((1, H, W, 1)).type_as(x) + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + shift_mask[:, h, w, :] = cnt + cnt += 1 + shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 + shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze( + 2 + ) # nW, win_size*win_size, win_size*win_size + shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill( + shift_attn_mask == 0, float(0.0) + ) + attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask + if self.cross_modulator is not None: + shortcut = x + x_cross = self.norm_cross(x) + x_cross = self.cross_attn(x, self.cross_modulator.weight) + x = shortcut + x_cross + shortcut = x + + x = self.norm1(x) + x = x.view(B, H, W, C) + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + # partition windows + x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C + x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C + # with_modulator + if self.modulator is not None: + wmsa_in = self.with_pos_embed(x_windows, self.modulator.weight) + else: + wmsa_in = x_windows + + # W-MSA/SW-MSA + attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) + shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + del attn_mask + return x + + def flops(self) -> int: + """Calculate the number of floating point operations in :class:`LeWinTransformerBlock`. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + H, W = self.input_resolution + + if self.cross_modulator is not None: + flops += self.dim * H * W + flops += self.cross_attn.flops(H * W, self.win_size * self.win_size) + + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + flops += self.attn.flops(H, W) + # norm2 + flops += self.dim * H * W + # mlp + flops += self.mlp.flops(H, W) + return flops + + +class BasicUFormerLayer(nn.Module): + """Basic layer of U-Former.""" + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + depth: int, + num_heads: int, + win_size: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: List[float] | float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.ffn, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ): + r"""Inits :class:`BasicUFormerLayer`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift in the attention sliding windows or not. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList( + [ + LeWinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + win_size=win_size, + shift_size=(0 if (i % 2 == 0) else win_size // 2) if shift_flag else 0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + token_projection=token_projection, + token_mlp=token_mlp, + modulator=modulator, + cross_modulator=cross_modulator, + ) + for i in range(depth) + ] + ) + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`BasicUFormerLayer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + for blk in self.blocks: + x = blk(x, mask) + return x + + def flops(self) -> int: + """Calculate the number of floating point operations in :class:`BasicUFormerLayer`. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + for blk in self.blocks: + flops += blk.flops() + return flops + + +class UFormer(nn.Module): + """U-Former is a transformer-based architecture that can process high-resolution images.""" + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ): + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + if len(encoder_num_heads) != len(encoder_depths): + raise ValueError( + f"The number of heads for each layer should be the same as the number of layers. " + f"Got {len(encoder_num_heads)} for {len(encoder_depths)} layers." + ) + if patch_size < (2 ** len(encoder_depths) * win_size): + raise ValueError( + f"Patch size must be greater or equal than 2 ** number of scales * window size." + f" Received: patch_size={patch_size}, number of scales=={len(encoder_depths)}," + f" and window_size={win_size}." + ) + self.num_enc_layers = len(encoder_num_heads) + self.num_dec_layers = len(encoder_num_heads) + depths = (*encoder_depths, bottleneck_depth, *encoder_depths[::-1]) + num_heads = (*encoder_num_heads, bottleneck_num_heads, bottleneck_num_heads, *encoder_num_heads[::-1][:-1]) + self.embedding_dim = embedding_dim + self.patch_norm = patch_norm + self.mlp_ratio = mlp_ratio + self.token_projection = token_projection + self.mlp = token_mlp + self.win_size = win_size + self.reso = patch_size + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[: self.num_enc_layers]))] + conv_dpr = [drop_path_rate] * depths[self.num_enc_layers + 1] + dec_dpr = enc_dpr[::-1] + + # Build layers + + # Input + self.input_proj = InputProjection( + in_channels=in_channels, out_channels=embedding_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU + ) + out_channels = out_channels if out_channels else in_channels + # Output + self.output_proj = OutputProjection( + in_channels=2 * embedding_dim, out_channels=out_channels, kernel_size=3, stride=1 + ) + if in_channels != out_channels: + self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0) + self.in_channels = in_channels + self.out_channels = out_channels + + # Encoder + self.encoder_layers = nn.ModuleList() + self.downsamples = nn.ModuleList() + for i in range(self.num_enc_layers): + layer_name = f"encoderlayer_{i}" + layer_input_resolution = (patch_size // (2**i), patch_size // (2**i)) + layer_dim = embedding_dim * (2**i) + layer_depth = depths[i] + layer_drop_path = enc_dpr[sum(depths[:i]) : sum(depths[: i + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[i], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + self.encoder_layers.add_module(layer_name, layer) + + downsample_layer_name = f"downsample_{i}" + downsample_layer = DownSampleBlock(layer_dim, embedding_dim * (2 ** (i + 1))) + self.downsamples.add_module(downsample_layer_name, downsample_layer) + # Bottleneck + self.bottleneck = BasicUFormerLayer( + dim=embedding_dim * (2**self.num_enc_layers), + input_resolution=(patch_size // (2**self.num_enc_layers), patch_size // (2**self.num_enc_layers)), + depth=depths[self.num_enc_layers], + num_heads=num_heads[self.num_enc_layers], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=conv_dpr, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + # Decoder + self.upsamples = nn.ModuleList() + self.decoder_layers = nn.ModuleList() + for i in range(self.num_dec_layers, 0, -1): + upsample_layer_name = f"upsample_{self.num_dec_layers - i}" + if i == self.num_dec_layers: + upsample_in_channels = embedding_dim * (2**i) + else: + upsample_in_channels = embedding_dim * (2 ** (i + 1)) + upsample_out_channels = embedding_dim * (2 ** (i - 1)) + upsample_layer = UpSampleBlock(upsample_in_channels, upsample_out_channels) + self.upsamples.add_module(upsample_layer_name, upsample_layer) + + layer_name = f"decoderlayer_{self.num_dec_layers - i}" + layer_input_resolution = (patch_size // (2 ** (i - 1)), patch_size // (2 ** (i - 1))) + layer_dim = embedding_dim * (2**i) + layer_num = self.num_enc_layers + self.num_dec_layers - i + 1 + layer_depth = depths[layer_num] + if i == self.num_dec_layers: + layer_drop_path = dec_dpr[: depths[layer_num]] + else: + start = self.num_enc_layers + 1 + layer_drop_path = dec_dpr[sum(depths[start:layer_num]) : sum(depths[start : layer_num + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[layer_num], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + self.decoder_layers.add_module(layer_name, layer) + + self.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def extra_repr(self) -> str: + return f"embedding_dim={self.embedding_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" + + def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + input : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + # Input Projection + output = self.input_proj(input) + output = self.pos_drop(output) + + # Encoder + stack = [] + for encoder_layer, downsample in zip(self.encoder_layers, self.downsamples): + output = encoder_layer(output, mask=mask) + stack.append(output) + output = downsample(output) + # Bottleneck + output = self.bottleneck(output, mask=mask) + + # Decoder + for decoder_layer, upsample in zip(self.decoder_layers, self.upsamples): + downsampled_output = stack.pop() + output = upsample(output) + + output = torch.cat([output, downsampled_output], -1) + output = decoder_layer(output, mask=mask) + + # Output Projection + output = self.output_proj(output) + if self.in_channels != self.out_channels: + input = self.conv_out(input) + return input + output + + def flops(self) -> int: + """Calculate the number of floating point operations in :class:`UFormer`. + + Returns + ------- + flops : int + Number of floating point operations. + """ + flops = 0 + # Input Projection + flops += self.input_proj.flops(self.reso, self.reso) + # Encoder + for i, (encoder_layer, downsample) in enumerate(zip(self.encoder_layers, self.downsamples)): + resolution = self.reso // (2**i) + flops += encoder_layer.flops() + downsample.flops(resolution, resolution) + + # Bottleneck + flops += self.bottleneck.flops() + + # Decoder + for i, upsample, decoder_layer in zip(range(self.num_dec_layers, 0, -1), self.upsamples, self.decoder_layers): + resolution = self.reso // (2**i) + flops += upsample.flops(resolution, resolution) + decoder_layer.flops() + # Output Projection + flops += self.output_proj.flops(self.reso, self.reso) + return flops + + +class UFormerModel(nn.Module): + """U-Former model.""" + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.linear, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.leff, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + normalized: bool = True, + ): + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.linear. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.leff. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + + self.uformer = UFormer( + patch_size, + in_channels, + out_channels, + embedding_dim, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + win_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + patch_norm, + token_projection, + token_mlp, + shift_flag, + modulator, + cross_modulator, + ) + + self.normalized = normalized + + self.padding_factor = win_size * (2 ** len(encoder_depths)) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + + x, _, wpad, hpad = pad_to_square(x, self.padding_factor) + if self.normalized: + x, mean, std = norm(x) + + x = self.uformer(x, mask) + + if self.normalized: + x = unnorm(x, mean, std) + x = unpad(x, wpad, hpad) + + return x + + def flops(self): + return self.uformer.flops() diff --git a/direct/nn/transformers/utils.py b/direct/nn/transformers/utils.py new file mode 100644 index 00000000..897b8f61 --- /dev/null +++ b/direct/nn/transformers/utils.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from math import ceil, floor + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +__all__ = ["init_weights", "norm", "pad", "pad_to_square", "unnorm", "unpad", "DropoutPath"] + + +def pad(x: torch.Tensor, pad_size: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int], tuple[int, int]]: + """Pad the input tensor with zeros to make its spatial dimensions divisible by the pad size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (\*, height, width). + pad_size : tuple[int, int] + Patch size to make dimensions divisible with, as a tuple of integers (pad_height, pad_width). + + Returns + ------- + tuple containing the padded tensor, and the number of pixels padded in the width and height dimensions respectively. + """ + h, w = x.shape[-2:] + hp, wp = pad_size + f1 = ((wp - w % wp) % wp) / 2 + f2 = ((hp - h % hp) % hp) / 2 + wpad = (floor(f1), ceil(f1)) + hpad = (floor(f2), ceil(f2)) + x = F.pad(x, wpad + hpad) + + return x, wpad, hpad + + +def pad_to_square( + inp: torch.Tensor, factor: float +) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]: + """Pad a tensor to a square shape with a given factor. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to pad to square shape. Expected shape is (\*, height, width). + factor : float + The factor to which the input tensor will be padded. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]] + A tuple of two tensors, the first is the input tensor padded to a square shape, and the + second is the corresponding mask for the padded tensor. + + Examples + -------- + 1. + >>> x = torch.rand(1, 3, 224, 192) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, mask.shape + (torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224])) + 2. + >>> x = torch.rand(3, 13, 2, 234, 180) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, wpad, hpad + (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3)) + """ + channels, h, w = inp.shape[-3:] + + # Calculate the maximum size and pad to the next multiple of the factor + x = int(ceil(max(h, w) / float(factor)) * factor) + + # Create a tensor of zeros with the maximum size and copy the input tensor into the center + img = torch.zeros(*inp.shape[:-3], channels, x, x, device=inp.device).type_as(inp) + mask = torch.zeros(*((1,) * (img.ndim - 3)), 1, x, x, device=inp.device).type_as(inp) + + # Compute the offset and copy the input tensor into the center of the zero tensor + offset_h = (x - h) // 2 + offset_w = (x - w) // 2 + hpad = (offset_h, offset_h + h) + wpad = (offset_w, offset_w + w) + img[..., hpad[0] : hpad[1], wpad[0] : wpad[1]] = inp.clone() + mask[..., hpad[0] : hpad[1], wpad[0] : wpad[1]].fill_(1.0) + # Return the padded tensor and the corresponding mask, and padding in spatial dimensions + return ( + img, + 1 - mask, + (wpad[0], wpad[1] - w + (1 if w % 2 != 0 else 0)), + (hpad[0], hpad[1] - h + (1 if h % 2 != 0 else 0)), + ) + + +def unpad(x: torch.Tensor, wpad: tuple[int, int], hpad: tuple[int, int]) -> torch.Tensor: + """Remove the padding added to the input tensor by _pad method. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, H_pad, W_pad). + wpad : tuple[int, int] + Number of pixels padded in the width dimension as a tuple of integers (left_pad, right_pad). + hpad : tuple[int, int] + Number of pixels padded in the height dimension as a tuple of integers (top_pad, bottom_pad). + + Returns + ------- + Tensor with the same shape as the original input tensor, but without the added padding. + """ + return x[..., hpad[0] : x.shape[-2] - hpad[1], wpad[0] : x.shape[-1] - wpad[1]] + + +def norm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalize the input tensor by subtracting the mean and dividing by the standard deviation across each channel and pixel. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, H, W). + + Returns + ------- + tuple containing the normalized tensor, mean tensor and standard deviation tensor. + """ + mean = x.reshape(x.shape[0], 1, 1, -1).mean(-1, keepdim=True) + std = x.reshape(x.shape[0], 1, 1, -1).std(-1, keepdim=True) + x = (x - mean) / std + + return x, mean, std + + +def unnorm(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """Denormalize the input tensor by multiplying with the standard deviation and adding + the mean across each channel and pixel. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, H, W). + mean : torch.Tensor + Mean tensor obtained during normalization. + std : torch.Tensor + Standard deviation tensor obtained during normalization. + + Returns + ------- + Tensor with the same shape as the original input tensor, but denormalized. + """ + return x * std + mean + + +def init_weights(m: nn.Module) -> None: + """Initializes the weights of the network using a truncated normal distribution. + + Parameters + ---------- + m : nn.Module + A module of the network whose weights need to be initialized. + """ + + if isinstance(m, nn.Linear): + init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + init.constant_(m.bias, 0) + init.constant_(m.weight, 1.0) + + +class DropoutPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + """Inits :class:`DropoutPath`. + + Parameters + ---------- + drop_prob : float + Probability of dropping a residual connection. Default: 0.0. + scale_by_keep : bool + Whether to scale the remaining activations by 1 / (1 - drop_prob) to maintain the expected value of + the activations. Default: True. + """ + super(DropoutPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + @staticmethod + def _dropout_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self._dropout_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"dropout_prob={round(self.drop_prob, 3):0.3f}" diff --git a/direct/nn/transformers/vision_transformers.py b/direct/nn/transformers/vision_transformers.py new file mode 100644 index 00000000..696f0336 --- /dev/null +++ b/direct/nn/transformers/vision_transformers.py @@ -0,0 +1,701 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +# Code borrowed from https://github.com/facebookresearch/convit which uses code from +# timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad, unnorm, unpad + +__all__ = ["VisionTransformer", "VisionTransformerModel"] + + +class MLP(nn.Module): + """MLP layer with dropout and activation. + + + Parameters + ---------- + in_features : int + Size of the input feature. + hidden_features : int, optional + Size of the hidden layer feature. If None, then hidden_features = in_features. (Default: None) + out_features : int, optional + Size of the output feature. If None, then out_features = in_features. (Default: None) + act_layer : nn.Module, optional + Activation layer to be used. (Default: nn.GELU) + drop : float, optional + Dropout probability. (Default: 0.) + + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor to the network. + + Returns + ------- + torch.Tensor + Output tensor of the network. + + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GPSA(nn.Module): + """Gated Positional Self-Attention module.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ): + """Inits :class:`GPSA`. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.locality_strength = locality_strength + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) + self.apply(init_weights) + if use_local_init: + self.local_init(locality_strength=locality_strength) + self.current_grid_size = grid_size + + def get_attention(self, x: torch.Tensor) -> torch.Tensor: + """Compute the attention scores for each patch in x. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + + Returns + ------- + torch.Tensor + Attention scores for each patch in x. + """ + B, N, C = x.shape + + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + pos_score = self.pos_proj(self.rel_indices).expand(B, -1, -1, -1).permute(0, 3, 1, 2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1.0 - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + attn = attn / attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + def get_attention_map(self, x: torch.Tensor, return_map: Optional[bool] = False): + """Compute the attention map for the input tensor x. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + return_map : bool, optional + Whether to return the attention map. Default: False. + + Returns + ------- + torch.Tensor + A scalar value representing the average attention distance between patches in the input tensor x. + If `return_map` is True, the method also returns the attention map tensor. + """ + + attn_map = self.get_attention(x).mean(0) # average over batch + distances = self.rel_indices.squeeze()[:, :, -1] ** 0.5 + dist = torch.einsum("nm,hnm->h", (distances, attn_map)) + dist /= distances.size(0) + if return_map: + return dist, attn_map + else: + return dist + + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + """Initializes the parameters for a locally connected attention mechanism. + + Parameters + ---------- + locality_strength : float, optional + A scalar multiplier for the locality distance. Default: 1.0. + + Returns + ------- + None + """ + self.v.weight.data.copy_(torch.eye(self.dim)) + locality_distance = 1 # max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads**0.5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + + # compute the positional projection weights with locality distance + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self) -> None: + """Generates relative positional indices for each patch in the input. + + Returns + ------- + None + """ + H, W = self.current_grid_size + N = H * W + rel_indices = torch.zeros(1, N, N, 3) + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(H, H) + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat_interleave(W, dim=0).repeat_interleave(W, dim=1) + indd = indx**2 + indy**2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.v.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`GPSA`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor: + """ + B, N, C = x.shape + if not hasattr(self, "rel_indices") or self.rel_indices.size(1) != N: + self.get_rel_indices() + + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MHSA(nn.Module): + """Multi-Head Self-Attention (MHSA) module.""" + + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, grid_size=None): + """Inits :class:`MHSA`. + + Parameters + ---------- + dim : int + Number of input features. + num_heads : int + Number of heads in the attention mechanism. Default is 8. + qkv_bias : bool + If True, bias is added to the query, key and value projections. Default is False. + qk_scale : float or None + Scaling factor for the query-key dot product. If None, it is set to + head_dim ** -0.5 where head_dim = dim // num_heads. Default is None. + attn_drop : float + Dropout rate for the attention weights. Default is 0. + proj_drop : float + Dropout rate for the output of the module. Default is 0. + grid_size : Tuple[int, int] or None + If not None, the module is designed to work with a grid of + patches. grid_size is a tuple of the form (H, W) where H and W are the number of patches in + the vertical and horizontal directions respectively. Default is None. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(init_weights) + self.current_grid_size = grid_size + + def get_attention_map(self, x, return_map=False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Compute the attention map of the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C) where B is the batch size, N is the number of + patches and C is the number of input features. + return_map : bool + If True, return the attention map along with the distance. Default is False. + + Returns + ------- + A torch.Tensor of shape (num_heads,) containing the distances between patches if return_map is + False. Otherwise, return a tuple containing the distance and the attention map. The attention + map is a torch.Tensor of shape (num_heads, N, N). + """ + rel_indices = self.get_rel_indices() + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn_map = (q @ k.transpose(-2, -1)) * self.scale + attn_map = attn_map.softmax(dim=-1).mean(0) # average over batch + distances = rel_indices.squeeze()[:, :, -1] ** 0.5 + dist = torch.einsum("nm,hnm->h", (distances, attn_map)) + dist /= distances.size(0) + if return_map: + return dist, attn_map + else: + return dist + + def get_rel_indices(self) -> torch.Tensor: + H, W = self.current_grid_size + N = H * W + rel_indices = torch.zeros(1, N, N, 3) + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(H, H) + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat_interleave(W, dim=0).repeat_interleave(W, dim=1) + indd = indx**2 + indy**2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.qkv.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class VisionTransformerBlock(nn.Module): + """A single transformer block used in the VisionTransformer model.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0.0, + attn_drop: float = 0.0, + dropout_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + use_gpsa: bool = True, + **kwargs, + ): + """Inits :class:`VisionTransformerBlock`. + + Parameters + ---------- + dim : int + The feature dimension. + num_heads : int + The number of attention heads. + mlp_ratio : float, optional + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool, optional + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float, optional + The scale factor for the query-key dot product. Default: None. + drop : float, optional + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop : float, optional + The dropout probability for the attention layer. Default: 0.0. + dropout_path : float, optional + The dropout probability for the dropout path. Default: 0.0. + act_layer : nn.Module, optional + The activation layer used in the MLP. Default: nn.GELU. + norm_layer : nn.Module, optional + The normalization layer used in the block. Default: nn.LayerNorm. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + **kwargs: Additional arguments for the attention layer. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = GPSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + else: + self.attn = MHSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + self.dropout_path = DropoutPath(dropout_path) if dropout_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x: torch.Tensor, grid_size: Tuple[int, int]) -> torch.Tensor: + """Forward pass for the :class:`VisionTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + grid_size : Tuple[int, int] + The size of the grid used by the attention layer. + + Returns + ------- + torch.Tensor: The output tensor. + """ + self.attn.current_grid_size = grid_size + x = x + self.dropout_path(self.attn(self.norm1(x))) + x = x + self.dropout_path(self.mlp(self.norm2(x))) + + return x + + +class PatchEmbedding(nn.Module): + """Image to Patch Embedding.""" + + def __init__(self, patch_size, in_channels, embedding_dim): + """Inits :class:`PatchEmbedding`. + + Parameters + ---------- + patch_size : int or Tuple[int, int] + The patch size. If an int is provided, the patch will be a square. + in_channels : int + Number of input channels. + embedding_dim : int + Dimension of the output embedding. + """ + super().__init__() + self.proj = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`PatchEmbedding`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + Patch embedding. + """ + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer""" + + def __init__( + self, + average_img_size: Union[int, Tuple[int, int]] = 320, + patch_size: Union[int, Tuple[int, int]] = 10, + in_channels: int = 1, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + gpsa_interval: Tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + ): + super().__init__() + + self.depth = depth + embedding_dim *= num_heads + self.num_features = embedding_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embedding = use_pos_embedding + + if isinstance(average_img_size, int): + img_size = (average_img_size, average_img_size) + else: + img_size = average_img_size + + if isinstance(patch_size, int): + self.patch_size = (patch_size, patch_size) + else: + self.patch_size = patch_size + + self.in_channels = in_channels + self.out_channels = out_channels if out_channels else in_channels + + self.patch_embed = PatchEmbedding( + patch_size=self.patch_size, in_channels=in_channels, embedding_dim=embedding_dim + ) + + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embedding: + self.pos_embed = nn.Parameter( + torch.zeros(1, embedding_dim, img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + ) + + init.trunc_normal_(self.pos_embed, std=0.02) + + dpr = [x.item() for x in torch.linspace(0, dropout_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList( + [ + VisionTransformerBlock( + dim=embedding_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + dropout_path=dpr[i], + norm_layer=norm_layer, + use_gpsa=gpsa_interval[0] - 1 <= i < gpsa_interval[1], + **( + {"locality_strength": locality_strength} + if gpsa_interval[0] - 1 <= i < gpsa_interval[1] + else {} + ), + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embedding_dim) + # head + self.feature_info = [dict(num_chs=embedding_dim, reduction=0, module="head")] + self.head = nn.Linear(self.num_features, self.out_channels * self.patch_size[0] * self.patch_size[1]) + + self.head.apply(init_weights) + + def seq2img(self, x: torch.Tensor, img_size: Tuple[int, ...]): + x = x.view(x.shape[0], x.shape[1], self.out_channels, self.patch_size[0], self.patch_size[1]) + x = x.chunk(x.shape[1], dim=1) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3) + x = x.chunk(img_size[0] // self.patch_size[0], dim=3) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3).squeeze(1) + + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def get_head(self) -> nn.Module: + return self.head + + def reset_head(self) -> None: + self.head = nn.Linear(self.num_features, self.out_channels * self.patch_size[0] * self.patch_size[1]) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + _, _, H, W = x.shape + + if self.use_pos_embedding: + pos_embed = F.interpolate(self.pos_embed, size=[H, W], mode="bilinear", align_corners=False) + x = x + pos_embed + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + for _, block in enumerate(self.blocks): + x = block(x, (H, W)) + + x = self.norm(x) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`VisionTransformer`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + _, _, H, W = x.shape + x = self.forward_features(x) + x = self.head(x) + x = self.seq2img(x, (H, W)) + + return x + + +class VisionTransformerModel(VisionTransformer): + def __init__( + self, + average_img_size: Union[int, Tuple[int, int]] = 320, + patch_size: Union[int, Tuple[int, int]] = 10, + in_channels: int = 1, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + gpsa_interval: Tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ): + super().__init__( + average_img_size, + patch_size, + in_channels, + out_channels, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + dropout_path_rate, + norm_layer, + gpsa_interval, + locality_strength, + use_pos_embedding, + ) + self.normalized = normalized + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`VisionTransformerModel`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + _, _, H, W = x.shape + x, wpad, hpad = pad(x, self.patch_size) + + if self.normalized: + x, mean, std = norm(x) + + x = self.forward_features(x) + x = self.head(x) + x = self.seq2img(x, (H, W)) + + if self.normalized: + x = unnorm(x, mean, std) + + x = unpad(x, wpad, hpad) + + return x diff --git a/direct/nn/types.py b/direct/nn/types.py index 8eaf0d90..569ab1d7 100644 --- a/direct/nn/types.py +++ b/direct/nn/types.py @@ -15,6 +15,8 @@ class ModelName(DirectEnum): normunet = "normunet" resnet = "resnet" didn = "didn" + uformer = "uformer" + vision_transformer = "vision_transformer" conv = "conv" diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index 19ccdc95..e7309d84 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -12,6 +12,7 @@ class UnetModel2dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 + cwn_conv: bool = False class NormUnetModel2dConfig(ModelConfig): @@ -21,6 +22,27 @@ class NormUnetModel2dConfig(ModelConfig): num_pool_layers: int = 4 dropout_probability: float = 0.0 norm_groups: int = 2 + cwn_conv: bool = False + + +@dataclass +class UnetModel3dConfig(ModelConfig): + in_channels: int = 2 + out_channels: int = 2 + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 + cwn_conv: bool = False + + +class NormUnetModel3dConfig(ModelConfig): + in_channels: int = 2 + out_channels: int = 2 + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 + norm_groups: int = 2 + cwn_conv: bool = False @dataclass @@ -28,6 +50,7 @@ class Unet2dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 + cwn_conv: bool = False skip_connection: bool = False normalized: bool = False image_initialization: str = "zero_filled" diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 5fd9853c..6f7bd69b 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from direct.data import transforms as T +from direct.nn.conv.conv import CWN_Conv2d, CWN_ConvTranspose2d class ConvBlock(nn.Module): @@ -113,6 +114,107 @@ def __repr__(self): return f"ConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})" +class CWNConvBlock(nn.Module): + """U-Net convolutional block. + + It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): + """Inits ConvBlock. + + Parameters + ---------- + in_channels: int + Number of input channels. + out_channels: int + Number of output channels. + dropout_probability: float + Dropout probability. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + CWN_Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + CWN_Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs the forward pass of :class:`ConvBlock`. + + Parameters + ---------- + input_data: torch.Tensor + + Returns + ------- + torch.Tensor + """ + return self.layers(input_data) + + def __repr__(self): + """Representation of :class:`ConvBlock`.""" + return ( + f"CWNConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"dropout_probability={self.dropout_probability})" + ) + + +class CWNTransposeConvBlock(nn.Module): + """U-Net Transpose Convolutional Block. + + It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_channels: int, out_channels: int): + """Inits :class:`TransposeConvBlock`. + + Parameters + ---------- + in_channels: int + Number of input channels. + out_channels: int + Number of output channels. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.layers = nn.Sequential( + CWN_ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`TransposeConvBlock`. + + Parameters + ---------- + input_data: torch.Tensor + + Returns + ------- + torch.Tensor + """ + return self.layers(input_data) + + def __repr__(self): + """Representation of "class:`TransposeConvBlock`.""" + return f"CWNConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})" + + class UnetModel2d(nn.Module): """PyTorch implementation of a U-Net model based on [1]_. @@ -129,6 +231,7 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, + cwn_conv: bool = False, ): """Inits :class:`UnetModel2d`. @@ -144,6 +247,8 @@ def __init__( Number of down-sampling and up-sampling layers (depth). dropout_probability: float Dropout probability. + cwn_conv : bool + Apply centered weigh normalization to convolutions. Default: False. """ super().__init__() @@ -153,25 +258,32 @@ def __init__( self.num_pool_layers = num_pool_layers self.dropout_probability = dropout_probability - self.down_sample_layers = nn.ModuleList([ConvBlock(in_channels, num_filters, dropout_probability)]) + if cwn_conv: + conv_block = CWNConvBlock + transpose_conv_block = CWNTransposeConvBlock + else: + conv_block = ConvBlock + transpose_conv_block = TransposeConvBlock + + self.down_sample_layers = nn.ModuleList([conv_block(in_channels, num_filters, dropout_probability)]) ch = num_filters for _ in range(num_pool_layers - 1): - self.down_sample_layers += [ConvBlock(ch, ch * 2, dropout_probability)] + self.down_sample_layers += [conv_block(ch, ch * 2, dropout_probability)] ch *= 2 - self.conv = ConvBlock(ch, ch * 2, dropout_probability) + self.conv = conv_block(ch, ch * 2, dropout_probability) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): - self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)] - self.up_conv += [ConvBlock(ch * 2, ch, dropout_probability)] + self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] + self.up_conv += [conv_block(ch * 2, ch, dropout_probability)] ch //= 2 - self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)] + self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] self.up_conv += [ nn.Sequential( - ConvBlock(ch * 2, ch, dropout_probability), - nn.Conv2d(ch, self.out_channels, kernel_size=1, stride=1), + conv_block(ch * 2, ch, dropout_probability), + (CWN_Conv2d if cwn_conv else nn.Conv2d)(ch, self.out_channels, kernel_size=1, stride=1), ) ] @@ -228,6 +340,7 @@ def __init__( num_pool_layers: int, dropout_probability: float, norm_groups: int = 2, + cwn_conv: bool = False, ): """Inits :class:`NormUnetModel2d`. @@ -243,6 +356,8 @@ def __init__( Number of down-sampling and up-sampling layers (depth). dropout_probability: float Dropout probability. + cwn_conv : bool + Apply centered weigh normalization to convolutions. Default: False. norm_groups: int, Number of normalization groups. """ @@ -254,6 +369,7 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, + cwn_conv=cwn_conv, ) self.norm_groups = norm_groups @@ -332,6 +448,7 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, + cwn_conv: bool = False, skip_connection: bool = False, normalized: bool = False, image_initialization: str = "zero_filled", @@ -351,6 +468,8 @@ def __init__( Number of pooling layers. dropout_probability: float Dropout probability. + cwn_conv : bool + Apply centered weigh normalization to convolutions. Default: False. skip_connection: bool If True, skip connection is used for the output. Default: False. normalized: bool @@ -375,6 +494,7 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, + cwn_conv=cwn_conv, ) else: self.unet = UnetModel2d( @@ -383,6 +503,7 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, + cwn_conv=cwn_conv, ) self.forward_operator = forward_operator self.backward_operator = backward_operator diff --git a/direct/nn/unet/unet_3d.py b/direct/nn/unet/unet_3d.py new file mode 100644 index 00000000..1a7bc435 --- /dev/null +++ b/direct/nn/unet/unet_3d.py @@ -0,0 +1,329 @@ +import math +from typing import List, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from direct.nn.conv.conv import CWN_Conv3d, CWN_ConvTranspose3d + + +class ConvBlock3D(nn.Module): + """3D U-Net convolutional block.""" + + def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.layers(input_data) + + +class TransposeConvBlock3D(nn.Module): + """3D U-Net Transpose Convolutional Block.""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.layers = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.layers(input_data) + + +class CWNConvBlock3D(nn.Module): + """U-Net convolutional block for 3D data.""" + + def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + CWN_Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + CWN_Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.layers(input_data) + + def __repr__(self): + return ( + f"CWNConvBlock3D(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"dropout_probability={self.dropout_probability})" + ) + + +class CWNTransposeConvBlock3D(nn.Module): + """U-Net Transpose Convolutional Block for 3D data.""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.layers = nn.Sequential( + CWN_ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.layers(input_data) + + +class UnetModel3d(nn.Module): + """PyTorch implementation of a 3D U-Net model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + cwn_conv: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_filters = num_filters + self.num_pool_layers = num_pool_layers + self.dropout_probability = dropout_probability + + if cwn_conv: + conv_block = CWNConvBlock3D + transpose_conv_block = CWNTransposeConvBlock3D + else: + conv_block = ConvBlock3D + transpose_conv_block = TransposeConvBlock3D + + self.down_sample_layers = nn.ModuleList([conv_block(in_channels, num_filters, dropout_probability)]) + ch = num_filters + for _ in range(num_pool_layers - 1): + self.down_sample_layers += [conv_block(ch, ch * 2, dropout_probability)] + ch *= 2 + self.conv = conv_block(ch, ch * 2, dropout_probability) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] + self.up_conv += [conv_block(ch * 2, ch, dropout_probability)] + ch //= 2 + + self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] + self.up_conv += [ + nn.Sequential( + conv_block(ch * 2, ch, dropout_probability), + nn.Conv3d(ch, out_channels, kernel_size=1, stride=1), + ) + ] + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + stack = [] + output, inp_pad = pad_to_pow_of_2(input_data, self.num_pool_layers) + + # Apply down-sampling layers + for _, layer in enumerate(self.down_sample_layers): + output = layer(output) + stack.append(output) + output = F.avg_pool3d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # Apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + padding = [0, 0, 0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 + if output.shape[-3] != downsample_layer.shape[-3]: + padding[5] = 1 + if sum(padding) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + if sum(inp_pad) != 0: + output = output[ + :, + :, + inp_pad[4] : output.shape[2] - inp_pad[5], + inp_pad[2] : output.shape[3] - inp_pad[3], + inp_pad[0] : output.shape[4] - inp_pad[1], + ] + + return output + + +class NormUnetModel3d(nn.Module): + """Implementation of a Normalized U-Net model for 3D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + norm_groups: int = 2, + cwn_conv: bool = False, + ): + """Inits :class:`NormUnetModel3D`. + + Parameters + ---------- + in_channels: int + Number of input channels to the U-Net. + out_channels: int + Number of output channels to the U-Net. + num_filters: int + Number of output channels of the first convolutional layer. + num_pool_layers: int + Number of down-sampling and up-sampling layers (depth). + dropout_probability: float + Dropout probability. + norm_groups: int, + Number of normalization groups. + cwn_conv : bool + Apply centered weight normalization to convolutions. Default: False. + """ + super().__init__() + + self.unet3d = UnetModel3d( + in_channels=in_channels, + out_channels=out_channels, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + cwn_conv=cwn_conv, + ) + + self.norm_groups = norm_groups + + @staticmethod + def norm(input_data: torch.Tensor, groups: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs group normalization.""" + # Group norm + b, c, z, h, w = input_data.shape + input_data = input_data.reshape(b, groups, -1) + + mean = input_data.mean(-1, keepdim=True) + std = input_data.std(-1, keepdim=True) + + output = (input_data - mean) / std + output = output.reshape(b, c, z, h, w) + + return output, mean, std + + @staticmethod + def unnorm( + input_data: torch.Tensor, + mean: torch.Tensor, + std: torch.Tensor, + groups: int, + ) -> torch.Tensor: + b, c, z, h, w = input_data.shape + input_data = input_data.reshape(b, groups, -1) + return (input_data * std + mean).reshape(b, c, z, h, w) + + @staticmethod + def pad( + input_data: torch.Tensor, + ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int, List[int], List[int]]]: + _, _, z, h, w = input_data.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + z_mult = ((z - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + z_pad = [math.floor((z_mult - z) / 2), math.ceil((z_mult - z) / 2)] + + output = F.pad(input_data, w_pad + h_pad + z_pad) + return output, (h_pad, w_pad, z_pad, h_mult, w_mult, z_mult) + + @staticmethod + def unpad( + input_data: torch.Tensor, + h_pad: List[int], + w_pad: List[int], + z_pad: List[int], + h_mult: int, + w_mult: int, + z_mult: int, + ) -> torch.Tensor: + return input_data[ + ..., z_pad[0] : z_mult - z_pad[1], h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] + ] + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs the forward pass of :class:`NormUnetModel3D`. + + Parameters + ---------- + input_data: torch.Tensor + + Returns + ------- + torch.Tensor + """ + + output, mean, std = self.norm(input_data, self.norm_groups) + output, pad_sizes = self.pad(output) + output = self.unet3d(output) + + output = self.unpad(output, *pad_sizes) + output = self.unnorm(output, mean, std, self.norm_groups) + + return output + + +def pad_to_pow_of_2(inp, k): + diffs = [_ - 2**k for _ in inp.shape[2:]] + padding = [0, 0, 0, 0, 0, 0] + for i, diff in enumerate(diffs[::-1]): + if diff < 1: + padding[2 * i] = abs(diff) // 2 + padding[2 * i + 1] = abs(diff) - padding[2 * i] + + if sum(padding) > 0: + inp = F.pad(inp, padding) + + return inp, padding diff --git a/direct/nn/varsplitnet/config.py b/direct/nn/varsplitnet/config.py index d4fdeb3f..7934d4d5 100644 --- a/direct/nn/varsplitnet/config.py +++ b/direct/nn/varsplitnet/config.py @@ -17,31 +17,68 @@ class MRIVarSplitNetConfig(ModelConfig): kspace_no_parameter_sharing: bool = True image_model_architecture: str = ModelName.unet kspace_model_architecture: Optional[str] = None - image_resnet_hidden_channels: Optional[int] = 128 - image_resnet_num_blocks: Optional[int] = 15 - image_resnet_batchnorm: Optional[bool] = True - image_resnet_scale: Optional[float] = 0.1 - image_unet_num_filters: Optional[int] = 32 - image_unet_num_pool_layers: Optional[int] = 4 - image_unet_dropout: Optional[float] = 0.0 - image_didn_hidden_channels: Optional[int] = 16 - image_didn_num_dubs: Optional[int] = 6 - image_didn_num_convs_recon: Optional[int] = 9 - kspace_resnet_hidden_channels: Optional[int] = 64 - kspace_resnet_num_blocks: Optional[int] = 1 - kspace_resnet_batchnorm: Optional[bool] = True - kspace_resnet_scale: Optional[float] = 0.1 - kspace_unet_num_filters: Optional[int] = 16 - kspace_unet_num_pool_layers: Optional[int] = 4 - kspace_unet_dropout: Optional[float] = 0.0 - kspace_didn_hidden_channels: Optional[int] = 8 - kspace_didn_num_dubs: Optional[int] = 6 - kspace_didn_num_convs_recon: Optional[int] = 9 - image_conv_hidden_channels: Optional[int] = 64 - image_conv_n_convs: Optional[int] = 15 - image_conv_activation: Optional[str] = ActivationType.relu - image_conv_batchnorm: Optional[bool] = False - kspace_conv_hidden_channels: Optional[int] = 64 - kspace_conv_n_convs: Optional[int] = 15 - kspace_conv_activation: Optional[str] = ActivationType.prelu - kspace_conv_batchnorm: Optional[bool] = False + image_resnet_hidden_channels: int = 128 + image_resnet_num_blocks: int = 15 + image_resnet_batchnorm: bool = True + image_resnet_scale: float = 0.1 + image_unet_num_filters: int = 32 + image_unet_num_pool_layers: int = 4 + image_unet_dropout: float = 0.0 + image_unet_cwn_conv: bool = False + image_didn_hidden_channels: int = 16 + image_didn_num_dubs: int = 6 + image_didn_num_convs_recon: int = 9 + kspace_resnet_hidden_channels: int = 64 + kspace_resnet_num_blocks: int = 1 + kspace_resnet_batchnorm: bool = True + kspace_resnet_scale: float = 0.1 + kspace_unet_num_filters: int = 16 + kspace_unet_num_pool_layers: int = 4 + kspace_unet_dropout: float = 0.0 + kspace_didn_hidden_channels: int = 8 + kspace_didn_num_dubs: int = 6 + kspace_didn_num_convs_recon: int = 9 + image_conv_hidden_channels: int = 64 + image_conv_n_convs: int = 15 + image_conv_activation: str = ActivationType.relu + image_conv_batchnorm: bool = False + kspace_conv_hidden_channels: int = 64 + kspace_conv_n_convs: int = 15 + kspace_conv_activation: str = ActivationType.prelu + kspace_conv_batchnorm: bool = False + image_uformer_patch_size: int = 256 + image_uformer_embedding_dim: int = 32 + image_uformer_encoder_depths: tuple[int, ...] = (2, 2, 2, 2) + image_uformer_encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8) + image_uformer_bottleneck_depth: int = 2 + image_uformer_bottleneck_num_heads: int = 16 + image_uformer_win_size: int = 8 + image_uformer_mlp_ratio: float = 4.0 + image_uformer_qkv_bias: bool = True + image_uformer_qk_scale: Optional[float] = None + image_uformer_drop_rate: float = 0.0 + image_uformer_attn_drop_rate: float = 0.0 + image_uformer_drop_path_rate: float = 0.1 + image_uformer_patch_norm: bool = True + image_uformer_shift_flag: bool = True + image_uformer_modulator: bool = False + image_uformer_cross_modulator: bool = False + image_uformer_normalized: bool = True + kspace_uformer_patch_size: int = 256 + kspace_uformer_embedding_dim: int = 32 + kspace_uformer_encoder_depths: tuple[int, ...] = (2, 2, 2) + kspace_uformer_encoder_num_heads: tuple[int, ...] = (1, 2, 4) + kspace_uformer_bottleneck_depth: int = 2 + kspace_uformer_bottleneck_num_heads: int = 8 + kspace_uformer_win_size: int = 8 + kspace_uformer_mlp_ratio: float = 4.0 + kspace_uformer_qkv_bias: bool = True + kspace_uformer_qk_scale: Optional[float] = None + kspace_uformer_drop_rate: float = 0.0 + kspace_uformer_attn_drop_rate: float = 0.0 + kspace_uformer_drop_path_rate: float = 0.1 + kspace_uformer_patch_norm: bool = True + kspace_uformer_shift_flag: bool = True + kspace_uformer_modulator: bool = False + kspace_uformer_cross_modulator: bool = False + kspace_uformer_normalized: bool = True diff --git a/direct/nn/varsplitnet/varsplitnet.py b/direct/nn/varsplitnet/varsplitnet.py index f75d80b1..1c5097a2 100644 --- a/direct/nn/varsplitnet/varsplitnet.py +++ b/direct/nn/varsplitnet/varsplitnet.py @@ -46,7 +46,7 @@ def __init__( image_init: str = InitType.sense, no_parameter_sharing: bool = True, image_model_architecture: ModelName = ModelName.unet, - kspace_no_parameter_sharing: Optional[bool] = True, + kspace_no_parameter_sharing: bool = True, kspace_model_architecture: Optional[ModelName] = None, **kwargs, ): @@ -63,14 +63,15 @@ def __init__( self.no_parameter_sharing = no_parameter_sharing - if image_model_architecture not in ["unet", "normunet", "resnet", "didn", "conv"]: + if image_model_architecture not in ["unet", "normunet", "resnet", "didn", "conv", "uformer"]: raise ValueError(f"Invalid value {image_model_architecture} for `image_model_architecture`.") - if kspace_model_architecture not in ["unet", "normunet", "resnet", "didn", "conv", None]: + if kspace_model_architecture not in ["unet", "normunet", "resnet", "didn", "conv", "uformer", None]: raise ValueError(f"Invalid value {kspace_model_architecture} for `kspace_model_architecture`.") image_model, image_model_kwargs = _get_model_config( image_model_architecture, in_channels=4, + out_channels=2, **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, ) for _ in range(self.num_steps_reg if self.no_parameter_sharing else 1): diff --git a/direct/nn/vsharp/__init__.py b/direct/nn/vsharp/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/vsharp/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py new file mode 100644 index 00000000..1b5dc06a --- /dev/null +++ b/direct/nn/vsharp/config.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass +from typing import Optional + +from direct.config.defaults import ModelConfig +from direct.nn.types import ActivationType, ModelName + + +@dataclass +class VSharpNetConfig(ModelConfig): + num_steps: int = 10 + num_steps_dc_gd: int = 8 + image_init: str = "sense" + no_parameter_sharing: bool = True + auxiliary_steps: int = 0 + image_model_architecture: ModelName = ModelName.unet + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.prelu + image_resnet_hidden_channels: int = 128 + image_resnet_num_blocks: int = 15 + image_resnet_batchnorm: bool = True + image_resnet_scale: float = 0.1 + image_unet_num_filters: int = 32 + image_unet_num_pool_layers: int = 4 + image_unet_dropout: float = 0.0 + image_unet_cwn_conv: bool = False + image_didn_hidden_channels: int = 16 + image_didn_num_dubs: int = 6 + image_didn_num_convs_recon: int = 9 + image_conv_hidden_channels: int = 64 + image_conv_n_convs: int = 15 + image_conv_activation: str = ActivationType.relu + image_conv_batchnorm: bool = False + image_uformer_patch_size: int = 256 + image_uformer_embedding_dim: int = 32 + image_uformer_encoder_depths: tuple[int, ...] = (2, 2, 2, 2) + image_uformer_encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8) + image_uformer_bottleneck_depth: int = 2 + image_uformer_bottleneck_num_heads: int = 16 + image_uformer_win_size: int = 8 + image_uformer_mlp_ratio: float = 4.0 + image_uformer_qkv_bias: bool = True + image_uformer_qk_scale: Optional[float] = None + image_uformer_drop_rate: float = 0.0 + image_uformer_attn_drop_rate: float = 0.0 + image_uformer_drop_path_rate: float = 0.1 + image_uformer_patch_norm: bool = True + image_uformer_shift_flag: bool = True + image_uformer_modulator: bool = False + image_uformer_cross_modulator: bool = False + image_uformer_normalized: bool = True + kspace_model_architecture: Optional[str] = None + kspace_resnet_hidden_channels: int = 64 + kspace_resnet_num_blocks: int = 1 + kspace_resnet_batchnorm: bool = True + kspace_resnet_scale: float = 0.1 + kspace_unet_num_filters: int = 16 + kspace_unet_num_pool_layers: int = 4 + kspace_unet_dropout: float = 0.0 + kspace_didn_hidden_channels: int = 8 + kspace_didn_num_dubs: int = 6 + kspace_didn_num_convs_recon: int = 9 + kspace_conv_hidden_channels: int = 64 + kspace_conv_n_convs: int = 15 + kspace_conv_activation: str = ActivationType.prelu + kspace_conv_batchnorm: bool = False + kspace_uformer_patch_size: int = 256 + kspace_uformer_embedding_dim: int = 32 + kspace_uformer_encoder_depths: tuple[int, ...] = (2, 2, 2) + kspace_uformer_encoder_num_heads: tuple[int, ...] = (1, 2, 4) + kspace_uformer_bottleneck_depth: int = 2 + kspace_uformer_bottleneck_num_heads: int = 8 + kspace_uformer_win_size: int = 8 + kspace_uformer_mlp_ratio: float = 4.0 + kspace_uformer_qkv_bias: bool = True + kspace_uformer_qk_scale: Optional[float] = None + kspace_uformer_drop_rate: float = 0.0 + kspace_uformer_attn_drop_rate: float = 0.0 + kspace_uformer_drop_path_rate: float = 0.1 + kspace_uformer_patch_norm: bool = True + kspace_uformer_shift_flag: bool = True + kspace_uformer_modulator: bool = False + kspace_uformer_cross_modulator: bool = False + kspace_uformer_normalized: bool = True + image_vision_transformer_average_img_size: int = 320 + image_vision_transformer_patch_size: int = 10 + image_vision_transformer_embedding_dim: int = 64 + image_vision_transformer_depth: int = 8 + image_vision_transformer_num_heads: int = 9 + image_vision_transformer_mlp_ratio: float = 4.0 + image_vision_transformer_qkv_bias: bool = False + image_vision_transformer_qk_scale: Optional[float] = None + image_vision_transformer_drop_rate: float = (0.0,) + image_vision_transformer_attn_drop_rate: float = (0.0,) + image_vision_transformer_dropout_path_rate: float = 0.0 + image_vision_transformer_gpsa_interval: tuple[int, int] = (-1, -1) + image_vision_transformer_locality_strength: float = 1.0 + image_vision_transformer_use_pos_embedding: bool = True + image_vision_transformer_normalized: bool = True + kspace_vision_transformer_average_img_size: int = 320 + kspace_vision_transformer_patch_size: int = 10 + kspace_vision_transformer_embedding_dim: int = 64 + kspace_vision_transformer_depth: int = 8 + kspace_vision_transformer_num_heads: int = 9 + kspace_vision_transformer_mlp_ratio: float = 4.0 + kspace_vision_transformer_qkv_bias: bool = False + kspace_vision_transformer_qk_scale: Optional[float] = None + kspace_vision_transformer_drop_rate: float = (0.0,) + kspace_vision_transformer_attn_drop_rate: float = (0.0,) + kspace_vision_transformer_dropout_path_rate: float = 0.0 + kspace_vision_transformer_gpsa_interval: tuple[int, int] = (-1, -1) + kspace_vision_transformer_locality_strength: float = 1.0 + kspace_vision_transformer_use_pos_embedding: bool = True + kspace_vision_transformer_normalized: bool = True + + +@dataclass +class VSharpNet3DConfig(ModelConfig): + num_steps: int = 8 + num_steps_dc_gd: int = 6 + image_init: str = "sense" + no_parameter_sharing: bool = True + auxiliary_steps: int = -1 + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.prelu + unet_num_filters: int = 32 + unet_num_pool_layers: int = 4 + unet_dropout: float = 0.0 + unet_cwn_conv: bool = False + unet_norm: bool = False diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py new file mode 100644 index 00000000..5e20d425 --- /dev/null +++ b/direct/nn/vsharp/vsharp.py @@ -0,0 +1,601 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + + +from __future__ import annotations + +from typing import Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from direct.constants import COMPLEX_SIZE +from direct.data.transforms import apply_mask, expand_operator, reduce_operator +from direct.nn.get_nn_model_config import ModelName, _get_activation, _get_model_config +from direct.nn.types import ActivationType, InitType +from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d + + +class LagrangeMultipliersInitializer(nn.Module): + """A convolutional neural network model that initializers the Lagrange multiplier of the vSHARPNet.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: ActivationType = ActivationType.prelu, + ): + """Inits :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv2d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + + self.activation = _get_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, height, width). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, height, width). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet(nn.Module): + """Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI. + + Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves the augmented Lagrangian derivation + of the variable half quadratic splitting problem using ADMM (Alternating Direction Method of Multipliers). + It is designed for solving inverse problems in magnetic resonance imaging (MRI). + + The VSharpNet model incorporates an iterative optimization algorithm that consists of three steps: z-step, x-step, + and u-step: + + .. math :: + \vec{z}^{t+1} = \argmin_{\vec{z}}\, \lambda \, \mathcal{G}(\vec{z}) + + \frac{\rho}{2} \big | \big | \vec{x}^{t} - \vec{z} + \frac{\vec{u}^t}{\rho} \big | \big |_2^2 + \quad \Big[\vec{z}\text{-step}\Big] + \vec{x}^{t+1} = \argmin_{\vec{x}}\, \frac{1}{2} \big | \big | \mathcal{A}_{\mat{U},\mat{S}}(\vec{x}) - + \tilde{\vec{y}} \big | \big |_2^2 + \frac{\rho}{2} \big | \big | \vec{x} - \vec{z}^{t+1} + + \frac{\vec{u}^t}{\rho} \big | \big |_2^2 \quad \Big[\vec{x}\text{-step}\Big] + \vec{u}^{t+1} = \vec{u}^t + \rho (\vec{x}^{t+1} - \vec{z}^{t+1}) \quad \Big[\vec{u}\text{-step}\Big] + + + In the z-step, the model minimizes the augmented Lagrangian function with respect to z using DL based + denoisers. + + In the x-step, it optimizes x by minimizing the data consistency term by unrolling a + gradient descent scheme (DC-GD). + + In the u-step, the model updates the Lagrange multiplier u. These steps are performed iteratively for + a specified number of steps. + + The VSharpNet model supports both image and k-space domain parameterizations. It includes an initializer for + Lagrange multipliers. + + It can also incorporate auxiliary steps during training for improved performance. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_steps: int, + num_steps_dc_gd: int, + image_init: str = InitType.sense, + no_parameter_sharing: bool = True, + image_model_architecture: ModelName = ModelName.unet, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.prelu, + kspace_no_parameter_sharing: bool = True, + kspace_model_architecture: Optional[ModelName] = None, + auxiliary_steps: int = 0, + **kwargs, + ): + """Inits :class:`VSharpNet`. + + Parameters + ---------- + forward_operator : Callable + Forward operator function. + backward_operator : Callable + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + image_model_architecture : ModelName + Image model architecture. Default: ModelName.unet. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.relu. + kspace_no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the k-space denoiser. Ignored if input for + `kspace_model_architecture` is None. Default: True. + kspace_model_architecture : ModelName, optional + K-space model architecture. Default: None. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. + **kwargs: Additional keyword arguments. + """ + super().__init__() + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + if image_model_architecture not in [ + "unet", + "normunet", + "resnet", + "didn", + "conv", + "uformer", + "vision_transformer", + ]: + raise ValueError(f"Invalid value {image_model_architecture} for `image_model_architecture`.") + if kspace_model_architecture not in [ + "unet", + "normunet", + "resnet", + "didn", + "conv", + "uformer", + "vision_transformer", + None, + ]: + raise ValueError(f"Invalid value {kspace_model_architecture} for `kspace_model_architecture`.") + + image_model, image_model_kwargs = _get_model_config( + image_model_architecture, + in_channels=COMPLEX_SIZE * 4 if kspace_model_architecture else COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, + ) + + if kspace_model_architecture: + self.kspace_no_parameter_sharing = kspace_no_parameter_sharing + kspace_model, kspace_model_kwargs = _get_model_config( + kspace_model_architecture, + in_channels=COMPLEX_SIZE, + out_channels=COMPLEX_SIZE, + **{k.replace("kspace_", ""): v for (k, v) in kwargs.items() if "kspace_" in k}, + ) + self.kspace_denoiser = kspace_model(**kspace_model_kwargs) + self.scale_k = nn.Parameter(torch.ones(1, requires_grad=True)) + nn.init.trunc_normal_(self.scale_k, 0, 0.1, 0.0) + else: + self.kspace_denoiser = None + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append(image_model(**image_model_kwargs)) + + self.initializer = LagrangeMultipliersInitializer( + COMPLEX_SIZE, + COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in ["sense", "zero_filled"]: + raise ValueError(f"Unknown image_initialization. Expected 'sense' or 'zero_filled'. " f"Got {image_init}.") + + self.image_init = image_init + + if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`MRIVarSplitNet`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. + sampling_mask: torch.Tensor + + Returns + ------- + image: torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + out = [] + if self.image_init == "sense": + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + for admm_step in range(self.num_steps): + if self.kspace_denoiser: + kspace_z = self.kspace_denoiser( + self.forward_operator(z.contiguous(), dim=[_ - 1 for _ in self._spatial_dims]).permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + kspace_z = self.backward_operator(kspace_z.contiguous(), dim=[_ - 1 for _ in self._spatial_dims]) + + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]] + ([self.scale_k * kspace_z] if self.kspace_denoiser else []), + dim=self._complex_dim, + ).permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out + + +class LagrangeMultipliersInitializer3D(torch.nn.Module): + """A convolutional neural network model that initializes the Lagrange multiplier of the vSHARPNet for 3D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: nn.Module = nn.PReLU(), + ): + """Initializes LagrangeMultipliersInitializer3D. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + activation : nn.Module + Activation function. Default: PReLU. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad3d(curr_dilations), + nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv3d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + self.activation = _get_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of LagrangeMultipliersInitializer3D. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, z, x, y). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, z, x, y). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet3D(nn.Module): + """VharpNet 3D version.""" + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_steps: int, + num_steps_dc_gd: int, + image_init: str = InitType.sense, + no_parameter_sharing: bool = True, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.prelu, + auxiliary_steps: int = -1, + unet_num_filters: int = 32, + unet_num_pool_layers: int = 4, + unet_dropout: float = 0.0, + unet_cwn_conv: bool = False, + unet_norm: bool = False, + **kwargs, + ): + """Inits :class:`VSharpNet`. + + Parameters + ---------- + forward_operator : Callable + Forward operator function. + backward_operator : Callable + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + image_model_architecture : ModelName + Image model architecture. Default: ModelName.unet. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.relu. + kspace_no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the k-space denoiser. Ignored if input for + `kspace_model_architecture` is None. Default: True. + kspace_model_architecture : ModelName, optional + K-space model architecture. Default: None. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. + **kwargs: Additional keyword arguments. + """ + super().__init__() + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + unet = UnetModel3d if not unet_norm else NormUnetModel3d + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append( + unet( + in_channels=COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + num_filters=unet_num_filters, + num_pool_layers=unet_num_pool_layers, + dropout_probability=unet_dropout, + cwn_conv=unet_cwn_conv, + ) + ) + + self.initializer = LagrangeMultipliersInitializer3D( + COMPLEX_SIZE, + COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in ["sense", "zero_filled"]: + raise ValueError(f"Unknown image_initialization. Expected 'sense' or 'zero_filled'. " f"Got {image_init}.") + + self.image_init = image_init + + if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (3, 4) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`MRIVarSplitNet`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. + sampling_mask: torch.Tensor + + Returns + ------- + image: torch.Tensor + Output image of shape (N, slice, height, width, complex=2). + """ + out = [] + if self.image_init == "sense": + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) + + for admm_step in range(self.num_steps): + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]], + dim=self._complex_dim, + ).permute(0, 4, 1, 2, 3) + ).permute(0, 2, 3, 4, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py new file mode 100644 index 00000000..e314618f --- /dev/null +++ b/direct/nn/vsharp/vsharp_engine.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from torch import nn +from torch.cuda.amp import autocast + +from direct.config import BaseConfig +from direct.data import transforms as T +from direct.engine import DoIterationOutput +from direct.nn.mri_models import MRIModelEngine +from direct.types import TensorOrNone +from direct.utils import detach_dict, dict_to_device + + +class VSharpNet3DEngine(MRIModelEngine): + """VSharpNet Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VSharpNetEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable, optional + The forward operator. Default: None. + backward_operator: Callable, optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (3, 4) + + def _do_iteration( + self, + data: Dict[str, Any], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + """Performs forward method and calculates loss functions. + + Parameters + ---------- + data : Dict[str, Any] + Data containing keys with values tensors such as k-space, image, sensitivity map, etc. + loss_fns : Optional[Dict[str, Callable]] + Callable loss functions. + regularizer_fns : Optional[Dict[str, Callable]] + Callable regularization functions. + + Returns + ------- + DoIterationOutput + Contains outputs. + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + data = dict_to_device(data, self.device) + + output_image: TensorOrNone + output_kspace: TensorOrNone + + with autocast(enabled=self.mixed_precision): + output_images, output_kspace = self.forward_function(data) + output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images] + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + for i in range(len(output_images)): + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + + loss = sum(loss_dict.values()) # type: ignore + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + + output_image = output_images[-1] + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict}, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex[=2]) + + output_image = output_images[-1] + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_images, output_kspace + + +class VSharpNetEngine(MRIModelEngine): + """VSharpNet Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VSharpNetEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable, optional + The forward operator. Default: None. + backward_operator: Callable, optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def _do_iteration( + self, + data: Dict[str, Any], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + """Performs forward method and calculates loss functions. + + Parameters + ---------- + data : Dict[str, Any] + Data containing keys with values tensors such as k-space, image, sensitivity map, etc. + loss_fns : Optional[Dict[str, Callable]] + Callable loss functions. + regularizer_fns : Optional[Dict[str, Callable]] + Callable regularization functions. + + Returns + ------- + DoIterationOutput + Contains outputs. + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + data = dict_to_device(data, self.device) + + output_image: TensorOrNone + output_kspace: TensorOrNone + + with autocast(enabled=self.mixed_precision): + output_images, output_kspace = self.forward_function(data) + output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images] + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + for i in range(len(output_images)): + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + + loss = sum(loss_dict.values()) # type: ignore + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + + output_image = output_images[-1] + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict}, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex[=2]) + + output_image = output_images[-1] + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_images, output_kspace diff --git a/direct/predict.py b/direct/predict.py index b5759f5a..1bb1cfa6 100644 --- a/direct/predict.py +++ b/direct/predict.py @@ -18,7 +18,8 @@ def _get_transforms(env): dataset_cfg = env.cfg.inference.dataset - mask_func = build_masking_function(**dataset_cfg.transforms.masking) + masking = dataset_cfg.transforms.masking # Can be None + mask_func = None if masking is None else build_masking_function(**masking) transforms = build_inference_transforms(env, mask_func, dataset_cfg) return dataset_cfg, transforms diff --git a/direct/train.py b/direct/train.py index dad9bf1f..ed0216b5 100644 --- a/direct/train.py +++ b/direct/train.py @@ -75,11 +75,13 @@ def get_root_of_file(filename: PathOrString): def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable: + masking = dataset_config.transforms.masking # Masking func can be None + mask_func = None if masking is None else build_masking_function(**masking) mri_transforms_func = functools.partial( build_mri_transforms, forward_operator=env.engine.forward_operator, backward_operator=env.engine.backward_operator, - mask_func=build_masking_function(**dataset_config.transforms.masking), + mask_func=mask_func, ) return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore diff --git a/direct/types.py b/direct/types.py index d51dcf2b..b5d73e36 100644 --- a/direct/types.py +++ b/direct/types.py @@ -48,6 +48,16 @@ class KspaceKey(DirectEnum): masked_kspace = "masked_kspace" +class TransformKey(DirectEnum): + sensitivity_map = "sensitivity_map" + target = "target" + kspace = "kspace" + masked_kspace = "masked_kspace" + sampling_mask = "sampling_mask" + acs_mask = "acs_mask" + scaling_factor = "scaling_factor" + + class IntegerListOrTupleStringMeta(type): """Metaclass for the :class:`IntegerListOrTupleString` class. diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py index 2917e96b..8d1164dc 100644 --- a/direct/utils/__init__.py +++ b/direct/utils/__init__.py @@ -232,6 +232,24 @@ def merge_list_of_dicts(list_of_dicts: List[Dict]) -> Dict: return functools.reduce(lambda a, b: {**dict(a), **dict(b)}, list_of_dicts) +def merge_list_of_lists(list_of_lists: List[List]) -> List: + """A list of lists is merged into one list. + + Parameters + ---------- + list_of_lists: List[List] + + Returns + ------- + List + """ + + if not list_of_lists: + return [] + + return functools.reduce(lambda a, b: a + b, list_of_lists) + + def evaluate_dict( fns_dict: Dict[str, Callable], source: torch.Tensor, target: torch.Tensor, reduction: str = "mean" ) -> Dict: @@ -351,7 +369,7 @@ def __init__(self): """Inits DirectTransform.""" super().__init__() self.coil_dim = 1 - self.spatial_dims = (2, 3) + self.spatial_dims = {"2D": (1, 2), "3D": (2, 3)} self.complex_dim = -1 def __repr__(self): @@ -385,6 +403,9 @@ class DirectModule(DirectTransform, abc.ABC, torch.nn.Module): @abc.abstractmethod def __init__(self): super().__init__() + self.coil_dim = 1 + self.spatial_dims = {"2D": (2, 3), "3D": (3, 4)} + self.complex_dim = -1 def forward(self, sample: Dict): pass # This comment passes "Function/method with an empty body PTC-W0049" error. diff --git a/tests/tests_common/test_subsample.py b/tests/tests_common/test_subsample.py index 7d979928..a35da453 100644 --- a/tests/tests_common/test_subsample.py +++ b/tests/tests_common/test_subsample.py @@ -44,6 +44,31 @@ def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim): assert torch.all(mask2 == mask3) +@pytest.mark.parametrize( + "mask_func", + [ + CartesianEquispacedMaskFunc, + CartesianMagicMaskFunc, + CartesianRandomMaskFunc, + ], +) +@pytest.mark.parametrize( + "center_fracs, accelerations, batch_size, dim", + [ + ([10], [4], 4, 320), + ([30, 20], [4, 8], 2, 368), + ], +) +def test_mask_reuse_cartesian(mask_func, center_fracs, accelerations, batch_size, dim): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) + shape = (batch_size, dim, dim, 2) + mask1 = mask_func(shape, seed=123) + mask2 = mask_func(shape, seed=123) + mask3 = mask_func(shape, seed=123) + assert torch.all(mask1 == mask2) + assert torch.all(mask2 == mask3) + + @pytest.mark.parametrize( "mask_func", [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc, Gaussian1DMaskFunc], @@ -108,6 +133,25 @@ def test_apply_mask_cartesian(mask_func, shape, center_fractions, accelerations) ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), ], ) +def test_same_across_volumes_mask_cartesian_fraction_center(mask_func, shape, center_fractions, accelerations): + mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) + num_slices = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + + +@pytest.mark.parametrize( + "mask_func", + [CartesianEquispacedMaskFunc, CartesianMagicMaskFunc, CartesianRandomMaskFunc], +) +@pytest.mark.parametrize( + "shape, center_fractions, accelerations", + [ + ([4, 32, 32, 2], [6], [4]), + ([2, 64, 64, 2], [4, 6], [8, 4]), + ], +) def test_same_across_volumes_mask_cartesian(mask_func, shape, center_fractions, accelerations): mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) num_slices = shape[0] diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index 76b55523..cae13221 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -9,6 +9,7 @@ import pytest import torch +from direct.common.subsample import FastMRIRandomMaskFunc from direct.data.mri_transforms import ( ApplyMask, ApplyZeroPadding, @@ -25,6 +26,7 @@ PadCoilDimension, RandomFlip, RandomFlipType, + RandomReverse, RandomRotation, ReconstructionType, SensitivityMapType, @@ -50,18 +52,18 @@ def create_sample(shape, **kwargs): def _mask_func(shape, seed=None, return_acs=False): - shape = shape[:-1] - mask = torch.zeros(shape).bool() + num_rows, num_cols = shape[:2] + mask = torch.zeros(num_rows, num_cols).bool() mask[ - shape[0] // 2 - shape[0] // 4 : shape[0] // 2 + shape[0] // 4, - shape[1] // 2 - shape[1] // 4 : shape[1] // 2 + shape[1] // 4, + num_rows // 2 - num_rows // 4 : num_rows // 2 + num_rows // 4, + num_cols // 2 - num_cols // 4 : num_cols // 2 + num_cols // 4, ] = True if return_acs: return mask.unsqueeze(0).unsqueeze(-1) if seed: rng = np.random.RandomState() rng.seed(seed) - mask = mask | torch.from_numpy(np.random.rand(*shape)).round().bool() + mask = mask | torch.from_numpy(np.random.rand(num_rows, num_cols)).round().bool() return mask.unsqueeze(0).unsqueeze(-1) @@ -91,18 +93,26 @@ def test_Compose(shape): "shape", [(5, 7, 6), (3, 4, 6, 4)], ) -def test_ComputeZeroPadding(shape): +@pytest.mark.parametrize( + "eps", + [0.00001, None], +) +def test_ComputeZeroPadding(shape, eps): sample = create_sample(shape + (2,)) + transform = ComputeZeroPadding(eps=eps) + if eps: + pad_shape = [1 for _ in range(len(sample["kspace"].shape))] + pad_shape[-2] = sample["kspace"].shape[-2] + pad_shape[-3] = sample["kspace"].shape[-3] + padding = torch.from_numpy(np.random.randn(*pad_shape)).round().bool() + sample["kspace"] = (~padding) * sample["kspace"] - pad_shape = [1 for _ in range(len(sample["kspace"].shape))] - pad_shape[1:-1] = sample["kspace"].shape[1:-1] - padding = torch.from_numpy(np.random.randn(*pad_shape)).round().bool() - sample["kspace"] = (~padding) * sample["kspace"] - - transform = ComputeZeroPadding() - sample = transform(sample) + sample = transform(sample) - assert torch.allclose(sample["padding"], padding) + assert torch.allclose(sample["padding"], padding) + else: + sample = transform(sample) + assert sample == sample @pytest.mark.parametrize( @@ -112,7 +122,8 @@ def test_ComputeZeroPadding(shape): def test_ApplyZeroPadding(shape): sample = create_sample(shape + (2,)) pad_shape = [1 for _ in range(len(sample["kspace"].shape))] - pad_shape[1:-1] = sample["kspace"].shape[1:-1] + pad_shape[-2] = sample["kspace"].shape[-2] + pad_shape[-3] = sample["kspace"].shape[-3] padding = torch.from_numpy(np.random.randn(*pad_shape)).round().bool() sample.update({"padding": padding}) @@ -125,30 +136,31 @@ def test_ApplyZeroPadding(shape): @pytest.mark.parametrize( "shape", - [(1, 4, 6), (5, 7, 6), (2, None, None), (3, 4, 6, 4)], + [(1, 9, 8), (5, 7, 6), (2, None, None), (3, 5, 6, 4), (1, 1, 4, 9)], ) @pytest.mark.parametrize( "return_acs", [True, False], ) -@pytest.mark.parametrize( - "padding", - [None, True], -) @pytest.mark.parametrize( "use_shape", [True, False], ) -def test_CreateSamplingMask(shape, return_acs, padding, use_shape): - sample = create_sample(shape + (2,)) - if padding: - pad_shape = [1 for _ in range(len(sample["kspace"].shape))] - pad_shape[1:-1] = sample["kspace"].shape[1:-1] - sample.update({"padding": torch.from_numpy(np.random.randn(*pad_shape))}) - transform = CreateSamplingMask(mask_func=_mask_func, shape=shape[1:] if use_shape else None, return_acs=return_acs) +def test_CreateSamplingMask(shape, return_acs, use_shape): + shape = shape + (2,) + sample = create_sample(shape) + + transform = CreateSamplingMask( + mask_func=_mask_func, shape=shape[-3:-1] if use_shape else None, return_acs=return_acs + ) sample = transform(sample) assert "sampling_mask" in sample - assert tuple(sample["sampling_mask"].shape) == (1,) + sample["kspace"].shape[1:-1] + (1,) + + mask_shape = torch.ones(len(shape)) + mask_shape[-3] = sample["kspace"].shape[-3] + mask_shape[-2] = sample["kspace"].shape[-2] + assert list(sample["sampling_mask"].shape) == mask_shape.int().tolist() + if return_acs: assert "acs_mask" in sample @@ -177,7 +189,7 @@ def test_ApplyMask(shape): ) @pytest.mark.parametrize( "crop", - [(5, 6), "reconstruction_size", "[5, 6]", "(5, 6)", None, "invalid_key"], + [(5, 6), "reconstruction_size", "[5, 6]"], ) @pytest.mark.parametrize( "image_space_center_crop", @@ -186,8 +198,6 @@ def test_ApplyMask(shape): @pytest.mark.parametrize( "random_crop_sampler_type, random_crop_sampler_gaussian_sigma", [ - ["uniform", None], - ["gaussian", None], ["gaussian", [1.0, 2.0]], ], ) @@ -207,7 +217,7 @@ def test_CropKspace( shape=shape + (2,), sensitivity_map=torch.rand(shape + (2,)), sampling_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), - input_image=torch.rand((1,) + shape[1:] + (2,)), + acs_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), ) args = { "crop": crop, @@ -227,7 +237,8 @@ def test_CropKspace( else: if crop == "reconstruction_size": crop_shape = tuple((d // 2 for d in shape[1:])) - sample.update({"reconstruction_size": crop_shape}) + print(crop_shape) + sample.update({"reconstruction_size": crop_shape + (2,)}) elif isinstance(crop, IntegerListOrTupleString): crop_shape = tuple(IntegerListOrTupleString(crop)) @@ -239,7 +250,74 @@ def test_CropKspace( @pytest.mark.parametrize( "shape", - [(3, 10, 16)], + [(3, 21, 10, 16)], +) +@pytest.mark.parametrize( + "crop", + [(10, 5, 6), "reconstruction_size", "[10, 5, 6]", "(10, 5, 6)", None, "invalid_key"], +) +@pytest.mark.parametrize( + "image_space_center_crop", + [True, False], +) +@pytest.mark.parametrize( + "random_crop_sampler_type, random_crop_sampler_gaussian_sigma", + [ + ["uniform", None], + ["gaussian", None], + ["gaussian", [1.0, 1.0, 2.0]], + ], +) +@pytest.mark.parametrize( + "random_crop_sampler_use_seed", + [True, False], +) +def test_CropKspace3D( + shape, + crop, + image_space_center_crop, + random_crop_sampler_type, + random_crop_sampler_use_seed, + random_crop_sampler_gaussian_sigma, +): + sample = create_sample( + shape=shape + (2,), + sensitivity_map=torch.rand(shape + (2,)), + sampling_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), + acs_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), + ) + args = { + "crop": crop, + "image_space_center_crop": image_space_center_crop, + "random_crop_sampler_type": random_crop_sampler_type, + "random_crop_sampler_use_seed": random_crop_sampler_use_seed, + "random_crop_sampler_gaussian_sigma": random_crop_sampler_gaussian_sigma, + } + crop_shape = crop + if crop is None: + with pytest.raises(ValueError): + transform = CropKspace(**args) + elif crop == "invalid_key": + with pytest.raises(AssertionError): + transform = CropKspace(**args) + sample = transform(sample) + else: + if crop == "reconstruction_size": + crop_shape = tuple((d // 2 for d in shape[1:])) + print(crop_shape) + sample.update({"reconstruction_size": crop_shape + (2,)}) + elif isinstance(crop, IntegerListOrTupleString): + crop_shape = tuple(IntegerListOrTupleString(crop)) + + transform = CropKspace(**args) + + sample = transform(sample) + assert sample["kspace"].shape == (shape[0],) + crop_shape + (2,) + + +@pytest.mark.parametrize( + "shape", + [(3, 10, 16), (3, 11, 10, 16)], ) @pytest.mark.parametrize( "type", @@ -252,38 +330,52 @@ def test_random_flip(shape, type): sample = transform(sample) flipped_kspace = sample["kspace"] if type == "horizontal": - assert np.allclose(np.flip(kspace, 2), flipped_kspace, 0.0001) + assert np.allclose(np.flip(kspace, -3), flipped_kspace, 0.0001) elif type == "vertical": - assert np.allclose(np.flip(kspace, 1), flipped_kspace, 0.0001) + assert np.allclose(np.flip(kspace, -2), flipped_kspace, 0.0001) else: - assert np.allclose(np.flip(kspace, 1), flipped_kspace, 0.0001) | np.allclose( - np.flip(kspace, 2), flipped_kspace, 0.0001 + assert np.allclose(np.flip(kspace, -3), flipped_kspace, 0.0001) | np.allclose( + np.flip(kspace, -2), flipped_kspace, 0.0001 ) @pytest.mark.parametrize( "shape", - [(3, 10, 16)], + [(20, 10, 16), (11, 20, 10, 16)], +) +def test_random_reverse(shape): + sample = create_sample(shape=shape + (2,)) + kspace = sample["kspace"].numpy() + transform = RandomReverse(dim=-3, p=1) + sample = transform(sample) + flipped_kspace = sample["kspace"] + + assert np.allclose(np.flip(kspace, -3), flipped_kspace, 0.0001) + + +@pytest.mark.parametrize( + "shape", + [(3, 10, 16), (3, 11, 10, 16)], ) @pytest.mark.parametrize( "degree", [90, -90, 180], ) def test_random_rotation(shape, degree): - sample = create_sample(shape=shape + (2,), reconstruction_size=shape[1:] + (1,)) + sample = create_sample(shape=shape + (2,), reconstruction_size=shape[1:] + (2,)) kspace = sample["kspace"].numpy() transform = RandomRotation(degrees=(degree,), p=1) sample = transform(sample) rot_kspace = sample["kspace"].numpy() k = degree // 90 - assert np.allclose(np.rot90(kspace, k=k, axes=(1, 2)), rot_kspace, 0.0001) + assert np.allclose(np.rot90(kspace, k=k, axes=(-3, -2)), rot_kspace, 0.0001) @pytest.mark.parametrize( "shape", [ (4, 5, 5), - (4, 6, 4), + (4, 7, 6, 4), ], ) @pytest.mark.parametrize( @@ -308,21 +400,27 @@ def test_ComputeImage(shape, type_recon, complex_output): assert sample["target"].shape == (shape[1:] + (2,) if complex_output else shape[1:]) +# +# @pytest.mark.parametrize( - "shape, spatial_dims", + "shape", [ - [(1, 4, 6), (1, 2)], - [(5, 7, 6), (1, 2)], - [(4, 5, 5), (1, 2)], - [(3, 4, 6, 4), (2, 3)], + (1, 4, 6), + (5, 7, 6), + (4, 5, 5), + (3, 4, 6, 4), ], ) @pytest.mark.parametrize("use_seed", [True, False]) -def test_EstimateBodyCoilImage(shape, spatial_dims, use_seed): - sample = create_sample(shape=shape + (2,), sensitivity_map=torch.rand(shape + (2,))) +def test_EstimateBodyCoilImage(shape, use_seed): + sample = create_sample( + shape=shape + (2,), + sensitivity_map=torch.rand(shape + (2,)), + acs_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), + ) transform = EstimateBodyCoilImage( mask_func=_mask_func, - backward_operator=functools.partial(ifft2, dim=spatial_dims), + backward_operator=functools.partial(ifft2), use_seed=use_seed, ) sample = transform(sample) @@ -330,12 +428,12 @@ def test_EstimateBodyCoilImage(shape, spatial_dims, use_seed): assert sample["body_coil_image"].shape == shape[1:] +# +# @pytest.mark.parametrize( "shape", [ (1, 54, 46), - (5, 37, 26), - (4, 35, 35), ], ) @pytest.mark.parametrize( @@ -380,6 +478,55 @@ def test_EstimateSensitivityMap(shape, type_of_map, gaussian_sigma, espirit_iter assert sample["sensitivity_map"].shape == shape + (2,) +@pytest.mark.parametrize( + "shape", + [ + (4, 20, 35, 35), + ], +) +@pytest.mark.parametrize( + "type_of_map, gaussian_sigma, espirit_iters, expect_error, sense_map_in_sample", + [ + [SensitivityMapType.unit, None, None, False, False], + [SensitivityMapType.rss_estimate, 0.5, None, False, False], + [SensitivityMapType.rss_estimate, None, None, False, False], + [SensitivityMapType.rss_estimate, None, None, False, True], + [SensitivityMapType.espirit, None, 5, True, True], + ], +) +def test_EstimateSensitivityMap3D( + shape, type_of_map, gaussian_sigma, espirit_iters, expect_error, sense_map_in_sample +): + sample = create_sample( + shape=shape + (2,), + acs_mask=torch.rand((1,) + shape[1:] + (1,)).round(), + sampling_mask=torch.rand((1,) + shape[1:] + (1,)).round(), + ) + if sense_map_in_sample: + sample.update({"sensitivity_map": torch.rand(shape + (2,))}) + args = { + "kspace_key": "kspace", + "backward_operator": functools.partial(ifft2), + "type_of_map": type_of_map, + "gaussian_sigma": gaussian_sigma, + "espirit_max_iters": espirit_iters, + "espirit_kernel_size": 3, + } + if expect_error: + with pytest.raises(NotImplementedError): + transform = EstimateSensitivityMap(**args) + sample = transform(sample) + else: + transform = EstimateSensitivityMap(**args) + if shape[0] == 1 or sense_map_in_sample: + with pytest.warns(None): + sample = transform(sample) + else: + sample = transform(sample) + assert "sensitivity_map" in sample + assert sample["sensitivity_map"].shape == shape + (2,) + + @pytest.mark.parametrize( "shape", [(5, 3, 4)], @@ -533,8 +680,8 @@ def test_ToTensor(shape, key, is_multicoil, is_complex, is_scalar): ) def test_build_mri_transforms(shape, spatial_dims, estimate_body_coil_image, image_center_crop): transform = build_mri_transforms( - forward_operator=functools.partial(fft2, dim=spatial_dims), - backward_operator=functools.partial(ifft2, dim=spatial_dims), + forward_operator=functools.partial(fft2), + backward_operator=functools.partial(ifft2), mask_func=_mask_func, crop=None, crop_type="uniform", @@ -552,5 +699,9 @@ def test_build_mri_transforms(shape, spatial_dims, estimate_body_coil_image, ima ) assert sample["masked_kspace"].shape == shape + (2,) assert sample["sensitivity_map"].shape == shape + (2,) - assert sample["sampling_mask"].shape == (1,) + shape[1:] + (1,) assert sample["target"].shape == shape[1:] + + mask_shape = torch.ones(len(shape) + 1).int().tolist() + mask_shape[-3] = shape[-2] + mask_shape[-2] = shape[-1] + assert list(sample["sampling_mask"].shape) == mask_shape diff --git a/tests/tests_functionals/test_hfen.py b/tests/tests_functionals/test_hfen.py new file mode 100644 index 00000000..163828c1 --- /dev/null +++ b/tests/tests_functionals/test_hfen.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import numpy as np +import pytest +import torch +from skimage.color import rgb2gray +from sklearn.datasets import load_sample_image + +from direct.functionals import HFENL1Loss, HFENL2Loss, hfen_l1, hfen_l2 + +# Load two images and convert them to grayscale +flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) +china = rgb2gray(load_sample_image("china.jpg"))[None].astype(np.float32) + + +@pytest.mark.parametrize("image", [flower, china]) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("kernel_size", [10, 15]) +@pytest.mark.parametrize("norm", [True, False]) +def test_hfen_l1(image, reduction, kernel_size, norm): + image = torch.from_numpy(image).unsqueeze(0) + + noise = 0.5 * torch.randn(*image.shape) + image_noise = image + noise + hfenl1loss = HFENL1Loss(reduction=reduction, kernel_size=kernel_size, norm=norm).forward(image_noise, image) + hfenl1metric = hfen_l1(input=image_noise, target=image, reduction=reduction, kernel_size=kernel_size, norm=norm) + assert hfenl1loss == hfenl1metric + + +@pytest.mark.parametrize("image", [flower, china]) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("kernel_size", [10, 15]) +@pytest.mark.parametrize("norm", [True, False]) +def test_hfen_l2(image, reduction, kernel_size, norm): + image = torch.from_numpy(image).unsqueeze(0) + + noise = 0.5 * torch.randn(*image.shape) + image_noise = image + noise + hfenl2loss = HFENL2Loss(reduction=reduction, kernel_size=kernel_size, norm=norm).forward(image_noise, image) + hfenl2metric = hfen_l2(input=image_noise, target=image, reduction=reduction, kernel_size=kernel_size, norm=norm) + assert hfenl2loss == hfenl2metric diff --git a/tests/tests_functionals/test_snr.py b/tests/tests_functionals/test_snr.py new file mode 100644 index 00000000..e71dae96 --- /dev/null +++ b/tests/tests_functionals/test_snr.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import numpy as np +import pytest +import torch +from skimage.color import rgb2gray +from skimage.metrics import peak_signal_noise_ratio +from sklearn.datasets import load_sample_image + +from direct.functionals.snr import SNRLoss + +# Load two images and convert them to grayscale +flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) +china = rgb2gray(load_sample_image("china.jpg"))[None].astype(np.float32) + + +@pytest.mark.parametrize("image", [flower, china]) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +def test_snr(image, reduction): + image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0) + image_noise_batch = [] + single_image_snr = [] + + for sigma in range(0, 101, 20): + noise = sigma * torch.randn(*image.shape) + image_noise = image + noise + snr_torch = SNRLoss(reduction=reduction).forward(image_noise, image) + image_noise_batch.append(image_noise) + single_image_snr.append(snr_torch) + + image_batch = torch.cat([image] * len(image_noise_batch), dim=0) + image_noise_batch = torch.cat(image_noise_batch, dim=0) + snr_batch = SNRLoss(reduction=reduction).forward(image_noise_batch, image_batch) + # Assert that batch snr matches single snrs + assert np.allclose(snr_batch, np.average(single_image_snr), atol=5e-4) diff --git a/tests/tests_functionals/test_ssim.py b/tests/tests_functionals/test_ssim.py index aabc11db..c7b7b733 100644 --- a/tests/tests_functionals/test_ssim.py +++ b/tests/tests_functionals/test_ssim.py @@ -9,7 +9,7 @@ from sklearn.datasets import load_sample_image from direct.functionals.challenges import calgary_campinas_ssim, fastmri_ssim -from direct.functionals.ssim import SSIMLoss +from direct.functionals.ssim import SSIM3DLoss, SSIMLoss # Load two images and convert them to grayscale flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) @@ -105,7 +105,6 @@ def test_calgary_campinas_ssim(image): def test_fastmri_ssim(image): image_batch = [] image_noise_batch = [] - single_image_ssim = [] for sigma in range(1, 5): noise = sigma * np.random.rand(*image.shape) @@ -129,3 +128,32 @@ def test_fastmri_ssim(image): fastmri_ssim_batch = fastmri_ssim(image_batch_torch, image_noise_batch_torch) assert np.allclose(fastmri_ssim_batch, ssim_skimage_batch, atol=5e-4) + + +@pytest.mark.parametrize("data_range_255", [True, False]) +@pytest.mark.parametrize("win_size", [7]) +@pytest.mark.parametrize("k1, k2", [[0.01, 0.03], [0.05, 0.1]]) +def test_ssim_3de(data_range_255, win_size, k1, k2): + image = torch.from_numpy(np.concatenate([flower, china] * 4, 0)).unsqueeze(0).unsqueeze(0) + image_noise_batch = [] + + single_image_ssim = [] + + for sigma in range(0, 101, 20): + noise = sigma * torch.randn(*image.shape) + image_noise = image + noise + ssim_torch = 1 - SSIM3DLoss(win_size=win_size, k1=k1, k2=k2).forward( + image_noise, image, data_range=torch.tensor([255 if data_range_255 else image.max()]) + ) + image_noise_batch.append(image_noise) + single_image_ssim.append(ssim_torch) + + image_batch = torch.cat([image] * len(image_noise_batch), dim=0) + image_noise_batch = torch.cat(image_noise_batch, dim=0) + ssim_batch = 1 - SSIM3DLoss(win_size=win_size, k1=k1, k2=k2).forward( + X=image_noise_batch, + Y=image_batch, + data_range=torch.tensor([255]) if data_range_255 else image_batch.amax((1, 2, 3, 4)), + ) + # Assert that batch ssim matches single ssims + assert np.allclose(ssim_batch, np.average(single_image_ssim), atol=5e-4) diff --git a/tests/tests_nn/test_conjgradnet_engine.py b/tests/tests_nn/test_conjgradnet_engine.py index 33f8f35e..0eadad71 100644 --- a/tests/tests_nn/test_conjgradnet_engine.py +++ b/tests/tests_nn/test_conjgradnet_engine.py @@ -68,6 +68,7 @@ def test_resnetconjgrad_engine( config = DefaultConfig(training=training_config, validation=validation_config) # Define engine engine = ConjGradNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_iterdualnet_engine.py b/tests/tests_nn/test_iterdualnet_engine.py index 8a5ab39b..e4cbbb2b 100644 --- a/tests/tests_nn/test_iterdualnet_engine.py +++ b/tests/tests_nn/test_iterdualnet_engine.py @@ -59,7 +59,7 @@ def test_iterdualnet_engine(shape, loss_fns, num_iter, compute_per_coil): config = DefaultConfig(training=training_config, validation=validation_config, inference=inference_config) # Define engine engine = IterDualNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) - + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_jointicnet_engine.py b/tests/tests_nn/test_jointicnet_engine.py index a4850669..0ba9c95f 100644 --- a/tests/tests_nn/test_jointicnet_engine.py +++ b/tests/tests_nn/test_jointicnet_engine.py @@ -65,7 +65,7 @@ def test_jointicnet_engine( config = DefaultConfig(training=training_config, validation=validation_config, inference=inference_config) # Define engine engine = JointICNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) - + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_kikinet_engine.py b/tests/tests_nn/test_kikinet_engine.py index ff93612e..3496e6b9 100644 --- a/tests/tests_nn/test_kikinet_engine.py +++ b/tests/tests_nn/test_kikinet_engine.py @@ -67,7 +67,7 @@ def test_kikinet_engine(shape, loss_fns, num_iter): config = DefaultConfig(training=training_config, validation=validation_config, inference=inference_config) # Define engine engine = KIKINetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) - + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_lpd_engine.py b/tests/tests_nn/test_lpd_engine.py index 22a669a7..e9f91386 100644 --- a/tests/tests_nn/test_lpd_engine.py +++ b/tests/tests_nn/test_lpd_engine.py @@ -51,6 +51,7 @@ def test_lpd_engine(shape, loss_fns, num_iter, num_primal, num_dual): config = DefaultConfig(training=training_config, validation=validation_config) # Define engine engine = LPDNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_multidomainnet_engine.py b/tests/tests_nn/test_multidomainnet_engine.py index 442e8a66..b3f49fb4 100644 --- a/tests/tests_nn/test_multidomainnet_engine.py +++ b/tests/tests_nn/test_multidomainnet_engine.py @@ -70,7 +70,7 @@ def test_multidomainnet_engine(shape, loss_fns, standardization, num_filters, nu config = DefaultConfig(training=training_config, validation=validation_config, inference=inference_config) # Define engine engine = MultiDomainNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) - + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_recurrentvarnet_engine.py b/tests/tests_nn/test_recurrentvarnet_engine.py index 7ecc9226..1a07190e 100644 --- a/tests/tests_nn/test_recurrentvarnet_engine.py +++ b/tests/tests_nn/test_recurrentvarnet_engine.py @@ -73,6 +73,7 @@ def test_recurrentvarnet_engine(shape, loss_fns, num_steps): config = DefaultConfig(training=training_config, validation=validation_config) # Define engine engine = RecurrentVarNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_rim_engine.py b/tests/tests_nn/test_rim_engine.py index c0f62341..1d850cfb 100644 --- a/tests/tests_nn/test_rim_engine.py +++ b/tests/tests_nn/test_rim_engine.py @@ -61,6 +61,7 @@ def test_lpd_engine(shape, loss_fns, length, depth, scale_log): # Define engine engine = RIMEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) engine.ndim = 2 + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_unet_engine.py b/tests/tests_nn/test_unet_engine.py index cb4da6b7..99bc27b3 100644 --- a/tests/tests_nn/test_unet_engine.py +++ b/tests/tests_nn/test_unet_engine.py @@ -66,6 +66,7 @@ def test_unet_engine(shape, loss_fns, num_filters, num_pool_layers, normalized, sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) # Define engine engine = Unet2dEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_varnet_engine.py b/tests/tests_nn/test_varnet_engine.py index 02079644..cf1fcfbe 100644 --- a/tests/tests_nn/test_varnet_engine.py +++ b/tests/tests_nn/test_varnet_engine.py @@ -57,6 +57,7 @@ def test_lpd_engine(shape, loss_fns, num_layers, num_filters, num_pull_layers): config = DefaultConfig(training=training_config, validation=validation_config) # Define engine engine = EndToEndVarNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_varsplitnet_engine.py b/tests/tests_nn/test_varsplitnet_engine.py index 2378f5c0..ec83bc59 100644 --- a/tests/tests_nn/test_varsplitnet_engine.py +++ b/tests/tests_nn/test_varsplitnet_engine.py @@ -69,6 +69,7 @@ def test_varsplitnet_engine(shape, loss_fns, num_steps_reg, num_steps_dc, image_ sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) # Define engine engine = MRIVarSplitNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape, diff --git a/tests/tests_nn/test_vsharp.py b/tests/tests_nn/test_vsharp.py new file mode 100644 index 00000000..4d9e1941 --- /dev/null +++ b/tests/tests_nn/test_vsharp.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.get_nn_model_config import ModelName +from direct.nn.types import InitType +from direct.nn.vsharp.vsharp import VSharpNet, VSharpNet3D + + +def create_input(shape): + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize("shape", [[1, 3, 16, 16]]) +@pytest.mark.parametrize("num_steps", [3]) +@pytest.mark.parametrize("num_steps_dc_gd", [2]) +@pytest.mark.parametrize("image_init", [InitType.sense, InitType.zero_filled]) +@pytest.mark.parametrize( + "image_model_architecture, image_model_kwargs", + [ + [ModelName.unet, {"image_unet_num_filters": 4, "image_unet_num_pool_layers": 2}], + [ModelName.didn, {"image_didn_hidden_channels": 4, "image_didn_num_dubs": 2, "image_didn_num_convs_recon": 2}], + [ + ModelName.uformer, + { + "image_uformer_patch_size": 4, + "image_uformer_embedding_dim": 2, + "image_uformer_encoder_depths": (2,), + "image_uformer_encoder_num_heads": (1,), + "image_uformer_bottleneck_depth": 2, + "image_uformer_bottleneck_num_heads": 2, + "image_uformer_normalized": True, + "image_uformer_win_size": 2, + }, + ], + ], +) +@pytest.mark.parametrize( + "initializer_channels, initializer_dilations", + [ + [(8, 8, 16), (1, 1, 4)], + ], +) +@pytest.mark.parametrize("aux_steps", [-1, 1]) +def test_varsplitnet( + shape, + num_steps, + num_steps_dc_gd, + image_init, + image_model_architecture, + image_model_kwargs, + initializer_channels, + initializer_dilations, + aux_steps, +): + model = VSharpNet( + fft2, + ifft2, + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_init=image_init, + no_parameter_sharing=False, + initializer_channels=initializer_channels, + initializer_dilations=initializer_dilations, + auxiliary_steps=aux_steps, + image_model_architecture=image_model_architecture, + **image_model_kwargs, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + out = model(kspace, sens, mask) + + for i in range(len(out)): + assert list(out[i].shape) == [shape[0]] + shape[2:] + [2] + + +@pytest.mark.parametrize("shape", [[1, 3, 10, 16, 16]]) +@pytest.mark.parametrize("num_steps", [3]) +@pytest.mark.parametrize("num_steps_dc_gd", [2]) +@pytest.mark.parametrize("image_init", [InitType.sense, InitType.zero_filled]) +@pytest.mark.parametrize( + "image_model_kwargs", + [ + {"unet_num_filters": 4, "unet_num_pool_layers": 2}, + ], +) +@pytest.mark.parametrize( + "initializer_channels, initializer_dilations", + [ + [(8, 8, 8, 16), (1, 1, 2, 4)], + ], +) +@pytest.mark.parametrize("aux_steps", [-1, 1]) +def test_varsplitnet3d( + shape, + num_steps, + num_steps_dc_gd, + image_init, + image_model_kwargs, + initializer_channels, + initializer_dilations, + aux_steps, +): + model = VSharpNet3D( + fft2, + ifft2, + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_init=image_init, + no_parameter_sharing=False, + initializer_channels=initializer_channels, + initializer_dilations=initializer_dilations, + auxiliary_steps=aux_steps, + **image_model_kwargs, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + out = model(kspace, sens, mask) + + for i in range(len(out)): + assert list(out[i].shape) == [shape[0]] + shape[2:] + [2] diff --git a/tests/tests_nn/test_vsharp_engine.py b/tests/tests_nn/test_vsharp_engine.py new file mode 100644 index 00000000..c9c646eb --- /dev/null +++ b/tests/tests_nn/test_vsharp_engine.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import functools + +import numpy as np +import pytest +import torch + +from direct.config.defaults import DefaultConfig, FunctionConfig, LossConfig, TrainingConfig, ValidationConfig +from direct.data.transforms import fft2, ifft2 +from direct.nn.vsharp.config import VSharpNet3DConfig, VSharpNetConfig +from direct.nn.vsharp.vsharp import VSharpNet, VSharpNet3D +from direct.nn.vsharp.vsharp_engine import VSharpNet3DEngine, VSharpNetEngine + + +def create_sample(shape, **kwargs): + sample = dict() + sample["masked_kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["sensitivity_map"] = torch.from_numpy(np.random.randn(*shape)).float() + for k, v in locals()["kwargs"].items(): + sample[k] = v + return sample + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_unet_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNetConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_unet_num_filters=num_filters, + image_unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + image_unet_num_filters=model_config.image_unet_num_filters, + image_unet_num_pool_layers=model_config.image_unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(2, 3, 4, 10, 16, 2), (1, 11, 8, 12, 16, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_unet_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNet3DConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + unet_num_filters=num_filters, + unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet3D( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + unet_num_filters=model_config.unet_num_filters, + unet_num_pool_layers=model_config.unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNet3DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 3 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, 1, shape[3], shape[4], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3], shape[4])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) diff --git a/tests/tests_nn/test_xpdnet_engine.py b/tests/tests_nn/test_xpdnet_engine.py index 144039d2..45e5a264 100644 --- a/tests/tests_nn/test_xpdnet_engine.py +++ b/tests/tests_nn/test_xpdnet_engine.py @@ -64,6 +64,7 @@ def test_xpdnet_engine( config = DefaultConfig(training=training_config, validation=validation_config) # Define engine engine = XPDNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 # Test _do_iteration function with a single data batch data = create_sample( shape,