diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 7b41bd10..50d7e9c5 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -26,8 +26,8 @@ from direct.common._gaussian import gaussian_mask_1d, gaussian_mask_2d # pylint: disable=no-name-in-module from direct.common._poisson import poisson as _poisson # pylint: disable=no-name-in-module from direct.environment import DIRECT_CACHE_DIR -from direct.types import DirectEnum, MaskFuncMode, Number -from direct.utils import reshape_array_to_shape, str_to_class +from direct.types import DirectEnum, MaskFuncMode, Number, TensorOrNdarray +from direct.utils import str_to_class from direct.utils.io import download_url # pylint: disable=arguments-differ @@ -184,6 +184,46 @@ def mask_func(self, *args, **kwargs) -> torch.Tensor: """ raise NotImplementedError("This method should be implemented by a child class.") + def _reshape_and_add_coil_axis(self, mask: TensorOrNdarray, shape: tuple[int, ...]) -> torch.Tensor: + """Reshape the mask with ones to match shape and add a coil axis. + + Parameters + ---------- + mask : np.ndarray or torch.Tensor + Input mask of shape (num_rows, num_cols) if mode is MaskFuncMode.STATIC, and + (nt or num_slices, num_rows, num_cols) if mode is MaskFuncMode.DYNAMIC or + MaskFuncMode.MULTISLICE to be reshaped. + shape : tuple of ints + Shape of the output array after reshaping. Expects shape to be (\*, num_rows, num_cols, channels) for + mode MaskFuncMode.STATIC, and (\*, nt or num_slices, num_rows, num_cols, channels) for mode + MaskFuncMode.DYNAMIC where \* is any number of dimensions. + + Returns + ------- + toch.Tensor + Reshaped mask tensor with ones with an added coil axis. + """ + num_cols = shape[-2] + num_rows = shape[-3] + + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask_shape[-3] = num_rows + + # If mode is dynamic or multislice dim should not be 1 + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: + mask_shape[-4] = shape[-4] + + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + + # Reshape the mask to match the shape and make boolean + mask = mask.reshape(*mask_shape).bool() + # Add coil axis + mask = mask[None, ...] + + return mask + def __call__(self, shape: tuple[int, ...], *args, **kwargs) -> torch.Tensor: """Calls the mask function. @@ -290,36 +330,40 @@ def center_mask_func(num_cols: int, num_low_freqs: int) -> np.ndarray: return mask - def _reshape_and_broadcast_mask(self, shape: tuple[int, ...], mask: np.ndarray) -> np.ndarray: - """Broadcasts and reshapes the mask to the shape of the input k-space data. + @staticmethod + def _broadcast_mask(mask: np.ndarray, num_rows: int) -> np.ndarray: + """Broadcast the input mask array to match a specified number of rows. + + Useful for line masks that need to be broadcasted to match the number of rows in the k-space. Parameters ---------- - shape : tuple of ints - Shape of the k-space data. mask : np.ndarray - Mask to be reshaped and broadcasted. + Input mask array to be broadcasted. + num_rows : int + Number of rows to which the mask array will be broadcasted. Returns ------- np.ndarray - Reshaped and broadcasted mask. - """ - num_cols = shape[-2] - num_rows = shape[-3] - - # Reshape the mask - mask_shape = [1 for _ in shape] - mask_shape[-2] = num_cols - if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: - mask_shape[-4] = shape[-4] - mask = mask.reshape(*mask_shape).astype(bool) - mask_shape[-3] = num_rows + Broadcasted mask array. - # Add coil axis, make array writable. - mask = np.broadcast_to(mask, mask_shape)[np.newaxis, ...].copy() + Raises + ------ + ValueError + If the input mask array has an unsupported number of dimensions. + """ + if mask.ndim == 1: + broadcast_mask = np.tile(mask, (num_rows, 1)) + elif mask.ndim == 2: + broadcast_mask = np.tile(mask[:, np.newaxis, :], (1, num_rows, 1)) + else: + raise ValueError( + f"Mask should have 1 dimension for mode STATIC " + f"and 2 dimensions for mode DYNAMIC or MULTISLICE. Got mask of shape {mask.shape}." + ) - return mask + return broadcast_mask class RandomMaskFunc(CartesianVerticalMaskFunc): @@ -420,6 +464,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] + num_rows = shape[-3] num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -437,7 +482,7 @@ def mask_func( mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) # Create the mask mask = mask.reshape(num_slc_or_time, -1) # In case mode != MaskFuncMode.STATIC: @@ -445,7 +490,7 @@ def mask_func( prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) mask[i] = mask[i] | (self.rng.uniform(size=num_cols) < prob) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) class FastMRIRandomMaskFunc(RandomMaskFunc): @@ -682,6 +727,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] + num_rows = shape[-3] num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -699,7 +745,7 @@ def mask_func( mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) # determine acceleration rate by adjusting for the number of low frequencies adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) @@ -711,7 +757,7 @@ def mask_func( accel_samples = np.around(accel_samples).astype(np.uint) mask[i, accel_samples] = True - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) class FastMRIEquispacedMaskFunc(EquispacedMaskFunc): @@ -941,6 +987,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] + num_rows = shape[-3] num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -965,7 +1012,7 @@ def mask_func( acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(acs_mask, num_rows), shape) # adjust acceleration rate based on target acceleration. adjusted_target_cols_to_sample = target_cols_to_sample - num_low_freqs @@ -998,7 +1045,7 @@ def mask_func( mask[i] = np.logical_or(mask[i], acs_mask[i]) mask = np.stack(mask, axis=0).squeeze() - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) class FastMRIMagicMaskFunc(MagicMaskFunc): @@ -1619,7 +1666,7 @@ def mask_func( acs_mask.append(self.circular_centered_mask(mask[-1])) mask = torch.stack(mask, dim=0).squeeze() acs_mask = torch.stack(acs_mask, dim=0).squeeze() - acs_mask = reshape_array_to_shape(acs_mask, shape)[None].bool() + acs_mask = self._reshape_and_add_coil_axis(acs_mask, shape) if return_acs: return acs_mask @@ -1633,7 +1680,7 @@ def mask_func( if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) - acs_mask = torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() + acs_mask = self._reshape_and_add_coil_axis(acs_mask, shape) if return_acs: return acs_mask @@ -1656,7 +1703,7 @@ def mask_func( ) mask = torch.stack(mask, dim=0).squeeze() - return reshape_array_to_shape(mask, shape)[np.newaxis].bool() | acs_mask + return self._reshape_and_add_coil_axis(mask, shape) | acs_mask class RadialMaskFunc(CIRCUSMaskFunc): @@ -1906,14 +1953,14 @@ def mask_func( acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction) if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) - return torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() + return self._reshape_and_add_coil_axis(acs_mask, shape) mask = [] for _ in range(num_slc_or_time): mask.append(self.poisson(num_rows, num_cols, center_fraction, acceleration, self.rng.randint(1e5))) mask = np.stack(mask, axis=0).squeeze() - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() + return self._reshape_and_add_coil_axis(mask, shape) def poisson( self, @@ -2069,6 +2116,7 @@ def mask_func( """ num_cols = shape[-2] + num_rows = shape[-3] num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -2083,7 +2131,7 @@ def mask_func( mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) # Calls cython function nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1)) @@ -2104,7 +2152,7 @@ def mask_func( nonzero_count, num_cols, num_cols // 2, 6 * np.sqrt(num_cols // 2), mask, self.rng.randint(1e5) ) - return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool)) + return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape) class Gaussian2DMaskFunc(BaseMaskFunc): @@ -2200,7 +2248,7 @@ def mask_func( mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() + return self._reshape_and_add_coil_axis(mask, shape) std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float) @@ -2225,7 +2273,7 @@ def mask_func( nonzero_count, num_rows, num_cols, num_rows // 2, num_cols // 2, std, mask, self.rng.randint(1e5) ) - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool() + return self._reshape_and_add_coil_axis(mask, shape) class KtBaseMaskFunc(BaseMaskFunc): @@ -2555,7 +2603,7 @@ def mask_func( mask = mask + acs_mask mask = mask > 0 - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) + return self._reshape_and_add_coil_axis(mask, shape) class KtUniformMaskFunc(KtBaseMaskFunc): @@ -2666,7 +2714,7 @@ def mask_func( mask = mask + acs_mask mask = mask > 0 - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) + return self._reshape_and_add_coil_axis(mask, shape) class KtGaussian1DMaskFunc(KtBaseMaskFunc): @@ -2795,7 +2843,7 @@ def mask_func( mask = mask + acs_mask mask = mask > 0 - return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]) + return self._reshape_and_add_coil_axis(mask, shape) def integerize_seed(seed: Union[None, tuple[int, ...], list[int]]) -> int: diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py index 5de3c6d0..96914046 100644 --- a/direct/utils/__init__.py +++ b/direct/utils/__init__.py @@ -535,59 +535,3 @@ def dict_flatten(in_dict: DictOrDictConfig, dict_out: Optional[DictOrDictConfig] continue dict_out[k] = v return dict_out - - -def reshape_array_to_shape(array: np.ndarray, requested_shape: Tuple[int, ...]) -> np.ndarray: - """Reshapes the given array to match the requested shape by adding dimensions of size 1 where necessary. - - Parameters - ---------- - array : np.ndarray - The input array to be reshaped. - requested_shape tuple of ints - The desired shape of the output array. - - Returns - ------- - np.ndarray - The reshaped array with the requested shape. - - Example - ------- - >>> array1 = np.random.rand(4, 5) - >>> requested_shape1 = (4, 5, 1) - >>> result1 = reshape_array_to_shape(array1, requested_shape1) - >>> print(result1.shape) # Output: (4, 5, 1) - - >>> array2 = np.random.rand(4, 5) - >>> requested_shape2 = (1, 4, 5, 1) - >>> result2 = reshape_array_to_shape(array2, requested_shape2) - >>> print(result2.shape) # Output: (1, 4, 5, 1) - - >>> array3 = np.random.rand(2, 4, 5) - >>> requested_shape3 = (2, 4, 5, 1) - >>> result3 = reshape_array_to_shape(array3, requested_shape3) - >>> print(result3.shape) # Output: (2, 4, 5, 1) - """ - - # Get the current shape of the array - current_shape = array.shape - - # Check if the current shape already matches the requested shape - if current_shape == requested_shape: - return array - - # Initialize a new shape list with ones - new_shape = [1] * len(requested_shape) - - # Fill in the new shape list with dimensions from the current shape where appropriate - j = 0 # Index for current shape - for i, dim in enumerate(requested_shape): - if j < len(current_shape) and dim == current_shape[j]: - new_shape[i] = current_shape[j] - j += 1 - - # Reshape the array to the new shape - reshaped_array = np.reshape(array, new_shape) - - return reshaped_array diff --git a/tests/tests_utils/test_utils.py b/tests/tests_utils/test_utils.py index 1a6be27c..40135271 100644 --- a/tests/tests_utils/test_utils.py +++ b/tests/tests_utils/test_utils.py @@ -9,7 +9,7 @@ import pytest import torch -from direct.utils import is_power_of_two, normalize_image, remove_keys, reshape_array_to_shape, set_all_seeds +from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds from direct.utils.asserts import assert_complex from direct.utils.bbox import crop_to_largest from direct.utils.dataset import get_filenames_for_datasets_from_config @@ -126,20 +126,3 @@ def test_normalize_image(shape, eps): img = np.random.randn(*shape) normalized_img = normalize_image(img, eps) assert normalized_img.min() >= 0.0 and normalized_img.max() <= 1.0 - - -@pytest.mark.parametrize( - "array, requested_shape, expected_shape", - [ - (np.random.rand(4, 5), (4, 5, 1), (4, 5, 1)), - (np.random.rand(4, 5), (1, 4, 5, 1), (1, 4, 5, 1)), - (np.random.rand(2, 4, 5), (2, 4, 5, 1), (2, 4, 5, 1)), - (np.random.rand(3, 3), (1, 3, 1, 3, 1), (1, 3, 1, 3, 1)), - (np.random.rand(2, 3), (2, 1, 3), (2, 1, 3)), - (np.random.rand(4), (1, 1, 4, 1), (1, 1, 4, 1)), - (np.random.rand(6), (1, 6, 1), (1, 6, 1)), - ], -) -def test_reshape_array_to_shape(array, requested_shape, expected_shape): - result = reshape_array_to_shape(array, requested_shape) - assert result.shape == expected_shape