Skip to content

✨ Define SemanticSegmentor with the New EngineABC #866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 66 commits into
base: dev-define-engines-abc
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
193c587
:sparkles: Define `SemanticSegmentor` with the New `EngineABC`
shaneahmed Sep 20, 2024
113fb6b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 20, 2024
44e4b8c
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 21, 2024
7e8d78b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 22, 2024
1ea6c80
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Dec 3, 2024
d0b86a5
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 3, 2025
399827a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
333264a
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 24, 2025
c84099b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 24, 2025
16a632f
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Feb 5, 2025
e03db92
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Feb 21, 2025
13cc94a
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 3, 2025
dd83117
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 4, 2025
967dba1
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 7, 2025
40e79a1
:hammer: Add `run` method to SemanticSegmentor
shaneahmed Mar 7, 2025
85922ff
:memo: Update docstring
shaneahmed Mar 7, 2025
bda0581
:recycle: Refactor resolution to input_resolutions.
shaneahmed Mar 7, 2025
a643ea6
:recycle: Use `input_resolutions` instead of resolution
shaneahmed Mar 7, 2025
7eed649
:recycle: Use `input_resolutions` instead of resolution
shaneahmed Mar 7, 2025
3dff881
:white_check_mark: Add test to cli.
shaneahmed Mar 7, 2025
52ed249
Merge branch 'dev-use-input-resolutions' into dev-define-semantic-seg…
shaneahmed Mar 7, 2025
03e07e6
:white_check_mark: Add SemanticSegmentor patch_mode test.
shaneahmed Mar 7, 2025
a37e71d
:bug: Fix `unet` architecture
shaneahmed Mar 8, 2025
94a747f
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 8, 2025
987927d
Merge branch 'dev-define-engines-abc' into dev-use-input-resolutions
shaneahmed Mar 8, 2025
92bc813
:bug: Fix `test_datset` architecture
shaneahmed Mar 8, 2025
bddd956
Merge branch 'dev-use-input-resolutions' into dev-define-semantic-seg…
shaneahmed Mar 8, 2025
facf461
:white_check_mark: Add postproc to segmentation.
shaneahmed Mar 12, 2025
2b342f4
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' i…
shaneahmed Mar 12, 2025
6f9d412
:recycle: Move argmax postprocessing to utils.
shaneahmed Mar 12, 2025
91a45b3
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 14, 2025
8a95948
:white_check_mark: Check for cache_mode with zarr output.
shaneahmed Mar 17, 2025
cf5b50e
:white_check_mark: Update script for annotation store
shaneahmed Mar 18, 2025
816a568
:construction: Update script for annotation store
shaneahmed Mar 18, 2025
2f88cb7
:construction: Update script for annotation store
shaneahmed Mar 19, 2025
91d134d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
9204604
:bug: Fix saving annotation properties
shaneahmed Mar 19, 2025
84fdd2f
:bug: Fix TID252 Prefer absolute imports over relative imports
shaneahmed Mar 19, 2025
16e3f91
:bug: Test mask to store from #918
shaneahmed Mar 21, 2025
b7862c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2025
a495daf
:bug: Fix AnnotationStore display
shaneahmed Mar 28, 2025
355af0d
:fire: Remove unnecessary test
shaneahmed Mar 31, 2025
9e30cb9
:wastebasket: Clean up code for saving to AnnotationStore
shaneahmed Mar 31, 2025
d9e8f92
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 31, 2025
9011a8f
:bug: Fix mypy checks
shaneahmed Apr 3, 2025
9e1c218
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 3, 2025
136f9f2
:memo: Add comments
shaneahmed Apr 9, 2025
854e2b8
:white_check_mark: Update single and multipoint behaviour
shaneahmed Apr 10, 2025
fb56920
:white_check_mark: Add tests for correct annotation type
shaneahmed Apr 10, 2025
221e53b
:stethoscope: Add checks for correct patch size.
shaneahmed Apr 10, 2025
f69ac89
:white_check_mark: Add tests for incorrect image patch input.
shaneahmed Apr 10, 2025
e9ed0a9
:bug: Fix tests for io config delegation.
shaneahmed Apr 10, 2025
c1b06d5
:package: Add sample image for semantic_segmentor tests.
shaneahmed Apr 10, 2025
211d39d
:bug: Fix conversion for single and two points annotations.
shaneahmed Apr 10, 2025
f2b4678
:bug: Fix semantic segmentor test
shaneahmed Apr 10, 2025
388bbc6
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 10, 2025
56e195c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
c904626
:bug: Fix `ruff` errors.
shaneahmed Apr 10, 2025
fae76f8
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' i…
shaneahmed Apr 10, 2025
7ce679c
:bug: Fix offset values
shaneahmed Apr 10, 2025
f613cef
:bug: Fix tests
shaneahmed Apr 10, 2025
45a5737
:white_check_mark: Add tests for input numpy array
shaneahmed Apr 11, 2025
ab4e93b
:white_check_mark: Add tests to improve coverage.
shaneahmed Apr 11, 2025
a3174dc
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 25, 2025
322d32e
:memo: Fix typos
shaneahmed Apr 25, 2025
fc89d81
:sparkles: Save `probabilities` to `tiff`
shaneahmed Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict:
"wsi4_4k_4k_svs",
"wsi3_20k_20k_pred",
"wsi4_4k_4k_pred",
"wsi4_1k_1k_svs",
]
return {name: remote_sample(name) for name in file_names}

Expand Down
14 changes: 7 additions & 7 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
eng = TestEngineABC(model=model)

kwargs = {
"patch_input_shape": [512, 512],
"patch_input_shape": [224, 224],
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
}
with caplog.at_level(logging.WARNING):
Expand All @@ -536,7 +536,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)

# test providing config / full input info for non pretrained models
ioconfig = ModelIOConfigABC(
patch_input_shape=(512, 512),
patch_input_shape=(224, 224),
stride_shape=(256, 256),
input_resolutions=[{"resolution": 1.35, "units": "mpp"}],
)
Expand All @@ -546,7 +546,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
save_dir=f"{tmp_path}/dump",
ioconfig=ioconfig,
)
assert eng._ioconfig.patch_input_shape == (512, 512)
assert eng._ioconfig.patch_input_shape == (224, 224)
assert eng._ioconfig.stride_shape == (256, 256)
assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}]
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
Expand All @@ -557,15 +557,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
save_dir=f"{tmp_path}/dump",
**kwargs,
)
assert eng._ioconfig.patch_input_shape == [512, 512]
assert eng._ioconfig.stride_shape == [512, 512]
assert eng._ioconfig.patch_input_shape == [224, 224]
assert eng._ioconfig.stride_shape == [224, 224]
assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}]
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

# test overwriting pretrained ioconfig
eng = TestEngineABC(model="alexnet-kather100k")
eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
Expand All @@ -579,7 +579,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=None,
Expand Down
197 changes: 197 additions & 0 deletions tests/engines/test_semantic_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Test SemanticSegmentor."""

from __future__ import annotations

import json
import sqlite3
from typing import TYPE_CHECKING, Callable

import numpy as np
import torch
import zarr

from tiatoolbox.annotation import SQLiteStore
from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import imread

if TYPE_CHECKING:
from pathlib import Path

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_semantic_segmentor_init() -> None:
"""Tests SemanticSegmentor initialization."""
segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device)

assert isinstance(segmentor, SemanticSegmentor)
assert isinstance(segmentor.model, torch.nn.Module)


def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) -> None:
"""Tests SemanticSegmentor on image patches."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image, sample_image]

assert segmentor.cache_mode is False

output = segmentor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
)

assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert 0.120 < np.mean(output["probabilities"][:]) < 0.130

assert (
tuple(segmentor._ioconfig.patch_input_shape)
== output["probabilities"][0].shape[:-1]
)

assert (
tuple(segmentor._ioconfig.patch_input_shape) == output["predictions"][0].shape
)

output = segmentor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
save_dir=tmp_path / "output0",
)

assert output == tmp_path / "output0" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert 0.120 < np.mean(output["probabilities"][:]) < 0.130

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
output_type="zarr",
save_dir=tmp_path / "output1",
)

assert output == tmp_path / "output1" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert "probabilities" not in output.keys() # noqa: SIM118

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=False,
save_dir=tmp_path / "output2",
output_type="zarr",
)

assert output == tmp_path / "output2" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert "probabilities" not in output
assert "predictions" in output


def _test_store_output_patch(output: Path) -> None:
"""Helper method to test annotation store output for a patch."""
store_ = SQLiteStore.open(output)
annotations_ = store_.values()
annotations_geometry_type = [
str(annotation_.geometry_type) for annotation_ in annotations_
]
assert "Polygon" in annotations_geometry_type

con = sqlite3.connect(output)
cur = con.cursor()
annotations_properties = list(cur.execute("SELECT properties FROM annotations"))

out = []

for item in annotations_properties:
for json_str in item:
probs = json.loads(json_str)
if "type" in probs:
out.append(probs.pop("type"))

assert "mask" in out

assert annotations_properties is not None


def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None:
"""Test for saving output as annotation store."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image]

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=False,
save_dir=tmp_path / "output1",
output_type="annotationstore",
)

assert output[0] == tmp_path / "output1" / (sample_image.stem + ".db")
assert len(output) == 1
_test_store_output_patch(output[0])


def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path) -> None:
"""Test for saving output as annotation store using a numpy array."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

input_image = imread(sample_image)
inputs_list = [input_image, input_image]

output = segmentor.run(
images=inputs_list,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
save_dir=tmp_path / "output1",
output_type="annotationstore",
)

assert output[0] == tmp_path / "output1" / "0.db"
assert output[2] == tmp_path / "output1" / "1.db"

assert output[1] == tmp_path / "output1" / "0.tif"
assert output[3] == tmp_path / "output1" / "1.tif"

_test_store_output_patch(output[0])
_test_store_output_patch(output[2])
2 changes: 1 addition & 1 deletion tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
_ = output[0]
_ = output["probabilities"][0]

# run untrained network to test for architecture
model = UNetModel(
Expand Down
29 changes: 23 additions & 6 deletions tests/models/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from tiatoolbox.utils import download_data, imread, imwrite, unzip_data
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.exceptions import DimensionMismatchError
from tiatoolbox.wsicore import WSIReader

RNG = np.random.default_rng() # Numpy Random Generator
Expand Down Expand Up @@ -120,7 +121,9 @@ def test_kather_dataset(tmp_path: Path) -> None:
assert len(dataset.inputs) == len(dataset.labels)

# to actually get the image, we feed it to PatchDataset
actual_ds = PatchDataset(dataset.inputs, dataset.labels)
actual_ds = PatchDataset(
dataset.inputs, dataset.labels, patch_input_shape=(224, 224)
)
sample_patch = actual_ds[89]
assert isinstance(sample_patch["image"], np.ndarray)
assert sample_patch["label"] is not None
Expand All @@ -129,14 +132,28 @@ def test_kather_dataset(tmp_path: Path) -> None:
shutil.rmtree(save_dir_path, ignore_errors=True)


def test_incorrect_input_shape() -> None:
"""Incorrect input patch dimensions should raise DimensionMismatchError."""
size = (5, 5, 3)
img = RNG.integers(low=0, high=255, size=size)
list_imgs = [img, img, img]
dataset = PatchDataset(list_imgs, patch_input_shape=(100, 100))
with pytest.raises(
DimensionMismatchError, match=r".*\(100, 100\), but got \(5, 5\).*"
):
_ = dataset[0]


def test_patch_dataset_path_imgs(
sample_patch1: str | Path,
sample_patch2: str | Path,
) -> None:
"""Test for patch dataset with a list of file paths as input."""
size = (224, 224, 3)

dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)])
dataset = PatchDataset(
[Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1]
)

for _, sample_data in enumerate(dataset):
sampled_img_shape = sample_data["image"].shape
Expand All @@ -152,7 +169,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None:
size = (5, 5, 3)
img = RNG.integers(low=0, high=255, size=size)
list_imgs = [img, img, img]
dataset = PatchDataset(list_imgs)
dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1])

dataset.preproc_func = lambda x: x

Expand Down Expand Up @@ -197,14 +214,14 @@ def test_patch_datasetarray_imgs() -> None:
array_imgs = np.array(list_imgs)

# test different setter for label
dataset = PatchDataset(array_imgs, labels=labels)
dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(5, 5))
an_item = dataset[2]
assert an_item["label"] == 3
dataset = PatchDataset(array_imgs, labels=None)
dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(5, 5))
an_item = dataset[2]
assert "label" not in an_item

dataset = PatchDataset(array_imgs)
dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1])
for _, sample_data in enumerate(dataset):
sampled_img_shape = sample_data["image"].shape
assert sampled_img_shape[0] == size[0]
Expand Down
Loading