diff --git a/README.rst b/README.rst index 4f1601822..e0879dfcc 100644 --- a/README.rst +++ b/README.rst @@ -25,8 +25,11 @@ DIRECT: Deep Image REConstruction Toolkit ========================================= -``DIRECT`` is a Python, end-to-end pipeline for solving Inverse Problems emerging in Imaging Processing. It is built with PyTorch and stores state-of-the-art Deep Learning imaging inverse problem solvers such as denoising, dealiasing and reconstruction. By defining a base forward linear or non-linear operator, ``DIRECT`` can be used for training models for recovering images such as MRIs from partially observed or noisy input data. -``DIRECT`` stores inverse problem solvers such as the Learned Primal Dual algorithm, Recurrent Inference Machine and Recurrent Variational Network, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020. For a full list of the baselines currently implemented in DIRECT see `here <#baselines-and-trained-models>`_. +``DIRECT`` is a Python, end-to-end pipeline for solving Inverse Problems emerging in Imaging Processing. +It is built with PyTorch and stores state-of-the-art Deep Learning imaging inverse problem solvers such as denoising, dealiasing and reconstruction. +By defining a base forward linear or non-linear operator, ``DIRECT`` can be used for training models for recovering images such as MRIs from partially observed or noisy input data. +``DIRECT`` stores inverse problem solvers such as the vSHARP, Learned Primal Dual algorithm, Recurrent Inference Machine and Recurrent Variational Network, which were part of the winning solutions in Facebook & NYUs FastMRI challenge in 2019, the Calgary-Campinas MRI reconstruction challenge at MIDL 2020 and the CMRxRecon challenge 2023. +For a full list of the baselines currently implemented in DIRECT see `here <#baselines-and-trained-models>`_. .. raw:: html @@ -49,7 +52,7 @@ In the `projects `_ folder Baselines and trained models ---------------------------- -We provide a set of baseline results and trained models in the `DIRECT Model Zoo `_. Baselines and trained models include the `Recurrent Variational Network (RecurrentVarNet) `_, the `Recurrent Inference Machine (RIM) `_, the `End-to-end Variational Network (VarNet) `_, the `Learned Primal Dual Network (LDPNet) `_, the `X-Primal Dual Network (XPDNet) `_, the `KIKI-Net `_, the `U-Net `_, the `Joint-ICNet `_, and the `AIRS Medical fastmri model (MultiDomainNet) `_. +We provide a set of baseline results and trained models in the `DIRECT Model Zoo `_. Baselines and trained models include the `vSHARP `_, `Recurrent Variational Network (RecurrentVarNet) `_, the `Recurrent Inference Machine (RIM) `_, the `End-to-end Variational Network (VarNet) `_, the `Learned Primal Dual Network (LDPNet) `_, the `X-Primal Dual Network (XPDNet) `_, the `KIKI-Net `_, the `U-Net `_, the `Joint-ICNet `_, and the `AIRS Medical fastmri model (MultiDomainNet) `_. License and usage ----------------- @@ -63,15 +66,15 @@ If you use DIRECT in your own research, or want to refer to baseline results pub .. code-block:: BibTeX - @misc{DIRECTTOOLKIT, - doi = {10.21105/joss.04278}, - url = {https://doi.org/10.21105/joss.04278}, - year = {2022}, - publisher = {The Open Journal}, - volume = {7}, - number = {73}, - pages = {4278}, - author = {George Yiasemis and Nikita Moriakov and Dimitrios Karkalousos and Matthan Caan and Jonas Teuwen}, - title = {DIRECT: Deep Image REConstruction Toolkit}, - journal = {Journal of Open Source Software} - } + @article{DIRECTTOOLKIT, + doi = {10.21105/joss.04278}, + url = {https://doi.org/10.21105/joss.04278}, + year = {2022}, + publisher = {The Open Journal}, + volume = {7}, + number = {73}, + pages = {4278}, + author = {George Yiasemis and Nikita Moriakov and Dimitrios Karkalousos and Matthan Caan and Jonas Teuwen}, + title = {DIRECT: Deep Image REConstruction Toolkit}, + journal = {Journal of Open Source Software} + } diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index ecacfadd7..24e5cefb8 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -46,7 +46,7 @@ class NormalizationTransformConfig(BaseConfig): @dataclass class TransformsConfig(BaseConfig): - masking: MaskingConfig = MaskingConfig() + masking: Optional[MaskingConfig] = MaskingConfig() cropping: CropTransformConfig = CropTransformConfig() random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig() padding_eps: float = 0.001 diff --git a/direct/nn/vsharp/__init__.py b/direct/nn/vsharp/__init__.py new file mode 100644 index 000000000..b63c8cc1b --- /dev/null +++ b/direct/nn/vsharp/__init__.py @@ -0,0 +1 @@ +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py new file mode 100644 index 000000000..3acfd84f7 --- /dev/null +++ b/direct/nn/vsharp/config.py @@ -0,0 +1,36 @@ +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig +from direct.nn.types import ActivationType, InitType, ModelName + + +@dataclass +class VSharpNetConfig(ModelConfig): + num_steps: int = 10 + num_steps_dc_gd: int = 8 + image_init: InitType = InitType.SENSE + no_parameter_sharing: bool = True + auxiliary_steps: int = 0 + image_model_architecture: ModelName = ModelName.UNET + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.PRELU + image_resnet_hidden_channels: int = 128 + image_resnet_num_blocks: int = 15 + image_resnet_batchnorm: bool = True + image_resnet_scale: float = 0.1 + image_unet_num_filters: int = 32 + image_unet_num_pool_layers: int = 4 + image_unet_dropout: float = 0.0 + image_didn_hidden_channels: int = 16 + image_didn_num_dubs: int = 6 + image_didn_num_convs_recon: int = 9 + image_conv_hidden_channels: int = 64 + image_conv_n_convs: int = 15 + image_conv_activation: str = ActivationType.RELU + image_conv_batchnorm: bool = False diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py new file mode 100644 index 000000000..cd5bbac20 --- /dev/null +++ b/direct/nn/vsharp/vsharp.py @@ -0,0 +1,321 @@ +# Copyright (c) DIRECT Contributors + +"""This module provides the implementation of the variable Splitting Half-quadratic ADMM algorithm for Reconstruction + of inverse-Problems (vSHARPP) model as presented in [1]_. + +References +---------- +.. [1] George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction +of inverse-Problems (2023). https://arxiv.org/abs/2309.09954. +""" + + +from __future__ import annotations + +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from direct.constants import COMPLEX_SIZE +from direct.data.transforms import apply_mask, expand_operator, reduce_operator +from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation +from direct.nn.types import ActivationType, InitType + + +class LagrangeMultipliersInitializer(nn.Module): + """A convolutional neural network model that initializers the Lagrange multiplier of the vSHARPNet.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: ActivationType = ActivationType.PRELU, + ) -> None: + """Inits :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv2d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + + self.activation = _get_relu_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, height, width). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, height, width). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet(nn.Module): + """ + Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1]_. + + Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves + the augmented Lagrangian derivation of the variable half quadratic splitting problem + using ADMM (Alternating Direction Method of Multipliers). It is specifically designed + for solving inverse problems in magnetic resonance imaging (MRI). + + The VSharpNet model incorporates an iterative optimization algorithm that consists of + three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows: + + .. math:: + + z^{t+1} = \mathrm{argmin}_{z} \\lambda \mathcal{G}(z) + \\frac{\\rho}{2} || x^{t} - z + + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[z-step]} + + .. math:: + + x^{t+1} = \mathrm{argmin}_{x} \\frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \\tilde{y} ||_2^2 + + \\frac{\\rho}{2} || x - z^{t+1} + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[x-step]} + + .. math:: + + u^{t+1} = u^t + \\rho (x^{t+1} - z^{t+1}) \\quad \mathrm{[u-step]} + + During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing + DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through + unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u. + These steps are iterated for a specified number of cycles. + + The model includes an initializer for Lagrange multipliers. + + It also allows for outputting auxiliary steps. + + :class:`VSharpNet` is tailored for 2D MRI data reconstruction. + + References + ---------- + + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_steps: int, + num_steps_dc_gd: int, + image_init: InitType = InitType.SENSE, + no_parameter_sharing: bool = True, + image_model_architecture: ModelName = ModelName.UNET, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.PRELU, + auxiliary_steps: int = 0, + **kwargs, + ) -> None: + """Inits :class:`VSharpNet`. + + Parameters + ---------- + forward_operator : Callable + Forward operator function. + backward_operator : Callable + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + image_model_architecture : ModelName + Image model architecture. Default: ModelName.UNET. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.PRELU. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. If I, the last I steps will be used. + **kwargs: Additional keyword arguments. + Can be `model_name` or `image_model_` where `` represent parameters of the selected + image model architecture beyond the standard parameters. + Depending on the `image_model_architecture` chosen, different kwargs will be applicable. + """ + # pylint: disable=too-many-locals + super().__init__() + for extra_key in kwargs: + if extra_key != "model_name" and not extra_key.startswith("image_"): + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + image_model, image_model_kwargs = _get_model_config( + image_model_architecture, + in_channels=COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, + ) + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append(image_model(**image_model_kwargs)) + + self.initializer = LagrangeMultipliersInitializer( + in_channels=COMPLEX_SIZE, + out_channels=COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]: + raise ValueError( + f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. " + f"Got {image_init}." + ) + + self.image_init = image_init + + if not ((auxiliary_steps == -1) or (0 < auxiliary_steps <= num_steps)): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`VSharpNet`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. + sampling_mask: torch.Tensor + + Returns + ------- + image: torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + out = [] + if self.image_init == "sense": + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + for admm_step in range(self.num_steps): + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]], + dim=self._complex_dim, + ).permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py new file mode 100644 index 000000000..06dc2cd88 --- /dev/null +++ b/direct/nn/vsharp/vsharp_engine.py @@ -0,0 +1,144 @@ +# Copyright (c) DIRECT Contributors + +"""Engine for vSHARP 2D model.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch import nn +from torch.cuda.amp import autocast + +from direct.config import BaseConfig +from direct.data import transforms as T +from direct.engine import DoIterationOutput +from direct.nn.mri_models import MRIModelEngine +from direct.types import TensorOrNone +from direct.utils import detach_dict, dict_to_device + + +class VSharpNetEngine(MRIModelEngine): + """VSharpNet 2D Model Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[callable] = None, + backward_operator: Optional[callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ) -> None: + """Inits :class:`VSharpNetEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: callable, optional + The forward operator. Default: None. + backward_operator: callable, optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models for secondary tasks, such as sensitivity map estimation model. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def _do_iteration( + self, + data: dict[str, Any], + loss_fns: Optional[dict[str, callable]] = None, + regularizer_fns: Optional[dict[str, callable]] = None, + ) -> DoIterationOutput: + """Performs forward method and calculates loss functions. + + Parameters + ---------- + data : dict[str, Any] + Data containing keys with values tensors such as k-space, image, sensitivity map, etc. + loss_fns : Optional[dict[str, callable]] + callable loss functions. + regularizer_fns : Optional[dict[str, callable]] + callable regularization functions. + + Returns + ------- + DoIterationOutput + Contains outputs. + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + data = dict_to_device(data, self.device) + + output_image: TensorOrNone + output_kspace: TensorOrNone + + with autocast(enabled=self.mixed_precision): + output_images, output_kspace = self.forward_function(data) + output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images] + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + for i, output_image in enumerate(output_images): + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_image, None, auxiliary_loss_weights[i] + ) + # Compute loss on k-space + loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace) + + loss = sum(loss_dict.values()) # type: ignore + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + + output_image = output_images[-1] + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict}, + ) + + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]: + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex[=2]) + + output_image = output_images[-1] + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_images, output_kspace diff --git a/direct/predict.py b/direct/predict.py index b5759f5a8..1bb1cfa65 100644 --- a/direct/predict.py +++ b/direct/predict.py @@ -18,7 +18,8 @@ def _get_transforms(env): dataset_cfg = env.cfg.inference.dataset - mask_func = build_masking_function(**dataset_cfg.transforms.masking) + masking = dataset_cfg.transforms.masking # Can be None + mask_func = None if masking is None else build_masking_function(**masking) transforms = build_inference_transforms(env, mask_func, dataset_cfg) return dataset_cfg, transforms diff --git a/direct/train.py b/direct/train.py index dad9bf1f6..31207ba03 100644 --- a/direct/train.py +++ b/direct/train.py @@ -75,11 +75,13 @@ def get_root_of_file(filename: PathOrString): def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable: + masking = dataset_config.transforms.masking # Masking func can be None + mask_func = None if masking is None else build_masking_function(**masking) mri_transforms_func = functools.partial( build_mri_transforms, forward_operator=env.engine.forward_operator, backward_operator=env.engine.backward_operator, - mask_func=build_masking_function(**dataset_config.transforms.masking), + mask_func=mask_func, ) return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore @@ -101,6 +103,11 @@ def build_training_datasets_from_environment( dataset_config.text_description = f"ds{idx}" if len(datasets_config) > 1 else None else: dataset_config.text_description = None + if dataset_config.transforms.masking is None: # type: ignore + logger.info( + "Masking function set to None for %s.", + dataset_config.text_description, # type: ignore + ) transforms = build_transforms_from_environment(env, dataset_config) dataset_args = {"transforms": transforms, "dataset_config": dataset_config} if initial_images is not None: diff --git a/projects/vSHARP/README.rst b/projects/vSHARP/README.rst new file mode 100644 index 000000000..254df525f --- /dev/null +++ b/projects/vSHARP/README.rst @@ -0,0 +1,121 @@ +=============================================================================================== +vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction of inverse-Problems +=============================================================================================== + +This folder contains the training code specific for reproduction of our experiments as presented in our paper +`vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction of inverse-Problems (pre-print) `__. + +.. figure:: https://github.com/NKI-AI/direct/assets/71031687/493701b6-6efa-427d-9b4f-94a0ebcf3142 + :alt: fig + :name: fig1 + + Figure 1: Overview of our proposed method vSHARP. + +Dataset +======= +* For the proposed model, the comparison, and ablation studies we used the `fastMRI prostate T2 dataset `__. +To constract the training, validation and test data we used code provided in https://github.com/cai2r/fastMRI_prostate +from the raw ismrmd data format. + +* We employed a retrospective Cartesian equispaced scheme to undersample our data. + +Training +======== + +Assuming data are stored in ``data_root`` the standard training command ``direct train`` can be used for training. + +Our model and baselines configuration files can be found in the +`vSHARP project folder `_. + +To train vSHARP or the any of the baselines presented in the paper use the following command: + +.. code-block:: bash + + direct train \ + --training-root /.../data_root/ \ + --validation-root /.../data_root/ \ + --cfg projects/vSHARP/fastmri_prostate/base_.yaml \ + --num-gpus \ + --num-workers \ + + +For further information about training see `Training `__. + +During training, training loss, validation metrics and validation image predictions are logged. +Additionally, `Tensorboard `__ allows for visualization of the above. + +Inference +========= + +To perform inference on test set run: + +.. code-block:: bash + + direct predict \ + --checkpoint \ + --cfg projects/vSHARP/fastmri_prostate/base_.yaml \ + --data-root /.../data_root/ \ + --num-gpus \ + --num-workers \ + [--other-flags] + +Note that the above command will produce reconstructions for 4x accelerated data. To change the acceleration faction make +sure to adapt the `inference` field in the respective yaml file. For instance: + +.. code-block:: yaml + + inference: + crop: header + batch_size: 5 + dataset: + name: FastMRI + transforms: + use_seed: True + masking: + name: FastMRIEquispaced + accelerations: [8] + center_fractions: [0.04] + cropping: + crop: null + sensitivity_map_estimation: + estimate_sensitivity_maps: true + normalization: + scaling_key: masked_kspace + scale_percentile: 0.995 + text_description: inference-8x # Description for logging + +can be used for an acceleration factor of 8. + +Citing this work +---------------- + +Please use the following BiBTeX entries if you use vSHARP in your work: + +.. code-block:: BibTeX + + @article{yiasemis2023vsharp, + title = {vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction of inverse-Problems}, + author = {George Yiasemis and Nikita Moriakov and Jan-Jakob Sonke and Jonas Teuwen}, + month = {Sep}, + year = {2023}, + eprint = {2309.09954}, + archivePrefix = {arXiv}, + journal = {arXiv.org}, + doi = {10.48550/arXiv.2309.09954}, + url = {https://doi.org/10.48550/arXiv.2309.09954}, + note = {arXiv:2309.09954 [eess.IV]}, + primaryClass = {eess.IV} + } + + @article{DIRECTTOOLKIT, + doi = {10.21105/joss.04278}, + url = {https://doi.org/10.21105/joss.04278}, + year = {2022}, + publisher = {The Open Journal}, + volume = {7}, + number = {73}, + pages = {4278}, + author = {George Yiasemis and Nikita Moriakov and Dimitrios Karkalousos and Matthan Caan and Jonas Teuwen}, + title = {DIRECT: Deep Image REConstruction Toolkit}, + journal = {Journal of Open Source Software} + } diff --git a/projects/vSHARP/fastmri_prostate/configs/base_recurrentvarnet.yaml b/projects/vSHARP/fastmri_prostate/configs/base_recurrentvarnet.yaml new file mode 100644 index 000000000..a0d00a71b --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/configs/base_recurrentvarnet.yaml @@ -0,0 +1,234 @@ +physics: + forward_operator: fft2 + backward_operator: ifft2 +training: + datasets: + - name: FastMRI + filenames_lists: + - ../lists/train_10_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_14_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_16_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_20_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_24_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_26_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_30_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + batch_size: 1 # This is the batch size per GPU! + optimizer: Adam + lr: 0.002 + weight_decay: 0.0 + lr_step_size: 30000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 2000 + validation_steps: 4000 + loss: + crop: header + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 + - function: hfen_l2_norm_loss + multiplier: 1.0 + - function: hfen_l1_norm_loss + multiplier: 1.0 + - function: kspace_nmae_loss + multiplier: 1.0 + - function: kspace_nmse_loss + multiplier: 1.0 +validation: + datasets: + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + scale_percentile: 0.995 + use_seed: true + text_description: 4x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [8] + center_fractions: [0.04] + scale_percentile: 0.995 + use_seed: true + text_description: 8x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [16] + center_fractions: [0.02] + scale_percentile: 0.995 + use_seed: true + text_description: 16x # Description for logging + crop: header # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - fastmri_psnr + - fastmri_ssim + - fastmri_nmse + batch_size: 5 +model: + model_name: recurrentvarnet.recurrentvarnet.RecurrentVarNet + num_steps: 8 + recurrent_hidden_channels: 128 + recurrent_num_layers: 4 + initializer_initialization: sense + learned_initializer: true + initializer_channels: [32, 32, 64, 64] + initializer_dilations: [1, 1, 2, 4] + initializer_multiscale: 3 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + crop: header + batch_size: 5 + dataset: + name: FastMRI + transforms: + use_seed: True + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + cropping: + crop: null + sensitivity_map_estimation: + estimate_sensitivity_maps: true + normalization: + scaling_key: masked_kspace + scale_percentile: 0.995 + text_description: inference-4x # Description for logging diff --git a/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml b/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml new file mode 100644 index 000000000..2640d747e --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml @@ -0,0 +1,228 @@ +physics: + forward_operator: fft2 + backward_operator: ifft2 +training: + datasets: + - name: FastMRI + filenames_lists: + - ../lists/train_10_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_14_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_16_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_20_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_24_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_26_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_30_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + batch_size: 2 # This is the batch size per GPU! + optimizer: Adam + lr: 0.002 + weight_decay: 0.0 + lr_step_size: 30000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 2000 + validation_steps: 4000 + loss: + crop: header + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 + - function: hfen_l2_norm_loss + multiplier: 1.0 + - function: hfen_l1_norm_loss + multiplier: 1.0 + - function: kspace_nmae_loss + multiplier: 1.0 + - function: kspace_nmse_loss + multiplier: 1.0 +validation: + datasets: + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + scale_percentile: 0.995 + use_seed: true + text_description: 4x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [8] + center_fractions: [0.04] + scale_percentile: 0.995 + use_seed: true + text_description: 8x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [16] + center_fractions: [0.02] + scale_percentile: 0.995 + use_seed: true + text_description: 16x # Description for logging + crop: header # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - fastmri_psnr + - fastmri_ssim + - fastmri_nmse + batch_size: 20 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.0 +model: + model_name: unet.unet_2d.Unet2d + num_filters: 64 + image_initialization: sense +logging: + tensorboard: + num_images: 4 +inference: + crop: header + batch_size: 5 + dataset: + name: FastMRI + transforms: + use_seed: True + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + cropping: + crop: null + sensitivity_map_estimation: + estimate_sensitivity_maps: true + normalization: + scaling_key: masked_kspace + scale_percentile: 0.995 + text_description: inference-4x # Description for logging diff --git a/projects/vSHARP/fastmri_prostate/configs/base_varnet.yaml b/projects/vSHARP/fastmri_prostate/configs/base_varnet.yaml new file mode 100644 index 000000000..f7f45be99 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/configs/base_varnet.yaml @@ -0,0 +1,229 @@ +physics: + forward_operator: fft2 + backward_operator: ifft2 +training: + datasets: + - name: FastMRI + filenames_lists: + - ../lists/train_10_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_14_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_16_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_20_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_24_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_26_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_30_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + batch_size: 2 # This is the batch size per GPU! + optimizer: Adam + lr: 0.002 + weight_decay: 0.0 + lr_step_size: 30000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 2000 + validation_steps: 4000 + loss: + crop: header + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 + - function: hfen_l2_norm_loss + multiplier: 1.0 + - function: hfen_l1_norm_loss + multiplier: 1.0 + - function: kspace_nmae_loss + multiplier: 1.0 + - function: kspace_nmse_loss + multiplier: 1.0 +validation: + datasets: + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + scale_percentile: 0.995 + use_seed: true + text_description: 4x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [8] + center_fractions: [0.04] + scale_percentile: 0.995 + use_seed: true + text_description: 8x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [16] + center_fractions: [0.02] + scale_percentile: 0.995 + use_seed: true + text_description: 16x # Description for logging + crop: header # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - fastmri_psnr + - fastmri_ssim + - fastmri_nmse + batch_size: 20 +model: + model_name: varnet.varnet.EndToEndVarNet + num_layers: 12 + regularizer_num_filters: 64 + regularizer_num_pull_layers: 4 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + crop: header + batch_size: 5 + dataset: + name: FastMRI + transforms: + use_seed: True + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + cropping: + crop: null + sensitivity_map_estimation: + estimate_sensitivity_maps: true + normalization: + scaling_key: masked_kspace + scale_percentile: 0.995 + text_description: inference-4x # Description for logging diff --git a/projects/vSHARP/fastmri_prostate/configs/base_vsharp.yaml b/projects/vSHARP/fastmri_prostate/configs/base_vsharp.yaml new file mode 100644 index 000000000..0426553e6 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/configs/base_vsharp.yaml @@ -0,0 +1,233 @@ +physics: + forward_operator: fft2 + backward_operator: ifft2 +training: + datasets: + - name: FastMRI + filenames_lists: + - ../lists/train_10_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_14_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_16_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_20_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_24_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_26_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + - name: FastMRI + filenames_lists: + - ../lists/train_30_coils.lst + transforms: + crop: reconstruction_size + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + random_flip_probability: 0.5 + random_rotation_probability: 0.5 + masking: + name: FastMRIEquispaced + accelerations: [4, 8, 16] + center_fractions: [0.08, 0.04, 0.02] + scale_percentile: 0.995 + use_seed: false + delete_kspace: false + batch_size: 2 # This is the batch size per GPU! + optimizer: Adam + lr: 0.002 + weight_decay: 0.0 + lr_step_size: 30000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 2000 + validation_steps: 4000 + loss: + crop: header + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 + - function: hfen_l2_norm_loss + multiplier: 1.0 + - function: hfen_l1_norm_loss + multiplier: 1.0 + - function: kspace_nmae_loss + multiplier: 1.0 + - function: kspace_nmse_loss + multiplier: 1.0 +validation: + datasets: + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + scale_percentile: 0.995 + use_seed: true + text_description: 4x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [8] + center_fractions: [0.04] + scale_percentile: 0.995 + use_seed: true + text_description: 8x # Description for logging + - name: FastMRI + transforms: + estimate_sensitivity_maps: true + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + masking: + name: FastMRIEquispaced + accelerations: [16] + center_fractions: [0.02] + scale_percentile: 0.995 + use_seed: true + text_description: 16x # Description for logging + crop: header # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - fastmri_psnr + - fastmri_ssim + - fastmri_nmse + batch_size: 20 +model: + model_name: vsharp.vsharp.VSharpNet + num_steps: 12 + num_steps_dc_gd: 10 + image_init: SENSE + no_parameter_sharing: true + image_model_architecture: UNET + image_unet_num_filters: 32 + auxiliary_steps: -1 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + crop: header + batch_size: 5 + dataset: + name: FastMRI + transforms: + use_seed: True + masking: + name: FastMRIEquispaced + accelerations: [4] + center_fractions: [0.08] + cropping: + crop: null + sensitivity_map_estimation: + estimate_sensitivity_maps: true + normalization: + scaling_key: masked_kspace + scale_percentile: 0.995 + text_description: inference-4x # Description for logging diff --git a/projects/vSHARP/fastmri_prostate/lists/train_10_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_10_coils.lst new file mode 100644 index 000000000..6138cd506 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_10_coils.lst @@ -0,0 +1 @@ +file_prostate_AXT2_0023.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_14_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_14_coils.lst new file mode 100644 index 000000000..587f66c51 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_14_coils.lst @@ -0,0 +1,7 @@ +file_prostate_AXT2_0190.h5 +file_prostate_AXT2_0174.h5 +file_prostate_AXT2_0233.h5 +file_prostate_AXT2_0002.h5 +file_prostate_AXT2_0147.h5 +file_prostate_AXT2_0290.h5 +file_prostate_AXT2_0056.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_16_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_16_coils.lst new file mode 100644 index 000000000..b5d6ebc7d --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_16_coils.lst @@ -0,0 +1,14 @@ +file_prostate_AXT2_0007.h5 +file_prostate_AXT2_0139.h5 +file_prostate_AXT2_0005.h5 +file_prostate_AXT2_0243.h5 +file_prostate_AXT2_0015.h5 +file_prostate_AXT2_0003.h5 +file_prostate_AXT2_0227.h5 +file_prostate_AXT2_0151.h5 +file_prostate_AXT2_0123.h5 +file_prostate_AXT2_0310.h5 +file_prostate_AXT2_0306.h5 +file_prostate_AXT2_0068.h5 +file_prostate_AXT2_0020.h5 +file_prostate_AXT2_0154.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_20_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_20_coils.lst new file mode 100644 index 000000000..5fc5f60d2 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_20_coils.lst @@ -0,0 +1,153 @@ +file_prostate_AXT2_0244.h5 +file_prostate_AXT2_0058.h5 +file_prostate_AXT2_0309.h5 +file_prostate_AXT2_0234.h5 +file_prostate_AXT2_0223.h5 +file_prostate_AXT2_0172.h5 +file_prostate_AXT2_0060.h5 +file_prostate_AXT2_0103.h5 +file_prostate_AXT2_0252.h5 +file_prostate_AXT2_0011.h5 +file_prostate_AXT2_0235.h5 +file_prostate_AXT2_0308.h5 +file_prostate_AXT2_0006.h5 +file_prostate_AXT2_0245.h5 +file_prostate_AXT2_0081.h5 +file_prostate_AXT2_0114.h5 +file_prostate_AXT2_0076.h5 +file_prostate_AXT2_0166.h5 +file_prostate_AXT2_0191.h5 +file_prostate_AXT2_0237.h5 +file_prostate_AXT2_0074.h5 +file_prostate_AXT2_0116.h5 +file_prostate_AXT2_0083.h5 +file_prostate_AXT2_0218.h5 +file_prostate_AXT2_0247.h5 +file_prostate_AXT2_0250.h5 +file_prostate_AXT2_0094.h5 +file_prostate_AXT2_0063.h5 +file_prostate_AXT2_0220.h5 +file_prostate_AXT2_0186.h5 +file_prostate_AXT2_0171.h5 +file_prostate_AXT2_0012.h5 +file_prostate_AXT2_0187.h5 +file_prostate_AXT2_0062.h5 +file_prostate_AXT2_0095.h5 +file_prostate_AXT2_0251.h5 +file_prostate_AXT2_0082.h5 +file_prostate_AXT2_0075.h5 +file_prostate_AXT2_0148.h5 +file_prostate_AXT2_0236.h5 +file_prostate_AXT2_0039.h5 +file_prostate_AXT2_0016.h5 +file_prostate_AXT2_0163.h5 +file_prostate_AXT2_0001.h5 +file_prostate_AXT2_0194.h5 +file_prostate_AXT2_0113.h5 +file_prostate_AXT2_0071.h5 +file_prostate_AXT2_0242.h5 +file_prostate_AXT2_0087.h5 +file_prostate_AXT2_0070.h5 +file_prostate_AXT2_0112.h5 +file_prostate_AXT2_0162.h5 +file_prostate_AXT2_0175.h5 +file_prostate_AXT2_0182.h5 +file_prostate_AXT2_0067.h5 +file_prostate_AXT2_0105.h5 +file_prostate_AXT2_0090.h5 +file_prostate_AXT2_0254.h5 +file_prostate_AXT2_0226.h5 +file_prostate_AXT2_0279.h5 +file_prostate_AXT2_0177.h5 +file_prostate_AXT2_0128.h5 +file_prostate_AXT2_0209.h5 +file_prostate_AXT2_0092.h5 +file_prostate_AXT2_0107.h5 +file_prostate_AXT2_0065.h5 +file_prostate_AXT2_0158.h5 +file_prostate_AXT2_0110.h5 +file_prostate_AXT2_0230.h5 +file_prostate_AXT2_0161.h5 +file_prostate_AXT2_0240.h5 +file_prostate_AXT2_0073.h5 +file_prostate_AXT2_0111.h5 +file_prostate_AXT2_0159.h5 +file_prostate_AXT2_0064.h5 +file_prostate_AXT2_0093.h5 +file_prostate_AXT2_0257.h5 +file_prostate_AXT2_0129.h5 +file_prostate_AXT2_0014.h5 +file_prostate_AXT2_0176.h5 +file_prostate_AXT2_0181.h5 +file_prostate_AXT2_0278.h5 +file_prostate_AXT2_0266.h5 +file_prostate_AXT2_0239.h5 +file_prostate_AXT2_0055.h5 +file_prostate_AXT2_0137.h5 +file_prostate_AXT2_0291.h5 +file_prostate_AXT2_0249.h5 +file_prostate_AXT2_0118.h5 +file_prostate_AXT2_0032.h5 +file_prostate_AXT2_0150.h5 +file_prostate_AXT2_0201.h5 +file_prostate_AXT2_0286.h5 +file_prostate_AXT2_0120.h5 +file_prostate_AXT2_0271.h5 +file_prostate_AXT2_0188.h5 +file_prostate_AXT2_0287.h5 +file_prostate_AXT2_0200.h5 +file_prostate_AXT2_0146.h5 +file_prostate_AXT2_0024.h5 +file_prostate_AXT2_0248.h5 +file_prostate_AXT2_0217.h5 +file_prostate_AXT2_0169.h5 +file_prostate_AXT2_0054.h5 +file_prostate_AXT2_0238.h5 +file_prostate_AXT2_0267.h5 +file_prostate_AXT2_0134.h5 +file_prostate_AXT2_0272.h5 +file_prostate_AXT2_0031.h5 +file_prostate_AXT2_0153.h5 +file_prostate_AXT2_0099.h5 +file_prostate_AXT2_0202.h5 +file_prostate_AXT2_0203.h5 +file_prostate_AXT2_0098.h5 +file_prostate_AXT2_0152.h5 +file_prostate_AXT2_0311.h5 +file_prostate_AXT2_0284.h5 +file_prostate_AXT2_0293.h5 +file_prostate_AXT2_0057.h5 +file_prostate_AXT2_0264.h5 +file_prostate_AXT2_0078.h5 +file_prostate_AXT2_0027.h5 +file_prostate_AXT2_0214.h5 +file_prostate_AXT2_0179.h5 +file_prostate_AXT2_0044.h5 +file_prostate_AXT2_0109.h5 +file_prostate_AXT2_0156.h5 +file_prostate_AXT2_0034.h5 +file_prostate_AXT2_0258.h5 +file_prostate_AXT2_0302.h5 +file_prostate_AXT2_0131.h5 +file_prostate_AXT2_0053.h5 +file_prostate_AXT2_0303.h5 +file_prostate_AXT2_0261.h5 +file_prostate_AXT2_0022.h5 +file_prostate_AXT2_0211.h5 +file_prostate_AXT2_0206.h5 +file_prostate_AXT2_0259.h5 +file_prostate_AXT2_0108.h5 +file_prostate_AXT2_0276.h5 +file_prostate_AXT2_0127.h5 +file_prostate_AXT2_0283.h5 +file_prostate_AXT2_0018.h5 +file_prostate_AXT2_0047.h5 +file_prostate_AXT2_0274.h5 +file_prostate_AXT2_0263.h5 +file_prostate_AXT2_0132.h5 +file_prostate_AXT2_0088.h5 +file_prostate_AXT2_0143.h5 +file_prostate_AXT2_0295.h5 +file_prostate_AXT2_0133.h5 +file_prostate_AXT2_0205.h5 +file_prostate_AXT2_0069.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_24_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_24_coils.lst new file mode 100644 index 000000000..fa7b45ee1 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_24_coils.lst @@ -0,0 +1,4 @@ +file_prostate_AXT2_0101.h5 +file_prostate_AXT2_0221.h5 +file_prostate_AXT2_0033.h5 +file_prostate_AXT2_0213.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_26_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_26_coils.lst new file mode 100644 index 000000000..e6e221f99 --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_26_coils.lst @@ -0,0 +1,34 @@ +file_prostate_AXT2_0080.h5 +file_prostate_AXT2_0253.h5 +file_prostate_AXT2_0097.h5 +file_prostate_AXT2_0102.h5 +file_prostate_AXT2_0170.h5 +file_prostate_AXT2_0246.h5 +file_prostate_AXT2_0255.h5 +file_prostate_AXT2_0197.h5 +file_prostate_AXT2_0106.h5 +file_prostate_AXT2_0208.h5 +file_prostate_AXT2_0025.h5 +file_prostate_AXT2_0270.h5 +file_prostate_AXT2_0043.h5 +file_prostate_AXT2_0119.h5 +file_prostate_AXT2_0030.h5 +file_prostate_AXT2_0122.h5 +file_prostate_AXT2_0145.h5 +file_prostate_AXT2_0280.h5 +file_prostate_AXT2_0126.h5 +file_prostate_AXT2_0277.h5 +file_prostate_AXT2_0199.h5 +file_prostate_AXT2_0296.h5 +file_prostate_AXT2_0130.h5 +file_prostate_AXT2_0035.h5 +file_prostate_AXT2_0157.h5 +file_prostate_AXT2_0045.h5 +file_prostate_AXT2_0155.h5 +file_prostate_AXT2_0301.h5 +file_prostate_AXT2_0142.h5 +file_prostate_AXT2_0021.h5 +file_prostate_AXT2_0051.h5 +file_prostate_AXT2_0046.h5 +file_prostate_AXT2_0019.h5 +file_prostate_AXT2_0282.h5 diff --git a/projects/vSHARP/fastmri_prostate/lists/train_30_coils.lst b/projects/vSHARP/fastmri_prostate/lists/train_30_coils.lst new file mode 100644 index 000000000..68c7f762a --- /dev/null +++ b/projects/vSHARP/fastmri_prostate/lists/train_30_coils.lst @@ -0,0 +1,5 @@ +file_prostate_AXT2_0268.h5 +file_prostate_AXT2_0026.h5 +file_prostate_AXT2_0009.h5 +file_prostate_AXT2_0228.h5 +file_prostate_AXT2_0294.h5 diff --git a/tests/tests_nn/test_vsharp.py b/tests/tests_nn/test_vsharp.py new file mode 100644 index 000000000..2d4d10a23 --- /dev/null +++ b/tests/tests_nn/test_vsharp.py @@ -0,0 +1,96 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for the direct.nn.vsharp module.""" + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.get_nn_model_config import ModelName +from direct.nn.types import InitType +from direct.nn.vsharp.vsharp import VSharpNet + + +def create_input(shape): + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize("shape", [[1, 3, 16, 16]]) +@pytest.mark.parametrize("num_steps", [3, 4]) +@pytest.mark.parametrize("num_steps_dc_gd", [2]) +@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZERO_FILLED, InitType.ZEROS]) +@pytest.mark.parametrize( + "image_model_architecture, image_model_kwargs", + [ + [ModelName.UNET, {"image_unet_num_filters": 4, "image_unet_num_pool_layers": 2}], + [ModelName.DIDN, {"image_didn_hidden_channels": 4, "image_didn_num_dubs": 2, "image_didn_num_convs_recon": 2}], + ], +) +@pytest.mark.parametrize( + "initializer_channels, initializer_dilations, initializer_multiscale", + [ + [(8, 8, 16), (1, 1, 4), 1], + [(8, 8, 16), (1, 1, 4), 2], + [(2, 2, 4), (1, 1, 1), 3], + ], +) +@pytest.mark.parametrize("aux_steps", [-1, 1, -2, 4]) +def test_varsplitnet( + shape, + num_steps, + num_steps_dc_gd, + image_init, + image_model_architecture, + image_model_kwargs, + initializer_channels, + initializer_dilations, + initializer_multiscale, + aux_steps, +): + if ( + (image_init not in [InitType.SENSE, InitType.ZERO_FILLED]) + or (aux_steps == 0) + or (aux_steps <= -2) + or (aux_steps > num_steps) + ): + with pytest.raises(ValueError): + model = VSharpNet( + fft2, + ifft2, + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_init=image_init, + no_parameter_sharing=False, + initializer_channels=initializer_channels, + initializer_dilations=initializer_dilations, + initializer_multiscale=initializer_multiscale, + auxiliary_steps=aux_steps, + image_model_architecture=image_model_architecture, + **image_model_kwargs, + ).cpu() + + else: + model = VSharpNet( + fft2, + ifft2, + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_init=image_init, + no_parameter_sharing=False, + initializer_channels=initializer_channels, + initializer_dilations=initializer_dilations, + initializer_multiscale=initializer_multiscale, + auxiliary_steps=aux_steps, + image_model_architecture=image_model_architecture, + **image_model_kwargs, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + out = model(kspace, sens, mask) + + for i in range(len(out)): + assert list(out[i].shape) == [shape[0]] + shape[2:] + [2] diff --git a/tests/tests_nn/test_vsharp_engine.py b/tests/tests_nn/test_vsharp_engine.py new file mode 100644 index 000000000..1bb30ce82 --- /dev/null +++ b/tests/tests_nn/test_vsharp_engine.py @@ -0,0 +1,83 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for direct.nn.vsharp.vsharp_engine module.""" + +import functools + +import numpy as np +import pytest +import torch + +from direct.config.defaults import DefaultConfig, FunctionConfig, LossConfig, TrainingConfig, ValidationConfig +from direct.data.transforms import fft2, ifft2 +from direct.nn.vsharp.config import VSharpNetConfig +from direct.nn.vsharp.vsharp import VSharpNet +from direct.nn.vsharp.vsharp_engine import VSharpNetEngine + + +def create_sample(shape, **kwargs): + sample = dict() + sample["masked_kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["sensitivity_map"] = torch.from_numpy(np.random.randn(*shape)).float() + for k, v in locals()["kwargs"].items(): + sample[k] = v + return sample + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vsharpnet_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNetConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_unet_num_filters=num_filters, + image_unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + image_unet_num_filters=model_config.image_unet_num_filters, + image_unet_num_pool_layers=model_config.image_unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNetEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1])