Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor image utility functions and enhance documentation #154

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
files =
src/imgtools/logging/**/*.py,
src/imgtools/dicom/**/*.py,
src/imgtools/cli/**/*.py
src/imgtools/cli/**/*.py,
src/imgtools/modules/**/*.py,

# Exclude files from analysis
exclude = tests,
Expand Down
12 changes: 5 additions & 7 deletions config/ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# slowly fix everything

include = [
"src/imgtools/logging/**/*.py",
# "src/imgtools/cli/**/*.py",
"src/imgtools/logging/**/*.py",
"src/imgtools/modules/segmentation.py",
"src/imgtools/dicom/**/*.py",
# "src/imgtools/utils/crawl.py",
]
Expand All @@ -15,16 +15,13 @@ extend-exclude = [
"tests/**/*.py",
"src/imgtools/ops/ops.py",
"src/imgtools/io/**/*.py",
"src/imgtools/modules/**/*.py",
"src/imgtools/transforms/**/*.py",
"src/imgtools/autopipeline.py",
"src/imgtools/pipeline.py",
"src/imgtools/image.py",
]

extend-include = [
"src/imgtools/ops/functional.py",
]
extend-include = ["src/imgtools/ops/functional.py"]


line-length = 100
Expand Down Expand Up @@ -105,7 +102,8 @@ ignore = [
# Ignored because https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
"COM812", # https://docs.astral.sh/ruff/rules/missing-trailing-comma/#missing-trailing-comma-com812
"D206",
"N813",
"N813",
"EM101",
]
[lint.pydocstyle]
convention = "numpy"
Expand Down
188 changes: 163 additions & 25 deletions src/imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,175 @@
from functools import wraps
"""Manage and manipulate segmentation masks with multi-label support.

This module provides the `Segmentation` class and associated utilities for working
with medical image segmentation masks.
It extends the functionality of `SimpleITK.Image` to include ROI-specific operations,
label management, and metadata tracking.

Classes
-------
Segmentation
A specialized class for handling multi-label segmentation masks. Includes
functionality for extracting individual labels, resolving overlaps, and
integrating with DICOM SEG metadata.

Functions
---------
accepts_segmentations(f)
A decorator to ensure functions working on images handle `Segmentation` objects
correctly by preserving metadata and ROI labels.

map_over_labels(segmentation, f, include_background=False, return_segmentation=True, **kwargs)
Applies a function to each label in a segmentation mask and combines the results,
optionally returning a new `Segmentation` object.

Notes
-----
- The `Segmentation` class tracks metadata and ROI names, enabling easier management
of multi-label segmentation masks.
- The `generate_sparse_mask` method resolves overlapping contours by taking the
maximum label value for each voxel, ensuring a consistent sparse representation.
- Integration with DICOM SEG metadata is supported through the `from_dicom_seg`
class method, which creates `Segmentation` objects from DICOM SEG files.

Examples
--------
# Creating a Segmentation object from a SimpleITK.Image
>>> seg = Segmentation(image, roi_indices={'GTV': 1, 'PTV': 2})

# Extracting an individual label
>>> gtv_mask = seg.get_label(name='GTV')

# Generating a sparse mask
>>> sparse_mask = seg.generate_sparse_mask(verbose=True)

# Applying a function to each label in the segmentation
>>> def compute_statistics(label_image):
>>> return sitk.LabelStatisticsImageFilter().Execute(label_image)

>>> stats = map_over_labels(segmentation=seg, f=compute_statistics)
"""

from __future__ import annotations

import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Comment on lines +52 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused Dict import.

Static analysis indicates that typing.Dict is never referenced. Removing unused imports helps keep the code clean and reduces confusion.

- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+ from typing import Any, Callable, List, Optional, Tuple, Union
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from __future__ import annotations
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from __future__ import annotations
import warnings
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, Union
🧰 Tools
🪛 Ruff (0.8.2)

56-56: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)


import numpy as np
import SimpleITK as sitk

from imgtools.utils import array_to_image, image_to_array

from .sparsemask import SparseMask

from ..utils import array_to_image, image_to_array
from typing import Optional, Tuple, Set

def accepts_segmentations(f: Callable) -> Callable:
"""A decorator that ensures functions can handle `Segmentation` objects correctly.

def accepts_segmentations(f):
@wraps(f)
def wrapper(img, *args, **kwargs):
result = f(img, *args, **kwargs)
if isinstance(img, Segmentation):
result = sitk.Cast(result, sitk.sitkVectorUInt8)
return Segmentation(result, roi_indices=img.roi_indices, raw_roi_names=img.raw_roi_names)
else:
return result
return wrapper


def map_over_labels(segmentation, f, include_background=False, return_segmentation=True, **kwargs):
if include_background:
labels = range(segmentation.num_labels + 1)
else:
labels = range(1, segmentation.num_labels + 1)
res = [f(segmentation.get_label(label=label), **kwargs) for label in labels]
if return_segmentation and isinstance(res[0], sitk.Image):
res = [sitk.Cast(r, sitk.sitkUInt8) for r in res]
res = Segmentation(sitk.Compose(*res), roi_indices=segmentation.roi_indices, raw_roi_names=segmentation.raw_roi_names)
return res
If the input image is an instance of `Segmentation`, the decorator preserves
the ROI indices and raw ROI names in the output.

This is useful when using functions that process images without losing metadata
for the Segmentation class.

Parameters
----------
f : Callable
The function to wrap, which processes an image.

Returns
-------
Callable
A wrapped function that preserves `Segmentation` metadata if the input
is a `Segmentation` object.

Examples
--------
>>> @accepts_segmentations
... def some_processing_function(img, *args, **kwargs):
... return img # Perform some operation on the image
>>> segmentation = Segmentation(image, roi_indices={'ROI1': 1, 'ROI2': 2})
>>> result = some_processing_function(segmentation)
>>> isinstance(result, Segmentation)
True
"""

@wraps(f)
def wrapper(

Check warning on line 98 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L97-L98

Added lines #L97 - L98 were not covered by tests
img: Union[sitk.Image, Segmentation],
*args: Any, # noqa
**kwargs: Any, # noqa
) -> Union[sitk.Image, Segmentation]:
result = f(img, *args, **kwargs)
if isinstance(img, Segmentation):
result = sitk.Cast(result, sitk.sitkVectorUInt8)
return Segmentation(

Check warning on line 106 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L103-L106

Added lines #L103 - L106 were not covered by tests
result, roi_indices=img.roi_indices, raw_roi_names=img.raw_roi_names
)
return result

Check warning on line 109 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L109

Added line #L109 was not covered by tests

return wrapper

Check warning on line 111 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L111

Added line #L111 was not covered by tests


def map_over_labels(
segmentation: Segmentation,
f: Callable[[sitk.Image], sitk.Image],
include_background: bool = False,
return_segmentation: bool = True,
**kwargs: Any, # noqa
) -> Union[List[sitk.Image], Segmentation]:
"""
Applies a function to each label in a segmentation mask.

This function iterates over all labels in the segmentation mask, applies
the provided function to each label individually, and optionally combines
the results into a new `Segmentation` object.

Parameters
----------
segmentation : Segmentation
The segmentation object containing multiple ROI labels.
f : Callable[[sitk.Image], sitk.Image]
A function to apply to each label in the segmentation.
include_background : bool, optional
If True, includes the background label (label 0) in the operation.
Default is False.
return_segmentation : bool, optional
If True, combines the results into a new `Segmentation` object.
If False, returns a list of processed labels as `sitk.Image`. Default is True.
**kwargs : Any
Additional keyword arguments passed to the function `f`.

Returns
-------
Union[List[sitk.Image], Segmentation]
A new `Segmentation` object if `return_segmentation` is True,
otherwise a list of `sitk.Image` objects for each label.

Examples
--------
>>> def threshold(label_img, threshold=0.5):
... return sitk.BinaryThreshold(label_img, lowerThreshold=threshold)
>>> segmentation = Segmentation(image, roi_indices={'ROI1': 1, 'ROI2': 2})
>>> result = map_over_labels(segmentation, threshold, threshold=0.5)
>>> isinstance(result, Segmentation)
True
"""
if include_background:
labels = range(segmentation.num_labels + 1)

Check warning on line 159 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L158-L159

Added lines #L158 - L159 were not covered by tests
else:
labels = range(1, segmentation.num_labels + 1)

Check warning on line 161 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L161

Added line #L161 was not covered by tests

res = [f(segmentation.get_label(label=label), **kwargs) for label in labels]

Check warning on line 163 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L163

Added line #L163 was not covered by tests

if return_segmentation and isinstance(res[0], sitk.Image):
res = [sitk.Cast(r, sitk.sitkUInt8) for r in res]
return Segmentation(

Check warning on line 167 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L165-L167

Added lines #L165 - L167 were not covered by tests
sitk.Compose(*res),
roi_indices=segmentation.roi_indices,
raw_roi_names=segmentation.raw_roi_names,
)
return res

Check warning on line 172 in src/imgtools/modules/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/imgtools/modules/segmentation.py#L172

Added line #L172 was not covered by tests


class Segmentation(sitk.Image):
Expand Down
Loading
Loading