Skip to content

Commit

Permalink
Change functionality for mask reshaping, make simple
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jul 1, 2024
1 parent 1ed8960 commit 59e6d63
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 114 deletions.
128 changes: 88 additions & 40 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from direct.common._gaussian import gaussian_mask_1d, gaussian_mask_2d # pylint: disable=no-name-in-module
from direct.common._poisson import poisson as _poisson # pylint: disable=no-name-in-module
from direct.environment import DIRECT_CACHE_DIR
from direct.types import DirectEnum, MaskFuncMode, Number
from direct.utils import reshape_array_to_shape, str_to_class
from direct.types import DirectEnum, MaskFuncMode, Number, TensorOrNdarray
from direct.utils import str_to_class
from direct.utils.io import download_url

# pylint: disable=arguments-differ
Expand Down Expand Up @@ -184,6 +184,46 @@ def mask_func(self, *args, **kwargs) -> torch.Tensor:
"""
raise NotImplementedError("This method should be implemented by a child class.")

def _reshape_and_add_coil_axis(self, mask: TensorOrNdarray, shape: tuple[int, ...]) -> torch.Tensor:
"""Reshape the mask with ones to match shape and add a coil axis.
Parameters
----------
mask : np.ndarray or torch.Tensor
Input mask of shape (num_rows, num_cols) if mode is MaskFuncMode.STATIC, and
(nt or num_slices, num_rows, num_cols) if mode is MaskFuncMode.DYNAMIC or
MaskFuncMode.MULTISLICE to be reshaped.
shape : tuple of ints
Shape of the output array after reshaping. Expects shape to be (\*, num_rows, num_cols, channels) for
mode MaskFuncMode.STATIC, and (\*, nt or num_slices, num_rows, num_cols, channels) for mode
MaskFuncMode.DYNAMIC where \* is any number of dimensions.
Returns
-------
toch.Tensor
Reshaped mask tensor with ones with an added coil axis.
"""
num_cols = shape[-2]
num_rows = shape[-3]

mask_shape = [1 for _ in shape]
mask_shape[-2] = num_cols
mask_shape[-3] = num_rows

# If mode is dynamic or multislice dim should not be 1
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask_shape[-4] = shape[-4]

if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)

# Reshape the mask to match the shape and make boolean
mask = mask.reshape(*mask_shape).bool()
# Add coil axis
mask = mask[None, ...]

return mask

def __call__(self, shape: tuple[int, ...], *args, **kwargs) -> torch.Tensor:
"""Calls the mask function.
Expand Down Expand Up @@ -290,36 +330,40 @@ def center_mask_func(num_cols: int, num_low_freqs: int) -> np.ndarray:

return mask

def _reshape_and_broadcast_mask(self, shape: tuple[int, ...], mask: np.ndarray) -> np.ndarray:
"""Broadcasts and reshapes the mask to the shape of the input k-space data.
@staticmethod
def _broadcast_mask(mask: np.ndarray, num_rows: int) -> np.ndarray:
"""Broadcast the input mask array to match a specified number of rows.
Useful for line masks that need to be broadcasted to match the number of rows in the k-space.
Parameters
----------
shape : tuple of ints
Shape of the k-space data.
mask : np.ndarray
Mask to be reshaped and broadcasted.
Input mask array to be broadcasted.
num_rows : int
Number of rows to which the mask array will be broadcasted.
Returns
-------
np.ndarray
Reshaped and broadcasted mask.
"""
num_cols = shape[-2]
num_rows = shape[-3]

# Reshape the mask
mask_shape = [1 for _ in shape]
mask_shape[-2] = num_cols
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask_shape[-4] = shape[-4]
mask = mask.reshape(*mask_shape).astype(bool)
mask_shape[-3] = num_rows
Broadcasted mask array.
# Add coil axis, make array writable.
mask = np.broadcast_to(mask, mask_shape)[np.newaxis, ...].copy()
Raises
------
ValueError
If the input mask array has an unsupported number of dimensions.
"""
if mask.ndim == 1:
broadcast_mask = np.tile(mask, (num_rows, 1))
elif mask.ndim == 2:
broadcast_mask = np.tile(mask[:, np.newaxis, :], (1, num_rows, 1))
else:
raise ValueError(
f"Mask should have 1 dimension for mode STATIC "
f"and 2 dimensions for mode DYNAMIC or MULTISLICE. Got mask of shape {mask.shape}."
)

return mask
return broadcast_mask


class RandomMaskFunc(CartesianVerticalMaskFunc):
Expand Down Expand Up @@ -420,6 +464,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_rows = shape[-3]
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
Expand All @@ -437,15 +482,15 @@ def mask_func(
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)

# Create the mask
mask = mask.reshape(num_slc_or_time, -1) # In case mode != MaskFuncMode.STATIC:
for i in range(num_slc_or_time):
prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
mask[i] = mask[i] | (self.rng.uniform(size=num_cols) < prob)

return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)


class FastMRIRandomMaskFunc(RandomMaskFunc):
Expand Down Expand Up @@ -682,6 +727,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_rows = shape[-3]
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
Expand All @@ -699,7 +745,7 @@ def mask_func(
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)

# determine acceleration rate by adjusting for the number of low frequencies
adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
Expand All @@ -711,7 +757,7 @@ def mask_func(
accel_samples = np.around(accel_samples).astype(np.uint)
mask[i, accel_samples] = True

return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)


class FastMRIEquispacedMaskFunc(EquispacedMaskFunc):
Expand Down Expand Up @@ -941,6 +987,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_rows = shape[-3]
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
Expand All @@ -965,7 +1012,7 @@ def mask_func(
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(acs_mask, num_rows), shape)

# adjust acceleration rate based on target acceleration.
adjusted_target_cols_to_sample = target_cols_to_sample - num_low_freqs
Expand Down Expand Up @@ -998,7 +1045,7 @@ def mask_func(
mask[i] = np.logical_or(mask[i], acs_mask[i])
mask = np.stack(mask, axis=0).squeeze()

return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)


class FastMRIMagicMaskFunc(MagicMaskFunc):
Expand Down Expand Up @@ -1619,7 +1666,7 @@ def mask_func(
acs_mask.append(self.circular_centered_mask(mask[-1]))
mask = torch.stack(mask, dim=0).squeeze()
acs_mask = torch.stack(acs_mask, dim=0).squeeze()
acs_mask = reshape_array_to_shape(acs_mask, shape)[None].bool()
acs_mask = self._reshape_and_add_coil_axis(acs_mask, shape)
if return_acs:
return acs_mask

Expand All @@ -1633,7 +1680,7 @@ def mask_func(
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)

acs_mask = torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool()
acs_mask = self._reshape_and_add_coil_axis(acs_mask, shape)

if return_acs:
return acs_mask
Expand All @@ -1656,7 +1703,7 @@ def mask_func(
)

mask = torch.stack(mask, dim=0).squeeze()
return reshape_array_to_shape(mask, shape)[np.newaxis].bool() | acs_mask
return self._reshape_and_add_coil_axis(mask, shape) | acs_mask


class RadialMaskFunc(CIRCUSMaskFunc):
Expand Down Expand Up @@ -1906,14 +1953,14 @@ def mask_func(
acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction)
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)
return torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool()
return self._reshape_and_add_coil_axis(acs_mask, shape)

mask = []
for _ in range(num_slc_or_time):
mask.append(self.poisson(num_rows, num_cols, center_fraction, acceleration, self.rng.randint(1e5)))
mask = np.stack(mask, axis=0).squeeze()

return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool()
return self._reshape_and_add_coil_axis(mask, shape)

def poisson(
self,
Expand Down Expand Up @@ -2069,6 +2116,7 @@ def mask_func(
"""

num_cols = shape[-2]
num_rows = shape[-3]
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
Expand All @@ -2083,7 +2131,7 @@ def mask_func(
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)

# Calls cython function
nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1))
Expand All @@ -2104,7 +2152,7 @@ def mask_func(
nonzero_count, num_cols, num_cols // 2, 6 * np.sqrt(num_cols // 2), mask, self.rng.randint(1e5)
)

return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask).astype(bool))
return self._reshape_and_add_coil_axis(self._broadcast_mask(mask, num_rows), shape)


class Gaussian2DMaskFunc(BaseMaskFunc):
Expand Down Expand Up @@ -2200,7 +2248,7 @@ def mask_func(
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool()
return self._reshape_and_add_coil_axis(mask, shape)

std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float)

Expand All @@ -2225,7 +2273,7 @@ def mask_func(
nonzero_count, num_rows, num_cols, num_rows // 2, num_cols // 2, std, mask, self.rng.randint(1e5)
)

return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool()
return self._reshape_and_add_coil_axis(mask, shape)


class KtBaseMaskFunc(BaseMaskFunc):
Expand Down Expand Up @@ -2555,7 +2603,7 @@ def mask_func(
mask = mask + acs_mask
mask = mask > 0

return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis])
return self._reshape_and_add_coil_axis(mask, shape)


class KtUniformMaskFunc(KtBaseMaskFunc):
Expand Down Expand Up @@ -2666,7 +2714,7 @@ def mask_func(
mask = mask + acs_mask
mask = mask > 0

return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis])
return self._reshape_and_add_coil_axis(mask, shape)


class KtGaussian1DMaskFunc(KtBaseMaskFunc):
Expand Down Expand Up @@ -2795,7 +2843,7 @@ def mask_func(
mask = mask + acs_mask
mask = mask > 0

return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis])
return self._reshape_and_add_coil_axis(mask, shape)


def integerize_seed(seed: Union[None, tuple[int, ...], list[int]]) -> int:
Expand Down
56 changes: 0 additions & 56 deletions direct/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,59 +535,3 @@ def dict_flatten(in_dict: DictOrDictConfig, dict_out: Optional[DictOrDictConfig]
continue
dict_out[k] = v
return dict_out


def reshape_array_to_shape(array: np.ndarray, requested_shape: Tuple[int, ...]) -> np.ndarray:
"""Reshapes the given array to match the requested shape by adding dimensions of size 1 where necessary.
Parameters
----------
array : np.ndarray
The input array to be reshaped.
requested_shape tuple of ints
The desired shape of the output array.
Returns
-------
np.ndarray
The reshaped array with the requested shape.
Example
-------
>>> array1 = np.random.rand(4, 5)
>>> requested_shape1 = (4, 5, 1)
>>> result1 = reshape_array_to_shape(array1, requested_shape1)
>>> print(result1.shape) # Output: (4, 5, 1)
>>> array2 = np.random.rand(4, 5)
>>> requested_shape2 = (1, 4, 5, 1)
>>> result2 = reshape_array_to_shape(array2, requested_shape2)
>>> print(result2.shape) # Output: (1, 4, 5, 1)
>>> array3 = np.random.rand(2, 4, 5)
>>> requested_shape3 = (2, 4, 5, 1)
>>> result3 = reshape_array_to_shape(array3, requested_shape3)
>>> print(result3.shape) # Output: (2, 4, 5, 1)
"""

# Get the current shape of the array
current_shape = array.shape

# Check if the current shape already matches the requested shape
if current_shape == requested_shape:
return array

# Initialize a new shape list with ones
new_shape = [1] * len(requested_shape)

# Fill in the new shape list with dimensions from the current shape where appropriate
j = 0 # Index for current shape
for i, dim in enumerate(requested_shape):
if j < len(current_shape) and dim == current_shape[j]:
new_shape[i] = current_shape[j]
j += 1

# Reshape the array to the new shape
reshaped_array = np.reshape(array, new_shape)

return reshaped_array
Loading

0 comments on commit 59e6d63

Please sign in to comment.