Skip to content

Commit

Permalink
Add CMRxRecon Challenge 23 code
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 4, 2024
1 parent 79543f0 commit a17970c
Show file tree
Hide file tree
Showing 28 changed files with 11,474 additions and 1 deletion.
269 changes: 268 additions & 1 deletion direct/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""DIRECT datasets module."""

from __future__ import annotations

import bisect
import contextlib
import logging
Expand All @@ -11,6 +13,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import h5py
import numpy as np
from omegaconf import DictConfig
from torch.utils.data import Dataset, IterableDataset
Expand All @@ -20,13 +23,15 @@
from direct.data.sens import simulate_sensitivity_maps
from direct.types import PathOrString
from direct.utils import remove_keys, str_to_class
from direct.utils.dataset import get_filenames_for_datasets

logger = logging.getLogger(__name__)

__all__ = [
"build_dataset_from_input",
"CalgaryCampinasDataset",
"ConcatDataset",
"CMRxReconDataset",
"FastMRIDataset",
"FakeMRIBlobsDataset",
"SheppLoganDataset",
Expand Down Expand Up @@ -395,6 +400,268 @@ def __broadcast_mask(self, kspace_shape, mask):
return mask


class CMRxReconDataset(Dataset):
"""CMRxRecon Challenge Dataset [1]_.
References
----------
.. [1] https://cmrxrecon.github.io/Challenge.html
"""

# pylint: disable=too-many-arguments
def __init__(
self,
data_root: pathlib.Path,
transform: Optional[Callable[[tuple[Any, ...]], dict]] = None,
filenames_filter: Optional[list[PathOrString]] = None,
filenames_lists: Optional[list[PathOrString]] = None,
filenames_lists_root: Optional[PathOrString] = None,
kspace_key: str = "kspace_full",
extra_keys: Optional[tuple] = None,
text_description: Optional[str] = None,
compute_mask: bool = False,
kspace_context: Optional[str] = None,
) -> None:
"""Inits :class:`CMRxReconDataset`.
Parameters
----------
data_root : pathlib.Path
Root directory to data.
transform : Callable, optional
A list of transforms to be applied on the generated samples. Default is None.
filenames_filter : list[PathOrString], optional
List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob
on the root. If set, will skip searching for files in the root. Default: None.
filenames_lists : list[PathOrString], optional
List of paths pointing to `.lst` file(s) that contain file-names in `root` to filter.
Should be the same as the ones that can be derived from a glob on the root. If this is set,
this will override the `filenames_filter` option if not None. Defualt: None.
filenames_lists_root : PathOrString, optional
Root of `filenames_lists`. Ignored if `filename_lists` is None. Default: None.
kspace_key : str
Key to load the k-space. Typically, `kspace_full` for fully-sampled data, or `kspace_subxx` for
sub-sampled data. Default: `kspace_full`.
extra_keys: Tuple of strings
Add extra keys in h5 file to output. May be used to load sampling masks, e.g. `maskxx`. Default: None.
text_description: str
Description of dataset, can be useful for logging.
compute_mask : bool
If True, it will compute the sampling mask from data. This should be typically True at inference, where
data are already undersampled. This will also compute `acs_mask`, which is by default the 24
center lines. Default: False.
kspace_context : str, optional
Can be either None, `time` or `slice`. If None, data will be loaded per slice or time-frame (2D data).
If `time`, all time frames(phases) per slice will be loaded (3D data). If `slice`, all sliced per time frame
will be loaded (3D data). Default: None.
"""
self.logger = logging.getLogger(type(self).__name__)

self.root = pathlib.Path(data_root)
self.filenames_filter = filenames_filter

self.text_description = text_description

self.kspace_key = kspace_key

self.data: list[tuple] = []

self.volume_indices: Dict[pathlib.Path, range] = {}

if kspace_context not in [None, "slice", "time"]:
raise ValueError(
f"Attribute `kspace_context` can be None for 2D data or `slice` or `time` for 3D. "
f"Received {kspace_context}."
)

self.kspace_context = kspace_context

self.ndim = 2 if self.kspace_context is None else 3

# If filenames_filter and filenames_lists are given, it will load files in filenames_filter
# and filenames_lists will be ignored.
if filenames_filter is None:
if filenames_lists is not None:
if filenames_lists_root is None:
e = "`filenames_lists` is passed but `filenames_lists_root` is None."
self.logger.error(e)
raise ValueError(e)
filenames = get_filenames_for_datasets(
lists=filenames_lists, files_root=filenames_lists_root, data_root=data_root
)
self.logger.info("Attempting to load %s filenames from list(s).", len(filenames))
else:
self.logger.info("Parsing directory %s for mat files.", self.root)
filenames = list(self.root.glob("*.mat"))
else:
self.logger.info("Attempting to load %s filenames.", len(filenames_filter))
filenames = filenames_filter

filenames = [pathlib.Path(_) for _ in filenames]

if len(filenames) == 0:
warn = (
f"Found 0 mat files in directory {self.root}."
if not self.text_description
else f"Found 0 mat files in directory {self.root} for dataset {self.text_description}."
)
self.logger.warning(warn)
else:
self.logger.info("Using %s mat files in %s.", len(filenames), self.root)

self.parse_filenames_data(filenames, extra_mats=None) # Collect information on the image masks_dict.
self.extra_keys = extra_keys

self.compute_mask = compute_mask

self.transform = transform

if self.text_description:
self.logger.info("Dataset description: %s.", self.text_description)

def parse_filenames_data(self, filenames, extra_mats=None):
current_slice_number = 0 # This is required to keep track of where a volume is in the dataset

for idx, filename in enumerate(filenames):
if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1):
self.logger.info("Parsing: {:.2f}%.".format((idx + 1) / len(filenames) * 100))
try:
if not filename.exists():
raise OSError(f"{filename} does not exist.")
kspace_shape = h5py.File(filename, "r")[self.kspace_key].shape
self.verify_extra_mat_integrity(filename, extra_mats=extra_mats)
except FileNotFoundError as exc:
self.logger.warning("%s not found. Failed with: %s. Skipping...", filename, exc)
continue
except OSError as exc:
self.logger.warning("%s failed with OSError: %s. Skipping...", filename, exc)
continue

if self.kspace_context is None:
num_slices = np.prod(kspace_shape[:2])
elif self.kspace_context == "slice":
# Slice dimension second
num_slices = kspace_shape[0]
else:
# Time dimension first
num_slices = kspace_shape[1]

self.data += [(filename, slc) for slc in range(num_slices)]

self.volume_indices[filename] = range(
current_slice_number,
current_slice_number + num_slices,
)

current_slice_number += num_slices

@staticmethod
def verify_extra_mat_integrity(image_fn, extra_mats):
if not extra_mats:
return

for key in extra_mats:
mat_key, path = extra_mats[key]
extra_fn = path / image_fn.name
with h5py.File(extra_fn, "r") as file:
_ = file[mat_key].shape
return

def __len__(self):
return len(self.data)

def get_slice_data(self, filename, slice_no, key, extra_keys=None):
data = h5py.File(filename, "r")
shape = data[key].shape

if self.kspace_context is None:
inds = {(i): (k, l) for i, (k, l) in enumerate([(k, l) for k in range(shape[0]) for l in range(shape[1])])}
ind = inds[slice_no]
curr_data = np.array(data[key][ind[0]][ind[1]])
elif self.kspace_context == "slice":
# Slice dimension
curr_data = np.array(data[key][slice_no])
else:
# Time dimension
curr_data = np.array(data[key][:, slice_no])

extra_data = {}

if extra_keys:
for extra_key in self.extra_keys:
extra_data[extra_key] = data[extra_key][()]
data.close()
return curr_data, extra_data

def get_num_slices(self, filename):
num_slices = self.volume_indices[filename].stop - self.volume_indices[filename].start
return num_slices

def __getitem__(self, idx: int) -> Dict[str, Any]: # pylint: disable=too-many-locals
filename, slice_no = self.data[idx]
filename = pathlib.Path(filename)

kspace, extra_data = self.get_slice_data(filename, slice_no, key=self.kspace_key, extra_keys=self.extra_keys)

kspace = kspace["real"] + 1j * kspace["imag"]
kspace = np.swapaxes(kspace, -1, -2)

if kspace.ndim == 2: # Singlecoil data.
kspace = kspace[np.newaxis, ...]

sample = {"kspace": kspace, "filename": str(filename), "slice_no": slice_no}

if self.compute_mask:
nx, ny = kspace.shape[-2:]
sampling_mask = np.abs(kspace).sum(tuple(range(len(kspace.shape) - 2))) != 0
assert tuple(sampling_mask.shape) == (nx, ny)
acs_mask = np.zeros((nx, ny), dtype=bool)
acs_mask[:, ny // 2 - 12 : ny // 2 + 12] = True

sample["sampling_mask"] = sampling_mask[np.newaxis, ..., np.newaxis]
sample["acs_mask"] = acs_mask[np.newaxis, ..., np.newaxis]

elif any("mask" in key for key in extra_data):
mask_keys = [key for key in extra_data if "mask" in key]
# This will load up randomly a mask if more than one keys
mask_key = np.random.choice(mask_keys)

sampling_mask = np.array(extra_data[mask_key]).astype(bool)
for key in mask_keys:
del extra_data[key]

ny, nx = sampling_mask.shape
sampling_mask = np.swapaxes(sampling_mask, -1, -2)

acs_mask = np.zeros((nx, ny), dtype=bool)
acs_mask[:, ny // 2 - 12 : ny // 2 + 12] = True

sample["sampling_mask"] = sampling_mask[np.newaxis, ..., np.newaxis]
sample["acs_mask"] = acs_mask[np.newaxis, ..., np.newaxis]

if self.kspace_context and "sampling_mask" in sample:
sample["sampling_mask"] = sample["sampling_mask"][np.newaxis]
sample["acs_mask"] = sample["acs_mask"][np.newaxis]

sample.update(extra_data)

shape = kspace.shape
sample["reconstruction_size"] = (int(np.round(shape[-2] / 3)), int(np.round(shape[-1] / 2)), 1)
if self.kspace_context:
# Add context dimension in reconstruction size without any crop
context_size = shape[0]
sample["reconstruction_size"] = (context_size,) + sample["reconstruction_size"]
# If context put coil dim first
sample["kspace"] = np.swapaxes(sample["kspace"], 0, 1)

if self.transform:
sample = self.transform(sample)

return sample


class CalgaryCampinasDataset(H5SliceData):
"""Calgary-Campinas challenge dataset."""

Expand Down
13 changes: 13 additions & 0 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ class H5SliceConfig(DatasetConfig):
filenames_lists_root: Optional[str] = None


@dataclass
class CMRxReconConfig(DatasetConfig):
regex_filter: Optional[str] = None
data_root: Optional[str] = None
filenames_filter: Optional[List[str]] = None
filenames_lists: Optional[List[str]] = None
filenames_lists_root: Optional[str] = None
kspace_key: str = "kspace_full"
compute_mask: bool = False
extra_keys: Optional[List[str]] = None
kspace_context: Optional[str] = None


@dataclass
class FastMRIConfig(H5SliceConfig):
pass_attrs: bool = True
Expand Down
Loading

0 comments on commit a17970c

Please sign in to comment.