Skip to content

Commit

Permalink
Add support for level in PatchWSIDataset (#4036)
Browse files Browse the repository at this point in the history
* Add support for level as input

Signed-off-by: Behrooz <[email protected]>

* Add unittests for levels

Signed-off-by: Behrooz <[email protected]>

* Update docstring

Signed-off-by: Behrooz <[email protected]>

* Add kwargs for WSIReader in all datasets

Signed-off-by: Behrooz <[email protected]>

* Update docstring

Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
bhashemian authored Apr 2, 2022
1 parent 541d7ab commit de97391
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 7 deletions.
17 changes: 12 additions & 5 deletions monai/apps/pathology/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PatchWSIDataset(Dataset):
transform: transforms to be executed on input data.
image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
Defaults to CuCIM.
kwargs: additional parameters for ``WSIReader``
Note:
The input data has the following form as an example:
Expand All @@ -56,6 +57,7 @@ def __init__(
patch_size: Union[int, Tuple[int, int]],
transform: Optional[Callable] = None,
image_reader_name: str = "cuCIM",
**kwargs,
):
super().__init__(data, transform)

Expand All @@ -65,7 +67,7 @@ def __init__(

self.image_path_list = list({x["image"] for x in self.data})
self.image_reader_name = image_reader_name.lower()
self.image_reader = WSIReader(image_reader_name)
self.image_reader = WSIReader(backend=image_reader_name, **kwargs)
self.wsi_object_dict = None
if self.image_reader_name != "openslide":
# OpenSlide causes memory issue if we prefetch image objects
Expand Down Expand Up @@ -119,17 +121,18 @@ class SmartCachePatchWSIDataset(SmartCacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_init_workers: the number of worker threads to initialize the cache for first epoch.
If num_init_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is speficied, 1 will be used instead.
If a value less than 1 is specified, 1 will be used instead.
num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
If num_replace_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is speficied, 1 will be used instead.
If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar when caching for the first epoch.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cache content
or every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
kwargs: additional parameters for ``WSIReader``
"""

Expand All @@ -149,13 +152,15 @@ def __init__(
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
**kwargs,
):
patch_wsi_dataset = PatchWSIDataset(
data=data,
region_size=region_size,
grid_shape=grid_shape,
patch_size=patch_size,
image_reader_name=image_reader_name,
**kwargs,
)
super().__init__(
data=patch_wsi_dataset, # type: ignore
Expand Down Expand Up @@ -183,7 +188,8 @@ class MaskedInferenceWSIDataset(Dataset):
patch_size: the size of patches to be extracted from the whole slide image for inference.
transform: transforms to be executed on extracted patches.
image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
Defaults to CuCIM.
Defaults to CuCIM.
kwargs: additional parameters for ``WSIReader``
Note:
The resulting output (probability maps) after performing inference using this dataset is
Expand All @@ -196,14 +202,15 @@ def __init__(
patch_size: Union[int, Tuple[int, int]],
transform: Optional[Callable] = None,
image_reader_name: str = "cuCIM",
**kwargs,
) -> None:
super().__init__(data, transform)

self.patch_size = ensure_tuple_rep(patch_size, 2)

# set up whole slide image reader
self.image_reader_name = image_reader_name.lower()
self.image_reader = WSIReader(image_reader_name)
self.image_reader = WSIReader(backend=image_reader_name, **kwargs)

# process data and create a list of dictionaries containing all required data and metadata
self.data = self._prepare_data(data)
Expand Down
120 changes: 118 additions & 2 deletions tests/test_patch_wsi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]

TEST_CASE_0_L1 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
"level": 1,
"image_reader_name": "cuCIM",
},
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]

TEST_CASE_0_L2 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
"level": 1,
"image_reader_name": "cuCIM",
},
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]


TEST_CASE_1 = [
{
"data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
Expand All @@ -57,6 +82,41 @@
],
]


TEST_CASE_1_L0 = [
{
"data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
"level": 0,
"image_reader_name": "cuCIM",
},
[
{"image": np.array([[[247]], [[245]], [[248]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[245]], [[247]], [[244]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[[1]]])},
],
]


TEST_CASE_1_L1 = [
{
"data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
"level": 1,
"image_reader_name": "cuCIM",
},
[
{"image": np.array([[[248]], [[246]], [[249]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[196]], [[187]], [[192]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[245]], [[243]], [[244]]], dtype=np.uint8), "label": np.array([[[0]]])},
{"image": np.array([[[246]], [[242]], [[243]]], dtype=np.uint8), "label": np.array([[[1]]])},
],
]
TEST_CASE_2 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
Expand Down Expand Up @@ -90,6 +150,43 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]

TEST_CASE_OPENSLIDE_0_L0 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
"level": 0,
"image_reader_name": "OpenSlide",
},
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]

TEST_CASE_OPENSLIDE_0_L1 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
"level": 1,
"image_reader_name": "OpenSlide",
},
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]


TEST_CASE_OPENSLIDE_0_L2 = [
{
"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
"level": 2,
"image_reader_name": "OpenSlide",
},
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]

TEST_CASE_OPENSLIDE_1 = [
{
"data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
Expand All @@ -113,7 +210,18 @@ def setUp(self):
hash_val = testing_data_config("images", FILE_KEY, "hash_val")
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)

@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_0_L1,
TEST_CASE_0_L2,
TEST_CASE_1,
TEST_CASE_1_L0,
TEST_CASE_1_L1,
TEST_CASE_2,
TEST_CASE_3,
]
)
@skipUnless(has_cim, "Requires CuCIM")
def test_read_patches_cucim(self, input_parameters, expected):
dataset = PatchWSIDataset(**input_parameters)
Expand All @@ -124,7 +232,15 @@ def test_read_patches_cucim(self, input_parameters, expected):
self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"]))
self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"]))

@parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1])
@parameterized.expand(
[
TEST_CASE_OPENSLIDE_0,
TEST_CASE_OPENSLIDE_0_L0,
TEST_CASE_OPENSLIDE_0_L1,
TEST_CASE_OPENSLIDE_0_L2,
TEST_CASE_OPENSLIDE_1,
]
)
@skipUnless(has_osl, "Requires OpenSlide")
def test_read_patches_openslide(self, input_parameters, expected):
dataset = PatchWSIDataset(**input_parameters)
Expand Down

0 comments on commit de97391

Please sign in to comment.