Skip to content

Commit

Permalink
NKI breast dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 17, 2024
1 parent 17a12bf commit 9ef573d
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from omegaconf import DictConfig
from torch.utils.data import Dataset, IterableDataset

from direct.common.subsample import centered_disk_mask
from direct.data.fake import FakeMRIData
from direct.data.h5_data import H5SliceData
from direct.data.sens import simulate_sensitivity_maps
Expand Down Expand Up @@ -87,6 +88,51 @@ def _et_query(
return str(value.text)


class NKIKSpaceBreastDataset(H5SliceData):
def __init__(
self,
data_root: pathlib.Path,
transform: Optional[Callable] = None,
filenames_filter: Optional[list[PathOrString]] = None,
filenames_lists: Union[list[PathOrString], None] = None,
filenames_lists_root: Union[PathOrString, None] = None,
slice_data: Optional[tuple[int, int]] = None,
acs_ratio: float = 0.1,
**kwargs,
) -> None:
super().__init__(
root=data_root,
filenames_filter=filenames_filter,
filenames_lists=filenames_lists,
filenames_lists_root=filenames_lists_root,
regex_filter=None,
metadata=None,
pass_attrs=False,
text_description=kwargs.get("text_description", None),
pass_h5s=None,
pass_dictionaries=kwargs.get("pass_dictionaries", None),
sensitivity_maps=kwargs.get("sensitivity_maps", None),
slice_data=slice(slice_data[0], slice_data[1]) if slice_data is not None else None,
)
self.acs_ratio = acs_ratio
self.transform = transform

def __getitem__(self, idx: int) -> dict[str, Any]:
sample = super().__getitem__(idx)

sample["sampling_mask"] = sample["kspace"].sum(0) != 0
sample["acs_mask"] = (
centered_disk_mask(sample["sampling_mask"].squeeze().shape, self.acs_ratio) * sample["sampling_mask"]
)

sample["sampling_mask"] = sample["sampling_mask"][None, ..., None]
sample["acs_mask"] = sample["acs_mask"][None, ..., None]
if self.transform:
sample = self.transform(sample)

return sample


class FakeMRIBlobsDataset(Dataset):
"""A PyTorch Dataset class which outputs random fake k-space images which reconstruct into Gaussian blobs.
Expand Down

0 comments on commit 9ef573d

Please sign in to comment.