Skip to content

Commit

Permalink
4258 Add TCIA dataset (#4610)
Browse files Browse the repository at this point in the history
* Add TCIA dataset

Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv authored Jul 19, 2022
1 parent 9fb9d98 commit 422cc6d
Show file tree
Hide file tree
Showing 9 changed files with 777 additions and 31 deletions.
3 changes: 3 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Applications
.. autoclass:: DecathlonDataset
:members:

.. autoclass:: TciaDataset
:members:

.. autoclass:: CrossValidation
:members:

Expand Down
2 changes: 1 addition & 1 deletion monai/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset
from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger
280 changes: 277 additions & 3 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import sys
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from monai.apps.tcia import (
download_tcia_series_instance,
get_tcia_metadata,
get_tcia_ref_uid,
match_tcia_ref_uid_in_study,
)
from monai.apps.utils import download_and_extract
from monai.config.type_definitions import PathLike
from monai.data import (
CacheDataset,
PydicomReader,
load_decathlon_datalist,
load_decathlon_properties,
partition_dataset,
Expand All @@ -27,7 +37,7 @@
from monai.transforms import LoadImaged, Randomizable
from monai.utils import ensure_tuple

__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"]
__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation", "TciaDataset"]


class MedNISTDataset(Randomizable, CacheDataset):
Expand Down Expand Up @@ -194,8 +204,8 @@ class DecathlonDataset(Randomizable, CacheDataset):
for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D].
download: whether to download and extract the Decathlon from resource link, default is False.
if expected file already exists, skip downloading even set it to True.
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
user can manually copy tar file or dataset folder to the root directory.
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
note to set same seed for `training` and `validation` sections.
cache_num: number of items to be cached. Default is `sys.maxsize`.
Expand Down Expand Up @@ -379,6 +389,270 @@ def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
return [datalist[i] for i in self.indices]


class TciaDataset(Randomizable, CacheDataset):
"""
The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset
and generate items for training, validation or test.
The Highdicom library is used to load dicom data with modality "SEG", but only a part of collections are
supoorted, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability"
and "PROSTATEx". Therefore, if "seg" is included in `keys` of the `LoadImaged` transform and loading some
other collections, errors may be raised. For supported collections, the original "SEG" information may not
always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use
the `label_dict` argument of `PydicomReader` when calling the `LoadImaged` transform. The prepared label dicts
of collections that are mentioned above is also saved in: `monai.apps.tcia.TCIA_LABEL_DICT`. You can also refer
to the second example bellow.
This class is based on :py:class:`monai.data.CacheDataset` to accelerate the training process.
Args:
root_dir: user's local directory for caching and loading the TCIA dataset.
collection: name of a TCIA collection.
a TCIA dataset is defined as a collection. Please check the following list to browse
the collection list (only public collections can be downloaded):
https://www.cancerimagingarchive.net/collections/
section: expected data section, can be: `training`, `validation` or `test`.
transform: transforms to execute operations on input data.
for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D].
If not specified, `LoadImaged(reader="PydicomReader", keys=["image"])` will be used as the default
transform. In addition, we suggest to set the argument `labels` for `PydicomReader` if segmentations
are needed to be loaded. The original labels for each dicom series may be different, using this argument
is able to unify the format of labels.
download: whether to download and extract the dataset, default is False.
if expected file already exists, skip downloading even set it to True.
user can manually copy tar file or dataset folder to the root directory.
download_len: number of series that will be downloaded, the value should be larger than 0 or -1, where -1 means
all series will be downloaded. Default is -1.
seg_type: modality type of segmentation that is used to do the first step download. Default is "SEG".
modality_tag: tag of modality. Default is (0x0008, 0x0060).
ref_series_uid_tag: tag of referenced Series Instance UID. Default is (0x0020, 0x000e).
ref_sop_uid_tag: tag of referenced SOP Instance UID. Default is (0x0008, 0x1155).
specific_tags: tags that will be loaded for "SEG" series. This argument will be used in
`monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),
(0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
note to set same seed for `training` and `validation` sections.
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 0.0 (no cache).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
If num_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is speficied, 1 will be used instead.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
Example::
# collection is "Pancreatic-CT-CBCT-SEG", seg_type is "RTSTRUCT"
data = TciaDataset(
root_dir="./", collection="Pancreatic-CT-CBCT-SEG", seg_type="RTSTRUCT", download=True
)
# collection is "C4KC-KiTS", seg_type is "SEG", and load both images and segmentations
from monai.apps.tcia import TCIA_LABEL_DICT
transform = Compose(
[
LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=TCIA_LABEL_DICT["C4KC-KiTS"]),
EnsureChannelFirstd(keys=["image", "seg"]),
ResampleToMatchd(keys="image", key_dst="seg"),
]
)
data = TciaDataset(
root_dir="./", collection="C4KC-KiTS", section="validation", seed=12345, download=True
)
print(data[0]["seg"].shape)
"""

def __init__(
self,
root_dir: PathLike,
collection: str,
section: str,
transform: Union[Sequence[Callable], Callable] = (),
download: bool = False,
download_len: int = -1,
seg_type: str = "SEG",
modality_tag: Tuple = (0x0008, 0x0060),
ref_series_uid_tag: Tuple = (0x0020, 0x000E),
ref_sop_uid_tag: Tuple = (0x0008, 0x1155),
specific_tags: Tuple = (
(0x0008, 0x1115), # Referenced Series Sequence
(0x0008, 0x1140), # Referenced Image Sequence
(0x3006, 0x0010), # Referenced Frame of Reference Sequence
(0x0020, 0x000D), # Study Instance UID
(0x0010, 0x0010), # Patient's Name
(0x0010, 0x0020), # Patient ID
(0x0020, 0x0011), # Series Number
(0x0020, 0x0012), # Acquisition Number
),
seed: int = 0,
val_frac: float = 0.2,
cache_num: int = sys.maxsize,
cache_rate: float = 0.0,
num_workers: int = 1,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
raise ValueError("Root directory root_dir must be a directory.")

self.section = section
self.val_frac = val_frac
self.seg_type = seg_type
self.modality_tag = modality_tag
self.ref_series_uid_tag = ref_series_uid_tag
self.ref_sop_uid_tag = ref_sop_uid_tag

self.set_random_state(seed=seed)
download_dir = os.path.join(root_dir, collection)
load_tags = list(specific_tags)
load_tags += [modality_tag]
self.load_tags = load_tags
if download:
seg_series_list = get_tcia_metadata(
query=f"getSeries?Collection={collection}&Modality={seg_type}", attribute="SeriesInstanceUID"
)
if download_len > 0:
seg_series_list = seg_series_list[:download_len]
if len(seg_series_list) == 0:
raise ValueError(f"Cannot find data with collection: {collection} seg_type: {seg_type}")
for series_uid in seg_series_list:
self._download_series_reference_data(series_uid, download_dir)

if not os.path.exists(download_dir):
raise RuntimeError(f"Cannot find dataset directory: {download_dir}.")

self.indices: np.ndarray = np.array([])
self.datalist = self._generate_data_list(download_dir)

if transform == ():
transform = LoadImaged(reader="PydicomReader", keys=["image"])
CacheDataset.__init__(
self,
data=self.datalist,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
as_contiguous=as_contiguous,
)

def get_indices(self) -> np.ndarray:
"""
Get the indices of datalist used in this dataset.
"""
return self.indices

def randomize(self, data: np.ndarray) -> None:
self.R.shuffle(data)

def _download_series_reference_data(self, series_uid: str, download_dir: str):
"""
First of all, download a series from TCIA according to `series_uid`.
Then find all referenced series and download.
"""
seg_first_dir = os.path.join(download_dir, "raw", series_uid)
download_tcia_series_instance(
series_uid=series_uid, download_dir=download_dir, output_dir=seg_first_dir, check_md5=False
)
dicom_files = [f for f in os.listdir(seg_first_dir) if f.endswith(".dcm")]
# achieve series number and patient id from the first dicom file
dcm_path = os.path.join(seg_first_dir, dicom_files[0])
ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)
# (0x0010,0x0020) and (0x0010,0x0010), better to be contained in `specific_tags`
patient_id = ds.PatientID if ds.PatientID else ds.PatientName
if not patient_id:
warnings.warn(f"unable to find patient name of dicom file: {dcm_path}, use 'patient' instead.")
patient_id = "patient"
# (0x0020,0x0011) and (0x0020,0x0012), better to be contained in `specific_tags`
series_num = ds.SeriesNumber if ds.SeriesNumber else ds.AcquisitionNumber
if not series_num:
warnings.warn(f"unable to find series number of dicom file: {dcm_path}, use '0' instead.")
series_num = 0

series_num = str(series_num)
seg_dir = os.path.join(download_dir, patient_id, series_num, self.seg_type.lower())
dcm_dir = os.path.join(download_dir, patient_id, series_num, "image")

# get ref uuid
ref_uid_list = []
for dcm_file in dicom_files:
dcm_path = os.path.join(seg_first_dir, dcm_file)
ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)
if ds[self.modality_tag].value == self.seg_type:
ref_uid = get_tcia_ref_uid(
ds, find_sop=False, ref_series_uid_tag=self.ref_series_uid_tag, ref_sop_uid_tag=self.ref_sop_uid_tag
)
if ref_uid == "":
ref_sop_uid = get_tcia_ref_uid(
ds,
find_sop=True,
ref_series_uid_tag=self.ref_series_uid_tag,
ref_sop_uid_tag=self.ref_sop_uid_tag,
)
ref_uid = match_tcia_ref_uid_in_study(ds.StudyInstanceUID, ref_sop_uid)
if ref_uid != "":
ref_uid_list.append(ref_uid)
if not ref_uid_list:
warnings.warn(f"Cannot find the referenced Series Instance UID from series: {series_uid}.")
else:
download_tcia_series_instance(
series_uid=ref_uid_list[0], download_dir=download_dir, output_dir=dcm_dir, check_md5=False
)
if not os.path.exists(seg_dir):
shutil.copytree(seg_first_dir, seg_dir)

def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
# the types of the item in data list should be compatible with the dataloader
dataset_dir = Path(dataset_dir)
datalist = []
patient_list = [f.name for f in os.scandir(dataset_dir) if f.is_dir() and f.name != "raw"]
for patient_id in patient_list:
series_list = [f.name for f in os.scandir(os.path.join(dataset_dir, patient_id)) if f.is_dir()]
for series_num in series_list:
seg_key = self.seg_type.lower()
image_path = os.path.join(dataset_dir, patient_id, series_num, "image")
mask_path = os.path.join(dataset_dir, patient_id, series_num, seg_key)

if os.path.exists(image_path):
datalist.append({"image": image_path, seg_key: mask_path})
else:
datalist.append({seg_key: mask_path})

return self._split_datalist(datalist)

def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
if self.section == "test":
return datalist
length = len(datalist)
indices = np.arange(length)
self.randomize(indices)

val_length = int(length * self.val_frac)
if self.section == "training":
self.indices = indices[val_length:]
else:
self.indices = indices[:val_length]

return [datalist[i] for i in self.indices]


class CrossValidation:
"""
Cross validation dataset based on the general dataset which must have `_split_datalist` API.
Expand Down
13 changes: 13 additions & 0 deletions monai/apps/tcia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .label_desc import TCIA_LABEL_DICT
from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study
49 changes: 49 additions & 0 deletions monai/apps/tcia/label_desc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict

__all__ = ["TCIA_LABEL_DICT"]


TCIA_LABEL_DICT: Dict[str, Dict[str, int]] = {
"C4KC-KiTS": {"Kidney": 0, "Renal Tumor": 1},
"NSCLC-Radiomics": {
"Esophagus": 0,
"GTV-1": 1,
"Lungs-Total": 2,
"Spinal-Cord": 3,
"Lung-Left": 4,
"Lung-Right": 5,
"Heart": 6,
},
"NSCLC-Radiomics-Interobserver1": {
"GTV-1auto-1": 0,
"GTV-1auto-2": 1,
"GTV-1auto-3": 2,
"GTV-1auto-4": 3,
"GTV-1auto-5": 4,
"GTV-1vis-1": 5,
"GTV-1vis-2": 6,
"GTV-1vis-3": 7,
"GTV-1vis-4": 8,
"GTV-1vis-5": 9,
},
"QIN-PROSTATE-Repeatability": {"NormalROI_PZ_1": 0, "TumorROI_PZ_1": 1, "PeripheralZone": 2, "WholeGland": 3},
"PROSTATEx": {
"Prostate": 0,
"Peripheral zone of prostate": 1,
"Transition zone of prostate": 2,
"Distal prostatic urethra": 3,
"Anterior fibromuscular stroma of prostate": 4,
},
}
Loading

0 comments on commit 422cc6d

Please sign in to comment.