diff --git a/direct/data/datasets.py b/direct/data/datasets.py index 11e82aa8..bd178901 100644 --- a/direct/data/datasets.py +++ b/direct/data/datasets.py @@ -18,6 +18,7 @@ from omegaconf import DictConfig from torch.utils.data import Dataset, IterableDataset +from direct.common.subsample import centered_disk_mask from direct.data.fake import FakeMRIData from direct.data.h5_data import H5SliceData from direct.data.sens import simulate_sensitivity_maps @@ -87,6 +88,51 @@ def _et_query( return str(value.text) +class NKIKSpaceBreastDataset(H5SliceData): + def __init__( + self, + data_root: pathlib.Path, + transform: Optional[Callable] = None, + filenames_filter: Optional[list[PathOrString]] = None, + filenames_lists: Union[list[PathOrString], None] = None, + filenames_lists_root: Union[PathOrString, None] = None, + slice_data: Optional[tuple[int, int]] = None, + acs_ratio: float = 0.1, + **kwargs, + ) -> None: + super().__init__( + root=data_root, + filenames_filter=filenames_filter, + filenames_lists=filenames_lists, + filenames_lists_root=filenames_lists_root, + regex_filter=None, + metadata=None, + pass_attrs=False, + text_description=kwargs.get("text_description", None), + pass_h5s=None, + pass_dictionaries=kwargs.get("pass_dictionaries", None), + sensitivity_maps=kwargs.get("sensitivity_maps", None), + slice_data=slice(slice_data[0], slice_data[1]) if slice_data is not None else None, + ) + self.acs_ratio = acs_ratio + self.transform = transform + + def __getitem__(self, idx: int) -> dict[str, Any]: + sample = super().__getitem__(idx) + + sample["sampling_mask"] = sample["kspace"].sum(0) != 0 + sample["acs_mask"] = ( + centered_disk_mask(sample["sampling_mask"].squeeze().shape, self.acs_ratio) * sample["sampling_mask"] + ) + + sample["sampling_mask"] = sample["sampling_mask"][None, ..., None] + sample["acs_mask"] = sample["acs_mask"][None, ..., None] + if self.transform: + sample = self.transform(sample) + + return sample + + class FakeMRIBlobsDataset(Dataset): """A PyTorch Dataset class which outputs random fake k-space images which reconstruct into Gaussian blobs.