diff --git a/tests/conftest.py b/tests/conftest.py index e2d5bab6a..c21dcbdc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -520,6 +520,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict: "wsi4_4k_4k_svs", "wsi3_20k_20k_pred", "wsi4_4k_4k_pred", + "wsi4_1k_1k_svs", ] return {name: remote_sample(name) for name in file_names} diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 2f1dec7f8..e50ab7856 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -520,7 +520,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) eng = TestEngineABC(model=model) kwargs = { - "patch_input_shape": [512, 512], + "patch_input_shape": [224, 224], "input_resolutions": [{"units": "mpp", "resolution": 1.75}], } with caplog.at_level(logging.WARNING): @@ -536,7 +536,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) # test providing config / full input info for non pretrained models ioconfig = ModelIOConfigABC( - patch_input_shape=(512, 512), + patch_input_shape=(224, 224), stride_shape=(256, 256), input_resolutions=[{"resolution": 1.35, "units": "mpp"}], ) @@ -546,7 +546,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", ioconfig=ioconfig, ) - assert eng._ioconfig.patch_input_shape == (512, 512) + assert eng._ioconfig.patch_input_shape == (224, 224) assert eng._ioconfig.stride_shape == (256, 256) assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -557,15 +557,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", **kwargs, ) - assert eng._ioconfig.patch_input_shape == [512, 512] - assert eng._ioconfig.stride_shape == [512, 512] + assert eng._ioconfig.patch_input_shape == [224, 224] + assert eng._ioconfig.stride_shape == [224, 224] assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test overwriting pretrained ioconfig eng = TestEngineABC(model="alexnet-kather100k") eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=[{"units": "baseline", "resolution": 1.99}], @@ -579,7 +579,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) shutil.rmtree(tmp_path / "dump", ignore_errors=True) eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=None, diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py new file mode 100644 index 000000000..2932bbfb5 --- /dev/null +++ b/tests/engines/test_semantic_segmentor.py @@ -0,0 +1,197 @@ +"""Test SemanticSegmentor.""" + +from __future__ import annotations + +import json +import sqlite3 +from typing import TYPE_CHECKING, Callable + +import numpy as np +import torch +import zarr + +from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import imread + +if TYPE_CHECKING: + from pathlib import Path + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_semantic_segmentor_init() -> None: + """Tests SemanticSegmentor initialization.""" + segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device) + + assert isinstance(segmentor, SemanticSegmentor) + assert isinstance(segmentor.model, torch.nn.Module) + + +def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) -> None: + """Tests SemanticSegmentor on image patches.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image, sample_image] + + assert segmentor.cache_mode is False + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + assert 0.15 < np.mean(output["predictions"][:]) < 0.18 + assert 0.120 < np.mean(output["probabilities"][:]) < 0.130 + + assert ( + tuple(segmentor._ioconfig.patch_input_shape) + == output["probabilities"][0].shape[:-1] + ) + + assert ( + tuple(segmentor._ioconfig.patch_input_shape) == output["predictions"][0].shape + ) + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + save_dir=tmp_path / "output0", + ) + + assert output == tmp_path / "output0" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.15 < np.mean(output["predictions"][:]) < 0.18 + assert 0.120 < np.mean(output["probabilities"][:]) < 0.130 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + output_type="zarr", + save_dir=tmp_path / "output1", + ) + + assert output == tmp_path / "output1" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.15 < np.mean(output["predictions"][:]) < 0.18 + assert "probabilities" not in output.keys() # noqa: SIM118 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=False, + save_dir=tmp_path / "output2", + output_type="zarr", + ) + + assert output == tmp_path / "output2" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.15 < np.mean(output["predictions"][:]) < 0.18 + assert "probabilities" not in output + assert "predictions" in output + + +def _test_store_output_patch(output: Path) -> None: + """Helper method to test annotation store output for a patch.""" + store_ = SQLiteStore.open(output) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + assert "Polygon" in annotations_geometry_type + + con = sqlite3.connect(output) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + out = [] + + for item in annotations_properties: + for json_str in item: + probs = json.loads(json_str) + if "type" in probs: + out.append(probs.pop("type")) + + assert "mask" in out + + assert annotations_properties is not None + + +def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None: + """Test for saving output as annotation store.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image] + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=False, + save_dir=tmp_path / "output1", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output1" / (sample_image.stem + ".db") + assert len(output) == 1 + _test_store_output_patch(output[0]) + + +def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path) -> None: + """Test for saving output as annotation store using a numpy array.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + input_image = imread(sample_image) + inputs_list = [input_image, input_image] + + output = segmentor.run( + images=inputs_list, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + cache_mode=True, + save_dir=tmp_path / "output1", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output1" / "0.db" + assert output[2] == tmp_path / "output1" / "1.db" + + assert output[1] == tmp_path / "output1" / "0.tif" + assert output[3] == tmp_path / "output1" / "1.tif" + + _test_store_output_patch(output[0]) + _test_store_output_patch(output[2]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index 2ac231c7c..263a6b2c5 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -50,7 +50,7 @@ def test_functional_unet(remote_sample: Callable) -> None: pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) - _ = output[0] + _ = output["probabilities"][0] # run untrained network to test for architecture model = UNetModel( diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index ab9a6033f..a19cb5043 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -20,6 +20,7 @@ ) from tiatoolbox.utils import download_data, imread, imwrite, unzip_data from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore import WSIReader RNG = np.random.default_rng() # Numpy Random Generator @@ -120,7 +121,9 @@ def test_kather_dataset(tmp_path: Path) -> None: assert len(dataset.inputs) == len(dataset.labels) # to actually get the image, we feed it to PatchDataset - actual_ds = PatchDataset(dataset.inputs, dataset.labels) + actual_ds = PatchDataset( + dataset.inputs, dataset.labels, patch_input_shape=(224, 224) + ) sample_patch = actual_ds[89] assert isinstance(sample_patch["image"], np.ndarray) assert sample_patch["label"] is not None @@ -129,6 +132,18 @@ def test_kather_dataset(tmp_path: Path) -> None: shutil.rmtree(save_dir_path, ignore_errors=True) +def test_incorrect_input_shape() -> None: + """Incorrect input patch dimensions should raise DimensionMismatchError.""" + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs, patch_input_shape=(100, 100)) + with pytest.raises( + DimensionMismatchError, match=r".*\(100, 100\), but got \(5, 5\).*" + ): + _ = dataset[0] + + def test_patch_dataset_path_imgs( sample_patch1: str | Path, sample_patch2: str | Path, @@ -136,7 +151,9 @@ def test_patch_dataset_path_imgs( """Test for patch dataset with a list of file paths as input.""" size = (224, 224, 3) - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + dataset = PatchDataset( + [Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1] + ) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape @@ -152,7 +169,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None: size = (5, 5, 3) img = RNG.integers(low=0, high=255, size=size) list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) + dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1]) dataset.preproc_func = lambda x: x @@ -197,14 +214,14 @@ def test_patch_datasetarray_imgs() -> None: array_imgs = np.array(list_imgs) # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) + dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(5, 5)) an_item = dataset[2] assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) + dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(5, 5)) an_item = dataset[2] assert "label" not in an_item - dataset = PatchDataset(array_imgs) + dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1]) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape assert sampled_img_shape[0] == size[0] diff --git a/tests/test_utils.py b/tests/test_utils.py index 2ec93acec..b4267546e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1668,7 +1668,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1681,7 +1681,7 @@ def test_patch_pred_store() -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) patch_output = { "predictions": [1, 0, 1], @@ -1689,7 +1689,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1706,7 +1706,9 @@ def test_patch_pred_store_cdict() -> None: "other": "other", } class_dict = {0: "class0", 1: "class1"} - store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) + store = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), class_dict=class_dict + ) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1727,7 +1729,7 @@ def test_patch_pred_store_sf() -> None: "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], "labels": [1, 0, 1], } - store = misc.dict_to_store(patch_output, (2.0, 2.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (2.0, 2.0)) # Check that its an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1784,7 +1786,9 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: } save_path = tmp_path / "patch_output" / "output.db" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1802,7 +1806,7 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: @@ -1818,7 +1822,9 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: # sends the path of a jpeg source image, expects .db file in the same directory save_path = tmp_path / "patch_output" / "output.jpeg" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1836,7 +1842,7 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_torch_compile_already_compiled() -> None: @@ -1876,6 +1882,79 @@ def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None: assert "torch.compile" in caplog.text +def test_dict_to_store_semantic_segment() -> None: + """Tests multipoint behaviour in dict_to_store.""" + test_pred = np.zeros(shape=(224, 224)) + + patch_output = {"predictions": test_pred} + + store_ = misc.dict_to_store_semantic_segmentor( + patch_output=patch_output, + scale_factor=(1.0, 1.0), + class_dict=None, + save_path=None, + ) + assert not store_.values() + + patch_output["predictions"][100, 100] = 1 + + store_ = misc.dict_to_store_semantic_segmentor( + patch_output=patch_output, + scale_factor=(1.0, 1.0), + class_dict=None, + save_path=None, + ) + assert len(store_) == 1 + + annotations_ = store_.values() + + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + + assert "Point" in annotations_geometry_type + assert "Polygon" not in annotations_geometry_type + + patch_output["predictions"][110:155, 110:115] = 1 + + store_ = misc.dict_to_store_semantic_segmentor( + patch_output=patch_output, + scale_factor=(1.0, 1.0), + class_dict=None, + save_path=None, + ) + assert len(store_) == 2 + + annotations_ = store_.values() + + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + + assert "Point" in annotations_geometry_type + assert "Polygon" in annotations_geometry_type + + patch_output["predictions"][50, 50] = 1 + patch_output["predictions"][50, 51] = 1 + + store_ = misc.dict_to_store_semantic_segmentor( + patch_output=patch_output, + scale_factor=(1.0, 1.0), + class_dict=None, + save_path=None, + ) + assert len(store_) == 3 + annotations_ = store_.values() + + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + + assert "Point" in annotations_geometry_type + assert "Polygon" in annotations_geometry_type + assert "Line String" in annotations_geometry_type + + # Tests for OME tiff writer diff --git a/tiatoolbox/data/remote_samples.yaml b/tiatoolbox/data/remote_samples.yaml index 941200212..a3f29f168 100644 --- a/tiatoolbox/data/remote_samples.yaml +++ b/tiatoolbox/data/remote_samples.yaml @@ -21,6 +21,8 @@ files: extract: True svs-1-small: url: [*wsis, "CMU-1-Small-Region.svs"] + thumbnail-1k-1k: + url: [*wsis, "CMU-2_1k_1k-thumbnail.png"] tiled-tiff-1-small-jpeg: url: [*wsis, "CMU-1-Small-Region.jpeg.tiff"] tiled-tiff-1-small-jp2k: diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 6385e7587..a4b2fd0b1 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -4,13 +4,18 @@ from typing import Any +import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import nn from torchvision.models.resnet import Bottleneck as ResNetBottleneck from torchvision.models.resnet import ResNet -from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop +from tiatoolbox.models.architecture.utils import ( + UpSample2x, + argmax_last_axis, + centre_crop, +) from tiatoolbox.models.models_abc import ModelABC @@ -416,7 +421,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list: + ) -> dict[str, np.ndarray]: """Run inference on an input batch. This contains logic for forward operation as well as i/o @@ -432,9 +437,8 @@ def infer_batch( Transfers model to the specified device. Default is "cpu". Returns: - list: - List of network output head, each output is an - :class:`numpy.ndarray`. + dict: + A dict with "probabilities" key and a :class:`numpy.ndarray` as output. """ model.eval() @@ -457,7 +461,18 @@ def infer_batch( align_corners=False, ) probs = centre_crop(probs, crop_shape) - probs = probs.permute(0, 2, 3, 1) # to NHWC + output = probs.permute(0, 2, 3, 1) # to NHWC + + output = np.pad( + output.cpu().numpy(), pad_width=((0, 0), (256, 256), (256, 256), (0, 0)) + ) + + return {"probabilities": output} - probs = probs.cpu().numpy() - return [probs] + def postproc(self: UNetModel, image: np.ndarray) -> np.ndarray: + """Define post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + """ + return argmax_last_axis(image=image) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index e9560e59e..63cd055c1 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -234,3 +234,20 @@ def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor: ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw ret = ret.permute(0, 1, 2, 4, 3, 5) return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + + +def argmax_last_axis(image: np.ndarray) -> np.ndarray: + """Define the post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + Args: + image (np.ndarray): + The input image array. + + Returns: + np.ndarray: + The post-processed image array. + + """ + return np.argmax(image, axis=-1) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index b471d026c..efc642a17 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -4,16 +4,17 @@ from typing import TYPE_CHECKING -import numpy as np import timm import torch import torchvision.models as torch_models from timm.layers import SwiGLUPacked from torch import nn +from tiatoolbox.models.architecture.utils import argmax_last_axis from tiatoolbox.models.models_abc import ModelABC if TYPE_CHECKING: # pragma: no cover + import numpy as np from torchvision.models import WeightsEnum @@ -205,23 +206,6 @@ def _get_timm_architecture( raise ValueError(msg) -def _postproc(image: np.ndarray) -> np.ndarray: - """Define the post-processing of this class of model. - - This simply applies argmax along last axis of the input. - - Args: - image (np.ndarray): - The input image array. - - Returns: - np.ndarray: - The post-processed image array. - - """ - return np.argmax(image, axis=-1) - - def _infer_batch( model: nn.Module, batch_data: torch.Tensor, @@ -339,7 +323,7 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( @@ -463,7 +447,7 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 045bb39b7..08dc32573 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -15,6 +15,7 @@ from tiatoolbox import logger from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -22,7 +23,7 @@ from multiprocessing.managers import Namespace from tiatoolbox.models.engine.io_config import IOSegmentorConfig - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.type_hints import IntPair, Resolution, Units try: from typing import TypeGuard @@ -572,6 +573,10 @@ class PatchDataset(PatchDatasetABC): labels (list): List of labels for sample at the same index in `inputs`. Default is `None`. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. Examples: >>> # A user defined preproc func and expected behavior @@ -581,6 +586,7 @@ class PatchDataset(PatchDatasetABC): >>> ds = PatchDataset( ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], ... labels=["labels1", "labels2"], + ... patch_input_shape=(224, 224), ... ) """ @@ -589,6 +595,7 @@ def __init__( self: PatchDataset, inputs: np.ndarray | list, labels: list | None = None, + patch_input_shape: IntPair | None = None, ) -> None: """Initialize :class:`PatchDataset`.""" super().__init__() @@ -597,6 +604,7 @@ def __init__( self.inputs = inputs self.labels = labels + self.patch_input_shape = patch_input_shape # perform check on the input self._check_input_integrity(mode="patch") @@ -609,6 +617,18 @@ def __getitem__(self: PatchDataset, idx: int) -> dict: if not self.data_is_npy_alike: patch = self.load_img(patch) + if patch.shape[:-1] != tuple(self.patch_input_shape): + msg = ( + f"Patch size is not compatible with the model. " + f"Expected dimensions {tuple(self.patch_input_shape)}, but got " + f"{patch.shape[:-1]}." + ) + logger.error(msg=msg) + raise DimensionMismatchError( + expected_dims=tuple(self.patch_input_shape), + actual_dims=patch.shape[:-1], + ) + # Apply preprocessing to selected patch patch = self._preproc(patch) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6f6f9afab..e3ede3fe8 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -21,7 +21,7 @@ from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( - dict_to_store, + dict_to_store_patch_predictions, dict_to_zarr, write_to_zarr_in_cache_mode, ) @@ -479,7 +479,10 @@ def get_dataloader( shuffle=False, ) - dataset = PatchDataset(inputs=images, labels=labels) + dataset = PatchDataset( + inputs=images, labels=labels, patch_input_shape=ioconfig.patch_input_shape + ) + dataset.preproc_func = self.model.preproc_func # preprocessing must be defined with the dataset @@ -630,7 +633,7 @@ def save_predictions( processed_predictions: dict | Path, output_type: str, save_dir: Path | None = None, - **kwargs: dict, + **kwargs: EngineABCRunParams, ) -> dict | AnnotationStore | Path: """Save model predictions. @@ -676,7 +679,7 @@ def save_predictions( processed_predictions_path = processed_predictions processed_predictions = zarr.open(processed_predictions, mode="r") - out_file = dict_to_store( + out_file = dict_to_store_patch_predictions( processed_predictions, scale_factor, class_dict, @@ -971,6 +974,7 @@ def _run_patch_mode( masks=self.masks, labels=self.labels, patch_mode=True, + ioconfig=self._ioconfig, ) raw_predictions = self.infer_patches( dataloader=dataloader, diff --git a/tiatoolbox/models/engine/semantic_segmentor_new.py b/tiatoolbox/models/engine/semantic_segmentor_new.py new file mode 100644 index 000000000..3ecee536d --- /dev/null +++ b/tiatoolbox/models/engine/semantic_segmentor_new.py @@ -0,0 +1,494 @@ +"""Defines SemanticSegmentor Engine.""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +import zarr +from typing_extensions import Unpack + +from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor, imwrite_large_tif + +from .patch_predictor import PatchPredictor, PredictorRunParams + +if TYPE_CHECKING: # pragma: no cover + import os + + import numpy as np + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.type_hints import Resolution + from tiatoolbox.wsicore import WSIReader + + +class SemanticSegmentorRunParams(PredictorRunParams): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + input_resolutions (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + patch_output_shape: tuple + output_resolutions: Resolution + + +class SemanticSegmentor(PatchPredictor): + r"""Semantic Segmentor Engine for processing digital histology images. + + The tiatoolbox model should produce the following results on the BCSS dataset + using fcn_resnet50_unet-bcss. + + .. list-table:: Semantic segmentation performance on the BCSS dataset + :widths: 15 15 15 15 15 15 15 + :header-rows: 1 + + * - + - Tumour + - Stroma + - Inflammatory + - Necrosis + - Other + - All + * - Amgad et al. + - 0.851 + - 0.800 + - 0.712 + - 0.723 + - 0.666 + - 0.750 + * - TIAToolbox + - 0.885 + - 0.825 + - 0.761 + - 0.765 + - 0.581 + - 0.763 + + Args: + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. Default is `None`. + batch_size (int): + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = SemanticSegmentor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". + verbose (bool): + Whether to output logging information. Default value is False. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (IOSegmentorConfig): + Input IO configuration of type :class:`IOSegmentorConfig` to run the Engine. + _ioconfig (IOSegmentorConfig): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + input_resolutions (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + input_resolutions (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. + verbose (bool): + Whether to output logging information. Default value is False. + + Examples: + >>> # list of 2 image patches as input + >>> wsis = ['path/img.svs', 'path/img.svs'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(wsis, patch_mode=False) + + >>> # array of list of 2 image patches as input + >>> image_patches = [np.ndarray, np.ndarray] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=True) + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=False) + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(tile_file, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> segmentor = SemanticSegmentor(model="resnet18-kather100k") + >>> output = segmentor.run(wsis, patch_mode=False) + + References: + [1] Amgad M, Elfandy H, ..., Gutman DA, Cooper LAD. Structured crowdsourcing + enables convolutional segmentation of histology images. Bioinformatics 2019. + doi: 10.1093/bioinformatics/btz083 + + """ + + def __init__( + self: SemanticSegmentor, + model: str | ModelABC, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`SemanticSegmentor`.""" + super().__init__( + model=model, + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + weights=weights, + device=device, + verbose=verbose, + ) + + def save_predictions( + self: PatchPredictor, + processed_predictions: dict | Path, + output_type: str, + save_dir: Path | None = None, + **kwargs: SemanticSegmentorRunParams, + ) -> dict | AnnotationStore | Path | list[Path]: + """Save semantic segmentation predictions to disk. + + Args: + processed_predictions (dict | Path): + A dictionary or path to zarr with model prediction information. + save_dir (Path): + Optional output path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided `patch_mode` is True. If the + `patch_mode` is False then `save_dir` is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (SemanticSegmentorRunParams): + Keyword Args required to save the output. + + Returns: + dict or Path or :class:`AnnotationStore`: + If the `output_type` is "AnnotationStore", the function returns + the patch predictor output as an SQLiteStore containing Annotations + for each or the Path to a `.db` file depending on whether a + save_dir Path is provided. Otherwise, the function defaults to + returning patch predictor output, either as a dict or the Path to a + `.zarr` file depending on whether a save_dir Path is provided. + + """ + if ( + self.cache_mode or not save_dir + ) and output_type.lower() != "annotationstore": + return processed_predictions + + if output_type.lower() == "zarr": + return super().save_predictions( + processed_predictions, output_type, save_dir, **kwargs + ) + + save_path = Path(kwargs.get("output_file", save_dir)) + return_probabilities = kwargs.get("return_probabilities", False) + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + processed_predictions_path: str | Path | None = None + + # Need to add support for zarr conversion. + if self.cache_mode: + processed_predictions_path = processed_predictions + processed_predictions = zarr.open(processed_predictions, mode="r") + + save_paths = [] + + for i, predictions in enumerate(processed_predictions["predictions"]): + if isinstance(self.images[i], Path): + output_path = save_path / (self.images[i].stem + ".db") + else: + output_path = save_path / (str(i) + ".db") + + out_file = dict_to_store_semantic_segmentor( + patch_output={"predictions": predictions}, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) + + save_paths.append(out_file) + + if return_probabilities: + imwrite_large_tif( + image_path=output_path.with_suffix(".tif"), + img=processed_predictions["probabilities"], + tile_size=( + self._ioconfig.patch_input_shape[0], + self._ioconfig.patch_input_shape[1], + ), + index=i, + ) + save_paths.append(output_path.with_suffix(".tif")) + + if processed_predictions_path is not None: + shutil.rmtree(processed_predictions_path) + + return save_paths + + def run( + self: SemanticSegmentor, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. When using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOSegmentorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (SemanticSegmentorRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. + + Returns: + (:class:`numpy.ndarray`, dict, list): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + - list: List of image paths to the output files. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> image_patches = [np.ndarray, np.ndarray] + >>> class SemanticSegmentor(PatchPredictor): + >>> # Define all Abstract methods. + >>> ... + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = segmentor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = segmentor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'/path/to/wsi1.db'} + + """ + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) diff --git a/tiatoolbox/utils/exceptions.py b/tiatoolbox/utils/exceptions.py index db74af710..2f9f2a126 100644 --- a/tiatoolbox/utils/exceptions.py +++ b/tiatoolbox/utils/exceptions.py @@ -33,3 +33,23 @@ def __init__( ) -> None: """Initialize :class:`MethodNotSupportedError`.""" super().__init__(message) + + +class DimensionMismatchError(Exception): + """Raise dimension mismatch error. + + Args: + expected_dims (list or tuple) : Expected dimensions. + actual_dims (list or tuple) : Actual dimensions. + + """ + + def __init__( + self: DimensionMismatchError, + expected_dims: list | tuple, + actual_dims: list | tuple, + ) -> None: + """Initialize :class:`DimensionMismatchError`.""" + self.expected_dims = expected_dims + self.actual_dims = actual_dims + super().__init__(f"Expected dimensions {expected_dims}, but got {actual_dims}.") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 439760eeb..5acfd33a9 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -163,7 +163,7 @@ def imwrite(image_path: PathLike, img: np.ndarray) -> None: def imread(image_path: PathLike, as_uint8: bool | None = None) -> np.ndarray: - """Read an image as a NumPy array. + """Read an image as :class:`numpy.ndarray`. Args: image_path (PathLike): @@ -1231,6 +1231,115 @@ def patch_predictions_as_annotations( return annotations +def dict_to_store_semantic_segmentor( + patch_output: dict | zarr.group, + scale_factor: tuple[float, float], + class_dict: dict | None = None, + save_path: Path | None = None, +) -> AnnotationStore | Path: + """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. + + Args: + patch_output (dict | zarr.Group): + A dictionary with "probabilities", "predictions", and "labels" keys. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + class_dict (dict): + Optional dictionary mapping class indices to class names. + save_path (str or Path): + Optional Output directory to save the Annotation + Store results. + + Returns: + (SQLiteStore or Path): + An SQLiteStore containing Annotations for each patch + or Path to file storing SQLiteStore containing Annotations + for each patch. + + """ + preds = patch_output["predictions"] + + # Get the number of unique predictions + layer_list = np.unique(preds) + + layer_list = np.delete(layer_list, np.where(layer_list == 0)) + + count = 1 + + store = SQLiteStore() + + _ = class_dict # use it once overlay is working + + annotations_list = [] + + for type_class in layer_list: + layer = np.where(preds == type_class, 1, 0) + contours, _ = cv2.findContours( + layer.astype("uint8"), + cv2.RETR_TREE, + cv2.CHAIN_APPROX_NONE, + ) + for layer_ in contours: + coords = layer_.squeeze() + count += 1 + + scaled_coords = np.array([scale_factor * coords]) + + # save one points as a line, otherwise save the Polygon + if len(layer_) > 2: # noqa: PLR2004 + feature_geom = feature2geometry( + { + "type": "Polygon", + "coordinates": scaled_coords, + }, + ) + feature_geom = make_valid_poly(feature_geom) + # if two points, save as a line string + elif len(layer_) == 2: # noqa: PLR2004 + feature_geom = feature2geometry( + { + "type": "linestring", + "coordinates": scaled_coords[0], + }, + ) + # if single point, save it is a point + else: + feature_geom = feature2geometry( + { + "type": "point", + "coordinates": scaled_coords, + }, + ) + + annotations_list.extend( + [ + Annotation( + geometry=feature_geom, + properties={"type": "mask"}, + ) + ] + ) + + _ = store.append_many( + annotations_list, [str(i) for i in range(len(annotations_list))] + ) + + # # if a save director is provided, then dump store into a file + if save_path: + # ensure parent directory exists + save_path.parent.absolute().mkdir(parents=True, exist_ok=True) + # ensure proper db extension + save_path = save_path.parent.absolute() / (save_path.stem + ".db") + store.commit() + store.dump(save_path) + return save_path + + return store + + def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarray: """Converts a zarr array into a numpy array.""" if isinstance(zarr_array, zarr.core.Array): @@ -1239,13 +1348,13 @@ def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarra return np.array(zarr_array).astype(float) -def dict_to_store( +def dict_to_store_patch_predictions( patch_output: dict | zarr.group, scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, ) -> AnnotationStore | Path: - """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. + """Converts output of TIAToolbox PatchPredictor engine to AnnotationStore. Args: patch_output (dict | zarr.Group): @@ -1461,21 +1570,21 @@ def dict_to_zarr( compressor = ( kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) ) - chunks = kwargs.get("chunks", 10000) # ensure proper zarr extension save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") - # save to zarr - probabilities_array = np.array(raw_predictions["probabilities"]) - z = zarr.open( - str(save_path), - mode="w", - shape=probabilities_array.shape, - chunks=chunks, - compressor=compressor, - ) - z[:] = probabilities_array + z = zarr.open(str(save_path), mode="w") + + for key, value in raw_predictions.items(): + # save to zarr + array = np.array(raw_predictions[key]) + z.create_dataset( + name=key, + data=value, + compression=compressor, + shape=array.shape, + ) return save_path diff --git a/tiatoolbox/utils/transforms.py b/tiatoolbox/utils/transforms.py index 8c2817b75..9f1f901c6 100644 --- a/tiatoolbox/utils/transforms.py +++ b/tiatoolbox/utils/transforms.py @@ -95,7 +95,7 @@ def imresize( img: np.ndarray, scale_factor: float | tuple[float, float] | None = None, output_size: int | tuple[int, int] | None = None, - interpolation: str = "optimise", + interpolation: str | int = "optimise", ) -> np.ndarray: """Resize input image.