Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Cellpose segmentation with masking ROI tables & relabeling #786

Merged
merged 5 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
**Note**: Numbers like (\#123) point to closed Pull Requests on the fractal-tasks-core repository.

# Unreleased

* Tasks:
* Fix issue with masked ROI & relabeling in Cellpose task (\#785).
* Fix issue with masking ROI label types in masked_loading_wrapper for Cellpose task (\#785).

# 1.1.0

> NOTE: Starting from this release, `fractal-tasks-core` can coexist
Expand Down
2 changes: 1 addition & 1 deletion fractal_tasks_core/masked_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _preprocess_input(
'In _preprocess_input, "{column_name}" '
f" missing in {ROI_table.obs.columns=}"
)
label_value = int(ROI_table.obs[column_name][ROI_positional_index])
label_value = int(float(ROI_table.obs[column_name][ROI_positional_index]))

# Load masking-label array (lazily)
masking_label_path = str(
Expand Down
46 changes: 27 additions & 19 deletions fractal_tasks_core/tasks/cellpose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,26 @@

def segment_ROI(
x: np.ndarray,
num_labels_tot: dict[str, int],
model: models.CellposeModel = None,
do_3D: bool = True,
channels: list[int] = [0, 0],
diameter: float = 30.0,
normalize: CellposeCustomNormalizer = CellposeCustomNormalizer(),
normalize2: Optional[CellposeCustomNormalizer] = None,
label_dtype: Optional[np.dtype] = None,
relabeling: bool = True,
advanced_cellpose_model_params: CellposeModelParams = CellposeModelParams(), # noqa: E501
) -> np.ndarray:
"""
Internal function that runs Cellpose segmentation for a single ROI.

Args:
x: 4D numpy array.
num_labels_tot: Number of labels already in total image. Used for
relabeling purposes. Using a dict to have a mutable object that
can be edited from within the function without having to be passed
back through the masked_loading_wrapper.
model: An instance of `models.CellposeModel`.
do_3D: If `True`, cellpose runs in 3D mode: runs on xy, xz & yz planes,
then averages the flows.
Expand All @@ -107,6 +113,7 @@
normalized with default settings, both channels need to be
normalized with default settings.
label_dtype: Label images are cast into this `np.dtype`.
relabeling: Whether relabeling based on num_labels_tot is performed.
advanced_cellpose_model_params: Advanced Cellpose model parameters
that are passed to the Cellpose `model.eval` method.
"""
Expand Down Expand Up @@ -163,6 +170,23 @@
f" {advanced_cellpose_model_params.flow_threshold=}"
)

# Shift labels and update relabeling counters
if relabeling:
num_labels_roi = np.max(mask)
mask[mask > 0] += num_labels_tot["num_labels_tot"]
num_labels_tot["num_labels_tot"] += num_labels_roi

# Write some logs
logger.info(f"ROI had {num_labels_roi=}, {num_labels_tot=}")

# Check that total number of labels is under control
if num_labels_tot["num_labels_tot"] > np.iinfo(label_dtype).max:
raise ValueError(

Check notice on line 184 in fractal_tasks_core/tasks/cellpose_segmentation.py

View workflow job for this annotation

GitHub Actions / Coverage

Missing coverage

Missing coverage on line 184
"ERROR in re-labeling:"
f"Reached {num_labels_tot} labels, "
f"but dtype={label_dtype}"
)

return mask.astype(label_dtype)


Expand Down Expand Up @@ -438,8 +462,7 @@
logger.info(f"{data_zyx_c2.chunks}")

# Counters for relabeling
if relabeling:
num_labels_tot = 0
num_labels_tot = {"num_labels_tot": 0}

# Iterate over ROIs
num_ROIs = len(list_indices)
Expand Down Expand Up @@ -485,13 +508,15 @@

# Prepare keyword arguments for segment_ROI function
kwargs_segment_ROI = dict(
num_labels_tot=num_labels_tot,
model=model,
channels=channels,
do_3D=do_3D,
label_dtype=label_dtype,
diameter=diameter_level0 / coarsening_xy**level,
normalize=channel.normalize,
normalize2=channel2.normalize,
relabeling=relabeling,
advanced_cellpose_model_params=advanced_cellpose_model_params,
)

Expand All @@ -515,23 +540,6 @@
preprocessing_kwargs=preprocessing_kwargs,
)

# Shift labels and update relabeling counters
if relabeling:
num_labels_roi = np.max(new_label_img)
new_label_img[new_label_img > 0] += num_labels_tot
num_labels_tot += num_labels_roi

# Write some logs
logger.info(f"ROI {indices}, {num_labels_roi=}, {num_labels_tot=}")

# Check that total number of labels is under control
if num_labels_tot > np.iinfo(label_dtype).max:
raise ValueError(
"ERROR in re-labeling:"
f"Reached {num_labels_tot} labels, "
f"but dtype={label_dtype}"
)

if output_ROI_table:
bbox_df = array_to_bounding_box_table(
new_label_img,
Expand Down
108 changes: 108 additions & 0 deletions tests/tasks/test_workflows_cellpose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def patched_segment_ROI_overlapping_organoids(
return mask.astype(label_dtype)


def patched_cellpose_eval(self, x, **kwargs):
assert x.ndim == 4
# Actual labeling: segment_ROI returns a 3D mask with the same shape as x,
# except for the first dimension
mask = np.zeros_like(x[0, :, :, :])
nz, ny, nx = mask.shape
indices = np.arange(0, nx // 2)
mask[:, indices, indices] = 1 # noqa
mask[:, indices + 10, indices + 20] = 2 # noqa

return mask, 0, 0


def patched_cellpose_core_use_gpu(*args, **kwargs):
debug("WARNING: using patched_cellpose_core_use_gpu")
return False
Expand Down Expand Up @@ -616,6 +629,101 @@ def test_workflow_bounding_box_with_overlap(
assert "bounding-box pairs overlap" in caplog.text


def test_cellpose_within_masked_bb_with_overlap(
tmp_path: Path,
zenodo_zarr: list[str],
caplog: pytest.LogCaptureFixture,
monkeypatch: MonkeyPatch,
):
"""
Test to address #785: Segmenting objects within a masking ROI table and
ensuring that the relabeling works well with the masking.
"""

monkeypatch.setattr(
"fractal_tasks_core.tasks.cellpose_segmentation.cellpose.core.use_gpu",
patched_cellpose_core_use_gpu,
)

from cellpose import models

monkeypatch.setattr(
models.CellposeModel,
"eval",
patched_cellpose_eval,
)

# Use pre-made 3D zarr
zarr_dir = tmp_path / "tmp_out/"
zarr_urls = prepare_3D_zarr(str(zarr_dir), zenodo_zarr)
debug(zarr_dir)
debug(zarr_urls)

# Per-FOV labeling
channel = CellposeChannel1InputModel(
wavelength_id="A01_C01", normalize=CellposeCustomNormalizer()
)
for zarr_url in zarr_urls:
cellpose_segmentation(
zarr_url=zarr_url,
channel=channel,
level=3,
relabeling=True,
diameter_level0=80.0,
output_label_name="initial_segmentation",
output_ROI_table="bbox_table",
)

# Assert that 4 unique labels + background are present in the
# initial_segmentation
import dask.array as da

initial_segmentation = da.from_zarr(
f"{zarr_urls[0]}/labels/initial_segmentation/0"
).compute()
assert len(np.unique(initial_segmentation)) == 5
assert np.max(initial_segmentation) == 4

# Segment objects within the bbox_table masked, ensure the relabeling
# happens correctly
for zarr_url in zarr_urls:
cellpose_segmentation(
zarr_url=zarr_url,
channel=channel,
level=3,
relabeling=True,
diameter_level0=80.0,
input_ROI_table="bbox_table",
output_label_name="secondary_segmentation",
output_ROI_table="secondary_ROI_table",
)
# Check labels in secondary_segmentation: Verify correctness
# Our monkeypatched segmentation returns 2 labels, only 1 of which is
# within the mask => should be 1 segmentation output per initial object
# If relabeling works correctly, there will be 4 objects in
# secondary_segmentation and they will be 1, 2, 3, 4
secondary_segmentation = da.from_zarr(
f"{zarr_urls[0]}/labels/secondary_segmentation/0"
).compute()
assert len(np.unique(secondary_segmentation)) == 5
# Current approach doesn't 100% guarantee consecutive labels. Within the
# cellpose task, there could be labels that are taken into account for
# relabeling but are masked away afterwards. That's very unlikely to be an
# issue, because we set the background to 0 in the input image. But the
# mock testing ignores the input
# assert np.max(secondary_segmentation) == 4

# Ensure that labels stay in correct proportions => relabeling doesn't do
# reassignment
label1 = np.sum(secondary_segmentation == 1)
label3 = np.sum(secondary_segmentation == 3)
label5 = np.sum(secondary_segmentation == 5)
label7 = np.sum(secondary_segmentation == 7)
assert label1 == label3
assert label1 == label5
assert label1 == label7


def test_workflow_with_per_FOV_labeling_via_script(
tmp_path: Path,
zenodo_zarr: list[str],
Expand Down
Loading