diff --git a/tests/conftest.py b/tests/conftest.py index 2b7de0fd6..cdc53dee0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -531,6 +531,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict: "wsi4_4k_4k_svs", "wsi3_20k_20k_pred", "wsi4_4k_4k_pred", + "wsi4_1k_1k_svs", ] return {name: remote_sample(name) for name in file_names} diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 0cbff82c1..44e6e1953 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -521,7 +521,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) eng = TestEngineABC(model=model) kwargs = { - "patch_input_shape": [512, 512], + "patch_input_shape": [224, 224], "input_resolutions": [{"units": "mpp", "resolution": 1.75}], } with caplog.at_level(logging.WARNING): @@ -537,7 +537,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) # test providing config / full input info for non pretrained models ioconfig = ModelIOConfigABC( - patch_input_shape=(512, 512), + patch_input_shape=(224, 224), stride_shape=(256, 256), input_resolutions=[{"resolution": 1.35, "units": "mpp"}], ) @@ -547,7 +547,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", ioconfig=ioconfig, ) - assert eng._ioconfig.patch_input_shape == (512, 512) + assert eng._ioconfig.patch_input_shape == (224, 224) assert eng._ioconfig.stride_shape == (256, 256) assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -558,15 +558,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", **kwargs, ) - assert eng._ioconfig.patch_input_shape == [512, 512] - assert eng._ioconfig.stride_shape == [512, 512] + assert eng._ioconfig.patch_input_shape == [224, 224] + assert eng._ioconfig.stride_shape == [224, 224] assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test overwriting pretrained ioconfig eng = TestEngineABC(model="alexnet-kather100k") eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=[{"units": "baseline", "resolution": 1.99}], @@ -580,7 +580,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) shutil.rmtree(tmp_path / "dump", ignore_errors=True) eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=None, diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py new file mode 100644 index 000000000..3a2eae82f --- /dev/null +++ b/tests/engines/test_semantic_segmentor.py @@ -0,0 +1,256 @@ +"""Test SemanticSegmentor.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from typing import Callable + +import numpy as np +import torch +import zarr + +from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import imread + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_semantic_segmentor_init() -> None: + """Tests SemanticSegmentor initialization.""" + segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device) + + assert isinstance(segmentor, SemanticSegmentor) + assert isinstance(segmentor.model, torch.nn.Module) + + +def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) -> None: + """Tests SemanticSegmentor on image patches.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image, sample_image] + + assert segmentor.cache_mode is False + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 + + assert ( + tuple(segmentor._ioconfig.patch_output_shape) + == output["probabilities"][0].shape[:-1] + ) + + assert ( + tuple(segmentor._ioconfig.patch_output_shape) == output["predictions"][0].shape + ) + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + save_dir=tmp_path / "output0", + ) + + assert output == tmp_path / "output0" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + output_type="zarr", + save_dir=tmp_path / "output1", + ) + + assert output == tmp_path / "output1" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert "probabilities" not in output.keys() # noqa: SIM118 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=False, + save_dir=tmp_path / "output2", + output_type="zarr", + ) + + assert output == tmp_path / "output2" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert "probabilities" not in output + assert "predictions" in output + + +def _test_store_output_patch(output: Path) -> None: + """Helper method to test annotation store output for a patch.""" + store_ = SQLiteStore.open(output) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + assert "Polygon" in annotations_geometry_type + + con = sqlite3.connect(output) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + out = [] + + for item in annotations_properties: + for json_str in item: + probs = json.loads(json_str) + if "type" in probs: + out.append(probs.pop("type")) + + assert "mask" in out + + assert annotations_properties is not None + + +def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None: + """Test for saving output as annotation store.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image] + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=False, + save_dir=tmp_path / "output1", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output1" / (sample_image.stem + ".db") + assert len(output) == 1 + _test_store_output_patch(output[0]) + + +def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path) -> None: + """Test for saving output as annotation store using a numpy array.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + input_image = imread(sample_image) + inputs_list = [input_image, input_image] + + output = segmentor.run( + images=inputs_list, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + save_dir=tmp_path / "output1", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output1" / "0.db" + assert output[1] == tmp_path / "output1" / "1.db" + + assert output[2] == tmp_path / "output1.zarr" + + _test_store_output_patch(output[0]) + _test_store_output_patch(output[1]) + + output = segmentor.run( + images=inputs_list, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + save_dir=tmp_path / "output2", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output2" / "0.db" + assert output[1] == tmp_path / "output2" / "1.db" + + assert len(output) == 2 + + _test_store_output_patch(output[0]) + _test_store_output_patch(output[1]) + + +def test_wsi_segmentor_zarr( + remote_sample: Callable, sample_svs: Path, tmp_path: Path +) -> None: + """Test SemanticSegmentor for WSIs with zarr output.""" + wsi_with_artifacts = Path(remote_sample("wsi3_20k_20k_svs")) + + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=32, + verbose=False, + ) + # Return Probabilities is False + output = segmentor.run( + images=[wsi_with_artifacts], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + output_type="zarr", + ) + + output_ = zarr.open(output[wsi_with_artifacts], mode="r") + assert 0.48 < np.mean(output_["predictions"][:]) < 0.49 + assert "probabilities" not in output_ + + # Return Probabilities is True + # Using small image for faster run + output = segmentor.run( + images=[sample_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check_prob", + output_type="zarr", + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 0.27 < np.mean(output_["predictions"][:]) < 0.28 + assert 0.49 < np.mean(output_["probabilities"][:]) < 0.51 diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index 2ac231c7c..263a6b2c5 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -50,7 +50,7 @@ def test_functional_unet(remote_sample: Callable) -> None: pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) - _ = output[0] + _ = output["probabilities"][0] # run untrained network to test for architecture model = UNetModel( diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index ab9a6033f..22a7c6108 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -20,6 +20,7 @@ ) from tiatoolbox.utils import download_data, imread, imwrite, unzip_data from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore import WSIReader RNG = np.random.default_rng() # Numpy Random Generator @@ -120,7 +121,9 @@ def test_kather_dataset(tmp_path: Path) -> None: assert len(dataset.inputs) == len(dataset.labels) # to actually get the image, we feed it to PatchDataset - actual_ds = PatchDataset(dataset.inputs, dataset.labels) + actual_ds = PatchDataset( + dataset.inputs, dataset.labels, patch_input_shape=(224, 224) + ) sample_patch = actual_ds[89] assert isinstance(sample_patch["image"], np.ndarray) assert sample_patch["label"] is not None @@ -129,6 +132,18 @@ def test_kather_dataset(tmp_path: Path) -> None: shutil.rmtree(save_dir_path, ignore_errors=True) +def test_incorrect_input_shape() -> None: + """Incorrect input patch dimensions should raise DimensionMismatchError.""" + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs, patch_input_shape=(100, 100)) + with pytest.raises( + DimensionMismatchError, match=r".*\(100, 100\), but got \(5, 5\).*" + ): + _ = dataset[0] + + def test_patch_dataset_path_imgs( sample_patch1: str | Path, sample_patch2: str | Path, @@ -136,7 +151,9 @@ def test_patch_dataset_path_imgs( """Test for patch dataset with a list of file paths as input.""" size = (224, 224, 3) - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + dataset = PatchDataset( + [Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1] + ) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape @@ -152,7 +169,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None: size = (5, 5, 3) img = RNG.integers(low=0, high=255, size=size) list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) + dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1]) dataset.preproc_func = lambda x: x @@ -197,14 +214,14 @@ def test_patch_datasetarray_imgs() -> None: array_imgs = np.array(list_imgs) # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) + dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(5, 5)) an_item = dataset[2] assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) + dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(5, 5)) an_item = dataset[2] assert "label" not in an_item - dataset = PatchDataset(array_imgs) + dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1]) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape assert sampled_img_shape[0] == size[0] @@ -329,7 +346,6 @@ def test_wsi_patch_dataset( # noqa: PLR0915 """A test for creation and bare output.""" # convert to pathlib Path to prevent wsireader complaint mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: @@ -338,7 +354,7 @@ def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: """Testing function.""" - return reuse_init(mode="wsi", **kwargs) + return reuse_init(**kwargs) # test for ABC validate # intentionally created to check error @@ -363,7 +379,6 @@ def __getitem__(self: Proto, idx: int) -> object: with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): WSIPatchDataset( img_path="aaaa", - mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], auto_get_mask=False, @@ -374,7 +389,6 @@ def __getitem__(self: Proto, idx: int) -> object: WSIPatchDataset( img_path=mini_wsi_svs, mask_path="aaaa", - mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], resolution=1.0, @@ -382,10 +396,6 @@ def __getitem__(self: Proto, idx: int) -> object: auto_get_mask=False, ) - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - # invalid patch with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): reuse_init() @@ -468,9 +478,8 @@ def __getitem__(self: Proto, idx: int) -> object: ds = WSIPatchDataset( img_path=mini_wsi_svs, mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], + patch_input_shape=(512, 512), + stride_shape=(256, 256), auto_get_mask=False, resolution=1.0, units="mpp", @@ -483,41 +492,13 @@ def __getitem__(self: Proto, idx: int) -> object: ds = WSIPatchDataset( img_path=mini_wsi_svs, mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], + patch_input_shape=(512, 512), + stride_shape=(256, 256), auto_get_mask=False, resolution=1.0, units="mpp", ) - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - def test_patch_dataset_abc() -> None: """Test for ABC methods. diff --git a/tests/test_utils.py b/tests/test_utils.py index 77b5b6646..b14787e4d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1672,7 +1672,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1685,7 +1685,7 @@ def test_patch_pred_store() -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) patch_output = { "predictions": [1, 0, 1], @@ -1693,7 +1693,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1710,7 +1710,9 @@ def test_patch_pred_store_cdict() -> None: "other": "other", } class_dict = {0: "class0", 1: "class1"} - store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) + store = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), class_dict=class_dict + ) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1731,7 +1733,7 @@ def test_patch_pred_store_sf() -> None: "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], "labels": [1, 0, 1], } - store = misc.dict_to_store(patch_output, (2.0, 2.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (2.0, 2.0)) # Check that its an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1788,7 +1790,9 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: } save_path = tmp_path / "patch_output" / "output.db" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1806,7 +1810,7 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: @@ -1822,7 +1826,9 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: # sends the path of a jpeg source image, expects .db file in the same directory save_path = tmp_path / "patch_output" / "output.jpeg" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1840,7 +1846,7 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_torch_compile_already_compiled() -> None: diff --git a/tiatoolbox/data/remote_samples.yaml b/tiatoolbox/data/remote_samples.yaml index 1b7bf2bf1..44e7d3492 100644 --- a/tiatoolbox/data/remote_samples.yaml +++ b/tiatoolbox/data/remote_samples.yaml @@ -21,6 +21,8 @@ files: extract: True svs-1-small: url: [*wsis, "CMU-1-Small-Region.svs"] + thumbnail-1k-1k: + url: [*wsis, "CMU-2_1k_1k-thumbnail.png"] tiled-tiff-1-small-jpeg: url: [*wsis, "CMU-1-Small-Region.jpeg.tiff"] tiled-tiff-1-small-jp2k: diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 6385e7587..0f44b5d84 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F # noqa: N812 @@ -10,9 +10,16 @@ from torchvision.models.resnet import Bottleneck as ResNetBottleneck from torchvision.models.resnet import ResNet -from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop +from tiatoolbox.models.architecture.utils import ( + UpSample2x, + argmax_last_axis, + centre_crop, +) from tiatoolbox.models.models_abc import ModelABC +if TYPE_CHECKING: # pragma: no cover + import numpy as np + class ResNetEncoder(ResNet): """A subclass of ResNet defined in torch. @@ -416,7 +423,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list: + ) -> dict[str, np.ndarray]: """Run inference on an input batch. This contains logic for forward operation as well as i/o @@ -432,9 +439,8 @@ def infer_batch( Transfers model to the specified device. Default is "cpu". Returns: - list: - List of network output head, each output is an - :class:`numpy.ndarray`. + dict: + A dict with "probabilities" key and a :class:`numpy.ndarray` as output. """ model.eval() @@ -457,7 +463,14 @@ def infer_batch( align_corners=False, ) probs = centre_crop(probs, crop_shape) - probs = probs.permute(0, 2, 3, 1) # to NHWC + output = probs.permute(0, 2, 3, 1) # to NHWC + + return {"probabilities": output.cpu().numpy()} - probs = probs.cpu().numpy() - return [probs] + def postproc(self: UNetModel, image: np.ndarray) -> np.ndarray: + """Define post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + """ + return argmax_last_axis(image=image) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index e9560e59e..63cd055c1 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -234,3 +234,20 @@ def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor: ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw ret = ret.permute(0, 1, 2, 4, 3, 5) return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + + +def argmax_last_axis(image: np.ndarray) -> np.ndarray: + """Define the post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + Args: + image (np.ndarray): + The input image array. + + Returns: + np.ndarray: + The post-processed image array. + + """ + return np.argmax(image, axis=-1) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index b471d026c..efc642a17 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -4,16 +4,17 @@ from typing import TYPE_CHECKING -import numpy as np import timm import torch import torchvision.models as torch_models from timm.layers import SwiGLUPacked from torch import nn +from tiatoolbox.models.architecture.utils import argmax_last_axis from tiatoolbox.models.models_abc import ModelABC if TYPE_CHECKING: # pragma: no cover + import numpy as np from torchvision.models import WeightsEnum @@ -205,23 +206,6 @@ def _get_timm_architecture( raise ValueError(msg) -def _postproc(image: np.ndarray) -> np.ndarray: - """Define the post-processing of this class of model. - - This simply applies argmax along last axis of the input. - - Args: - image (np.ndarray): - The input image array. - - Returns: - np.ndarray: - The post-processed image array. - - """ - return np.argmax(image, axis=-1) - - def _infer_batch( model: nn.Module, batch_data: torch.Tensor, @@ -339,7 +323,7 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( @@ -463,7 +447,7 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 045bb39b7..644271526 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -15,6 +15,7 @@ from tiatoolbox import logger from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -22,7 +23,7 @@ from multiprocessing.managers import Namespace from tiatoolbox.models.engine.io_config import IOSegmentorConfig - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.type_hints import IntPair, Resolution, Units try: from typing import TypeGuard @@ -361,12 +362,12 @@ class WSIPatchDataset(PatchDatasetABC): """ - def __init__( # skipcq: PY-R1000 # noqa: PLR0915 + def __init__( # skipcq: PY-R1000 self: WSIPatchDataset, img_path: str | Path, - mode: str = "wsi", mask_path: str | Path | None = None, patch_input_shape: IntPair = None, + patch_output_shape: IntPair = None, stride_shape: IntPair = None, resolution: Resolution = None, units: Units = None, @@ -378,10 +379,6 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 """Create a WSI-level patch dataset. Args: - mode (str): - Can be either `wsi` or `tile` to denote the image to - read is either a whole-slide image or a large image - tile. img_path (str or Path): Valid to pyramidal whole-slide image or large tile to read. @@ -393,6 +390,12 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 and `units`. Expected to be positive and of (height, width). Note, this is not at `resolution` coordinate space. + patch_output_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + output shape from the model at requested `resolution` + and `units`. Expected to be positive and of (height, + width). Note, this is not at `resolution` coordinate + space. stride_shape: A tuple (int, int) or ndarray of shape (2,). Expected stride shape to read at requested `resolution` and @@ -422,7 +425,6 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 >>> # preprocessing function >>> ds = WSIPatchDataset( ... img_path='/A/B/C/wsi.svs', - ... mode="wsi", ... patch_input_shape=[512, 512], ... stride_shape=[256, 256], ... auto_get_mask=False, @@ -436,9 +438,6 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 if not Path.is_file(Path(img_path)): msg = "`img_path` must be a valid file path." raise ValueError(msg) - if mode not in ["wsi", "tile"]: - msg = f"`{mode}` is not supported." - raise ValueError(msg) patch_input_shape = np.array(patch_input_shape) stride_shape = np.array(stride_shape) @@ -459,47 +458,26 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 self.preproc_func = preproc_func img_path = Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # infer value such that read if mask provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - units = "mpp" - resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) + self.reader = WSIReader.open(img_path) # may decouple into misc ? # the scaling factor will scale base level to requested read resolution/units wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) # use all patches, as long as it overlaps source image - self.inputs = PatchExtractor.get_coordinates( - image_shape=wsi_shape, - patch_input_shape=patch_input_shape[::-1], - stride_shape=stride_shape[::-1], - input_within_bound=False, - ) + if patch_output_shape is not None: + self.inputs, self.outputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + patch_output_shape=patch_output_shape, + ) + else: + self.inputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + ) mask_reader = None if mask_path is not None: @@ -513,7 +491,7 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 mask_reader = VirtualWSIReader(mask) mask_reader.info = self.reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: + elif auto_get_mask and mask_path is None: # if no mask provided and `wsi` mode, generate basic tissue # mask on the fly mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") @@ -528,6 +506,8 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 min_mask_ratio=min_mask_ratio, ) self.inputs = self.inputs[selected] + if hasattr(self, "outputs"): + self.outputs = self.outputs[selected] if len(self.inputs) == 0: msg = "No patch coordinates remain after filtering." @@ -572,6 +552,10 @@ class PatchDataset(PatchDatasetABC): labels (list): List of labels for sample at the same index in `inputs`. Default is `None`. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. Examples: >>> # A user defined preproc func and expected behavior @@ -581,6 +565,7 @@ class PatchDataset(PatchDatasetABC): >>> ds = PatchDataset( ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], ... labels=["labels1", "labels2"], + ... patch_input_shape=(224, 224), ... ) """ @@ -589,6 +574,7 @@ def __init__( self: PatchDataset, inputs: np.ndarray | list, labels: list | None = None, + patch_input_shape: IntPair | None = None, ) -> None: """Initialize :class:`PatchDataset`.""" super().__init__() @@ -597,6 +583,7 @@ def __init__( self.inputs = inputs self.labels = labels + self.patch_input_shape = patch_input_shape # perform check on the input self._check_input_integrity(mode="patch") @@ -609,6 +596,18 @@ def __getitem__(self: PatchDataset, idx: int) -> dict: if not self.data_is_npy_alike: patch = self.load_img(patch) + if patch.shape[:-1] != tuple(self.patch_input_shape): + msg = ( + f"Patch size is not compatible with the model. " + f"Expected dimensions {tuple(self.patch_input_shape)}, but got " + f"{patch.shape[:-1]}." + ) + logger.error(msg=msg) + raise DimensionMismatchError( + expected_dims=tuple(self.patch_input_shape), + actual_dims=patch.shape[:-1], + ) + # Apply preprocessing to selected patch patch = self._preproc(patch) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6f6f9afab..734bea042 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -21,7 +21,7 @@ from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( - dict_to_store, + dict_to_store_patch_predictions, dict_to_zarr, write_to_zarr_in_cache_mode, ) @@ -233,6 +233,8 @@ class EngineABC(ABC): # noqa: B024 be downloaded. However, you can override with your own set of weights via the `weights` argument. Argument is case-insensitive. + dataloader (torch.utils.data.DataLoader): + :class:`torch.utils.data.DataLoader` used during inference. ioconfig (ModelIOConfigABC): Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. _ioconfig (ModelIOConfigABC): @@ -364,7 +366,8 @@ def __init__( self.input_resolutions: list[dict[Units, Resolution]] | None = None self.return_labels: bool = False self.stride_shape: IntPair | None = None - self.verbose = verbose + self.verbose: bool = verbose + self.dataloader: DataLoader | None = None @staticmethod def _initialize_model_ioconfig( @@ -460,7 +463,6 @@ def get_dataloader( if not patch_mode: dataset = WSIPatchDataset( img_path=images, - mode="wsi", mask_path=masks, patch_input_shape=ioconfig.patch_input_shape, stride_shape=ioconfig.stride_shape, @@ -479,7 +481,10 @@ def get_dataloader( shuffle=False, ) - dataset = PatchDataset(inputs=images, labels=labels) + dataset = PatchDataset( + inputs=images, labels=labels, patch_input_shape=ioconfig.patch_input_shape + ) + dataset.preproc_func = self.model.preproc_func # preprocessing must be defined with the dataset @@ -630,7 +635,7 @@ def save_predictions( processed_predictions: dict | Path, output_type: str, save_dir: Path | None = None, - **kwargs: dict, + **kwargs: EngineABCRunParams, ) -> dict | AnnotationStore | Path: """Save model predictions. @@ -676,7 +681,7 @@ def save_predictions( processed_predictions_path = processed_predictions processed_predictions = zarr.open(processed_predictions, mode="r") - out_file = dict_to_store( + out_file = dict_to_store_patch_predictions( processed_predictions, scale_factor, class_dict, @@ -966,14 +971,15 @@ def _run_patch_mode( duplicate_filter = DuplicateFilter() logger.addFilter(duplicate_filter) - dataloader = self.get_dataloader( + self.dataloader = self.get_dataloader( images=self.images, masks=self.masks, labels=self.labels, patch_mode=True, + ioconfig=self._ioconfig, ) raw_predictions = self.infer_patches( - dataloader=dataloader, + dataloader=self.dataloader, save_path=save_path, return_coordinates=output_type == "annotationstore", ) @@ -1067,17 +1073,17 @@ def _run_wsi_mode( duplicate_filter = DuplicateFilter() logger.addFilter(duplicate_filter) mask = self.masks[image_num] if self.masks is not None else None - dataloader = self.get_dataloader( + self.dataloader = self.get_dataloader( images=image, masks=mask, patch_mode=False, ioconfig=self._ioconfig, ) - scale_factor = self._calculate_scale_factor(dataloader=dataloader) + scale_factor = self._calculate_scale_factor(dataloader=self.dataloader) raw_predictions = self.infer_wsi( - dataloader=dataloader, + dataloader=self.dataloader, save_path=save_path[image], **kwargs, ) diff --git a/tiatoolbox/models/engine/semantic_segmentor_new.py b/tiatoolbox/models/engine/semantic_segmentor_new.py new file mode 100644 index 000000000..30859e6d8 --- /dev/null +++ b/tiatoolbox/models/engine/semantic_segmentor_new.py @@ -0,0 +1,635 @@ +"""Defines SemanticSegmentor Engine.""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import torch +import zarr +from typing_extensions import Unpack + +from tiatoolbox import logger +from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset +from tiatoolbox.utils.misc import ( + dict_to_store_semantic_segmentor, +) + +from .patch_predictor import PatchPredictor, PredictorRunParams + +if TYPE_CHECKING: # pragma: no cover + import os + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.type_hints import Resolution + from tiatoolbox.wsicore import WSIReader + + +class SemanticSegmentorRunParams(PredictorRunParams): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + input_resolutions (list(dict(Units, Resolution))):: + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + verbose (bool): + Whether to output logging information. + + """ + + patch_output_shape: tuple + output_resolutions: Resolution + + +class SemanticSegmentor(PatchPredictor): + r"""Semantic Segmentor Engine for processing digital histology images. + + The tiatoolbox model should produce the following results on the BCSS dataset + using fcn_resnet50_unet-bcss. + + .. list-table:: Semantic segmentation performance on the BCSS dataset + :widths: 15 15 15 15 15 15 15 + :header-rows: 1 + + * - + - Tumour + - Stroma + - Inflammatory + - Necrosis + - Other + - All + * - Amgad et al. + - 0.851 + - 0.800 + - 0.712 + - 0.723 + - 0.666 + - 0.750 + * - TIAToolbox + - 0.885 + - 0.825 + - 0.761 + - 0.765 + - 0.581 + - 0.763 + + Args: + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. Default is `None`. + batch_size (int): + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = SemanticSegmentor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". + verbose (bool): + Whether to output logging information. Default value is False. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (IOSegmentorConfig): + Input IO configuration of type :class:`IOSegmentorConfig` to run the Engine. + _ioconfig (IOSegmentorConfig): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + input_resolutions (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + input_resolutions (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + output_locations (list | None): + A list of coordinates in `[start_x, start_y, end_x, end_y]` format to be + used for patch extraction. + verbose (bool): + Whether to output logging information. Default value is False. + + Examples: + >>> # list of 2 image patches as input + >>> wsis = ['path/img.svs', 'path/img.svs'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(wsis, patch_mode=False) + + >>> # array of list of 2 image patches as input + >>> image_patches = [np.ndarray, np.ndarray] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=True) + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=False) + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(tile_file, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> segmentor = SemanticSegmentor(model="resnet18-kather100k") + >>> output = segmentor.run(wsis, patch_mode=False) + + References: + [1] Amgad M, Elfandy H, ..., Gutman DA, Cooper LAD. Structured crowdsourcing + enables convolutional segmentation of histology images. Bioinformatics 2019. + doi: 10.1093/bioinformatics/btz083 + + """ + + def __init__( + self: SemanticSegmentor, + model: str | ModelABC, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`SemanticSegmentor`.""" + super().__init__( + model=model, + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + weights=weights, + device=device, + verbose=verbose, + ) + self.output_locations: list | None = None + + def get_dataloader( + self: SemanticSegmentor, + images: str | Path | list[str | Path] | np.ndarray, + masks: Path | None = None, + labels: list | None = None, + ioconfig: SemanticSegmentorRunParams | None = None, + *, + patch_mode: bool = True, + ) -> torch.utils.data.DataLoader: + """Pre-process images and masks and return dataloader for inference. + + Args: + images (list of str or :class:`Path` or :class:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. When `patch_mode` is False + the function expects list of str/paths to WSIs. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + ioconfig (ModelIOConfigABC): + A :class:`ModelIOConfigABC` object. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + + Returns: + torch.utils.data.DataLoader: + :class:`torch.utils.data.DataLoader` for inference. + + """ + # Overwrite when patch_mode is False. + if not patch_mode: + dataset = WSIPatchDataset( + img_path=images, + mask_path=masks, + patch_input_shape=ioconfig.patch_input_shape, + patch_output_shape=ioconfig.patch_output_shape, + stride_shape=ioconfig.stride_shape, + resolution=ioconfig.input_resolutions[0]["resolution"], + units=ioconfig.input_resolutions[0]["units"], + ) + + dataset.preproc_func = self.model.preproc_func + self.output_locations = dataset.outputs + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + + return super().get_dataloader( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + ) + + def post_process_cache_mode( + self: SemanticSegmentor, + raw_predictions: Path, + **kwargs: Unpack[PredictorRunParams], + ) -> Path: + """Returns an array from raw predictions. + + Merges raw predictions from individual patches into a single prediction array if + patch_mode is False. + + """ + if self.patch_mode: + return super().post_process_cache_mode( + raw_predictions=raw_predictions, + **kwargs, + ) + + return_probabilities = kwargs.get("return_probabilities") + merged_resolution = self.ioconfig.highest_input_resolution + + zarr_group = zarr.open(str(raw_predictions), mode="r+") + + # Calculate canvas parameters + wsi_reader = self.dataloader.dataset.reader + ioconfig = self.ioconfig + in_out_ratio = np.array(ioconfig.patch_input_shape) / np.array( + ioconfig.patch_output_shape + ) + padding_size = ( + np.array(ioconfig.patch_input_shape) + - np.array(ioconfig.stride_shape) * 4 / in_out_ratio + ) + slide_dimensions = wsi_reader.slide_dimensions(**merged_resolution) + merged_shape = [ + *slide_dimensions + padding_size.astype(int), + zarr_group["probabilities"].shape[3], + ] + + # create dataset for merged probabilities + merged_probabilities = zarr_group.create_dataset( + name="merged_probabilities", + shape=merged_shape, + compressor=zarr_group["probabilities"].compressor, + ) + + merged_weights = np.zeros_like(merged_probabilities) + + for idx, location in enumerate(self.output_locations): + start_x, start_y, end_x, end_y = location + merged_probabilities[start_y:end_y, start_x:end_x, :] += zarr_group[ + "probabilities" + ][idx][0 : end_y - start_y, 0 : end_x - start_x, :] + merged_weights[start_y:end_y, start_x:end_x] += 1 + + # Normalize + merged_weights[merged_weights == 0] = 1 + merged_probabilities[:] = merged_probabilities[:] / merged_weights[:] + + # save merged probabilities as single output probabilities + zarr_group["probabilities"] = merged_probabilities + del zarr_group["merged_probabilities"] + + zarr_group["predictions"] = self.model.postproc_func( + zarr_group["probabilities"], + ) + + zarr_group["predictions"] = zarr_group["predictions"][ + 0 : slide_dimensions[0], + 0 : slide_dimensions[1], + ] + + if not return_probabilities: + del zarr_group["probabilities"] + return raw_predictions + + zarr_group["probabilities"] = zarr_group["probabilities"][ + 0 : slide_dimensions[0], + 0 : slide_dimensions[1], + :, + ] + return raw_predictions + + def save_predictions( + self: PatchPredictor, + processed_predictions: dict | Path, + output_type: str, + save_dir: Path | None = None, + **kwargs: SemanticSegmentorRunParams, + ) -> dict | AnnotationStore | Path | list[Path]: + """Save semantic segmentation predictions to disk. + + Args: + processed_predictions (dict | Path): + A dictionary or path to zarr with model prediction information. + save_dir (Path): + Optional output path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided `patch_mode` is True. If the + `patch_mode` is False then `save_dir` is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (SemanticSegmentorRunParams): + Keyword Args required to save the output. + + Returns: + dict or Path or :class:`AnnotationStore`: + If the `output_type` is "AnnotationStore", the function returns + the patch predictor output as an SQLiteStore containing Annotations + for each or the Path to a `.db` file depending on whether a + save_dir Path is provided. Otherwise, the function defaults to + returning patch predictor output, either as a dict or the Path to a + `.zarr` file depending on whether a save_dir Path is provided. + + """ + # Conversion to annotationstore uses a different function for SemanticSegmentor + if output_type.lower() != "annotationstore": + return super().save_predictions( + processed_predictions, output_type, save_dir, **kwargs + ) + + save_path = Path(kwargs.get("output_file", save_dir)) + return_probabilities = kwargs.get("return_probabilities", False) + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + processed_predictions_path: str | Path = None + + # Need to add support for zarr conversion. + if self.cache_mode: + processed_predictions_path = processed_predictions + processed_predictions = zarr.open(processed_predictions, mode="r+") + + save_paths = [] + + for i, predictions in enumerate(processed_predictions["predictions"]): + if isinstance(self.images[i], Path): + output_path = save_path / (self.images[i].stem + ".db") + else: + output_path = save_path / (str(i) + ".db") + + out_file = dict_to_store_semantic_segmentor( + patch_output={"predictions": predictions}, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) + + save_paths.append(out_file) + + processed_predictions.pop("predictions") + + if return_probabilities and self.cache_mode: + new_zarr_name = out_file.parent.with_suffix(".zarr") + processed_predictions_path.rename(new_zarr_name) + msg = ( + f"Probability maps cannot be saved as AnnotationStore. " + f"To visualise heatmaps in TIAToolbox Visualization tool," + f"convert heatmaps in {processed_predictions_path} to ome.tiff using" + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." + ) + logger.info(msg) + save_paths.append(new_zarr_name) + + if processed_predictions_path and processed_predictions_path.exists(): + shutil.rmtree(processed_predictions_path) + + return save_paths + + def run( + self: SemanticSegmentor, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. When using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOSegmentorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (SemanticSegmentorRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. + + Returns: + (:class:`numpy.ndarray`, dict, list): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + - list: List of image paths to the output files. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> image_patches = [np.ndarray, np.ndarray] + >>> class SemanticSegmentor(PatchPredictor): + >>> # Define all Abstract methods. + >>> ... + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = segmentor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = segmentor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'/path/to/wsi1.db'} + + """ + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) diff --git a/tiatoolbox/utils/exceptions.py b/tiatoolbox/utils/exceptions.py index db74af710..2f9f2a126 100644 --- a/tiatoolbox/utils/exceptions.py +++ b/tiatoolbox/utils/exceptions.py @@ -33,3 +33,23 @@ def __init__( ) -> None: """Initialize :class:`MethodNotSupportedError`.""" super().__init__(message) + + +class DimensionMismatchError(Exception): + """Raise dimension mismatch error. + + Args: + expected_dims (list or tuple) : Expected dimensions. + actual_dims (list or tuple) : Actual dimensions. + + """ + + def __init__( + self: DimensionMismatchError, + expected_dims: list | tuple, + actual_dims: list | tuple, + ) -> None: + """Initialize :class:`DimensionMismatchError`.""" + self.expected_dims = expected_dims + self.actual_dims = actual_dims + super().__init__(f"Expected dimensions {expected_dims}, but got {actual_dims}.") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 6a77af183..83994001a 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -163,7 +163,7 @@ def imwrite(image_path: PathLike, img: np.ndarray) -> None: def imread(image_path: PathLike, *, as_uint8: bool | None = None) -> np.ndarray: - """Read an image as a NumPy array. + """Read an image as :class:`numpy.ndarray`. Args: image_path (PathLike): @@ -1408,13 +1408,13 @@ def dict_to_store_semantic_segmentor( return store -def dict_to_store( +def dict_to_store_patch_predictions( patch_output: dict | zarr.group, scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, ) -> AnnotationStore | Path: - """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. + """Converts output of TIAToolbox PatchPredictor engine to AnnotationStore. Args: patch_output (dict | zarr.Group): @@ -1630,21 +1630,21 @@ def dict_to_zarr( compressor = ( kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) ) - chunks = kwargs.get("chunks", 10000) # ensure proper zarr extension save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") - # save to zarr - probabilities_array = np.array(raw_predictions["probabilities"]) - z = zarr.open( - str(save_path), - mode="w", - shape=probabilities_array.shape, - chunks=chunks, - compressor=compressor, - ) - z[:] = probabilities_array + z = zarr.open(str(save_path), mode="w") + + for key, value in raw_predictions.items(): + # save to zarr + array = np.array(value) + z.create_dataset( + name=key, + data=value, + compression=compressor, + shape=array.shape, + ) return save_path diff --git a/tiatoolbox/utils/transforms.py b/tiatoolbox/utils/transforms.py index 8c2817b75..9f1f901c6 100644 --- a/tiatoolbox/utils/transforms.py +++ b/tiatoolbox/utils/transforms.py @@ -95,7 +95,7 @@ def imresize( img: np.ndarray, scale_factor: float | tuple[float, float] | None = None, output_size: int | tuple[int, int] | None = None, - interpolation: str = "optimise", + interpolation: str | int = "optimise", ) -> np.ndarray: """Resize input image.