Skip to content

Commit

Permalink
Add cmr dataset test
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Sep 21, 2023
1 parent 627d37b commit 9275f38
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
2 changes: 1 addition & 1 deletion direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class CMRxReconDataset(Dataset):
.. [1] https://cmrxrecon.github.io/Challenge.html
"""

# pylint: disable=too-many-arguments
def __init__(
self,
data_root: pathlib.Path,
Expand All @@ -423,7 +424,6 @@ def __init__(
compute_mask: bool = False,
kspace_context: Optional[str] = None,
) -> None:
# pylint: disable=too-many-arguments
"""Inits :class:`CMRxReconDataset`.
Parameters
Expand Down
89 changes: 89 additions & 0 deletions tests/tests_data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from direct.data.datasets import (
CalgaryCampinasDataset,
CMRxReconDataset,
ConcatDataset,
FakeMRIBlobsDataset,
FastMRIDataset,
Expand Down Expand Up @@ -279,3 +280,91 @@ def test_ConcatDataset(num_samples, shapes):

with pytest.raises(ValueError):
dataset[-(np.cumsum(num_samples) + 1)]


@pytest.mark.parametrize(
"num_samples",
[3],
)
@pytest.mark.parametrize(
"shape",
[(3, 9, 10, 100, 130)],
)
@pytest.mark.parametrize(
"kspace_context",
[None, "time", "slice"],
)
@pytest.mark.parametrize("extra_keys, compute_mask", [[["mask"], False], [None, True], [None, False]])
@pytest.mark.parametrize(
"transform",
[None, lambda x: x],
)
@pytest.mark.parametrize(
"filter",
[None, ["file0.mat", "file1.mat"]],
)
def test_CMRxReconDataset(num_samples, shape, kspace_context, compute_mask, extra_keys, transform, filter):
with tempfile.TemporaryDirectory() as tempdir:
for _ in range(num_samples):
kspace = np.random.rand(*shape) + 1j * np.random.rand(*shape)
ny, nx = shape[-2:]
if compute_mask or extra_keys is not None:
mask = np.zeros((ny, nx), dtype=bool)
mask[:, ny // 2 - 12 : ny // 2 + 12] = True
mask[:, np.random.randint(0, ny, 30)] = True

if compute_mask:
kspace = mask[None, None, None] * kspace

with h5py.File(pathlib.Path(tempdir) / f"file{_}.mat", "w") as f:
dtype = np.dtype([("real", np.float32), ("imag", np.float32)])
# Reshape the complex data into a shape that matches the compound datatype
compound_data = np.empty(shape, dtype=dtype)
compound_data["real"] = np.real(kspace)
compound_data["imag"] = np.imag(kspace)
f.create_dataset("kspace_full", data=compound_data)
if extra_keys is not None:
f.create_dataset("mask", data=mask, dtype=bool)

dataset = CMRxReconDataset(
pathlib.Path(tempdir),
transform=transform,
kspace_context=kspace_context,
compute_mask=compute_mask,
extra_keys=extra_keys,
filenames_filter=[pathlib.Path(pathlib.Path(tempdir) / f) for f in filter] if filter else None,
)
sample = dataset[0]
assert "kspace" in sample

if kspace_context is None:
assert dataset.ndim == 2
assert len(dataset) == np.prod(shape[:2]) * (num_samples if not filter else len(filter))
assert sample["kspace"].shape == (shape[2],) + shape[3:][::-1]
elif kspace_context == "time":
assert dataset.ndim == 3
assert len(dataset) == shape[1] * (num_samples if not filter else len(filter))
assert (
sample["kspace"].shape
== (
shape[2],
shape[0],
)
+ shape[3:][::-1]
)
else:
assert dataset.ndim == 3
assert len(dataset) == shape[0] * (num_samples if not filter else len(filter))
assert (
sample["kspace"].shape
== (
shape[2],
shape[1],
)
+ shape[3:][::-1]
)

if compute_mask or extra_keys is not None:
assert "sampling_mask" in sample
assert "acs_mask" in sample
np.allclose(sample["sampling_mask"], mask.T[None, ..., None])

0 comments on commit 9275f38

Please sign in to comment.