-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: added MONAILabel app for tricuspid valve segmentation (issue #1)
- currently no training is supported ref: Project-MONAI/MONAILabel#154
- Loading branch information
Showing
10 changed files
with
504 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Segmentation - Tricuspid Valve from 3DE | ||
|
||
## Overview |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
--- | ||
version: 1 | ||
name: Segmentation Tricuspid Valve | ||
description: MONAI Label App for segmentation of the tricuspid valve from 3DE images | ||
dimension: 3 | ||
labels: | ||
- anterior | ||
- posterior | ||
- septal | ||
config: | ||
infer: | ||
device: cuda | ||
train: | ||
name: model_01 | ||
pretrained: True | ||
device: cuda | ||
amp: true | ||
lr: 0.02 | ||
epochs: 200 | ||
val_split: 0.1 | ||
train_batch_size: 8 | ||
val_batch_size: 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .activelearning import MyStrategy | ||
from .infer import MyInfer | ||
# from .train import MyTrain | ||
from .vnet import VNet |
26 changes: 26 additions & 0 deletions
26
MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import logging | ||
|
||
from monailabel.interfaces import Datastore | ||
from monailabel.interfaces.tasks import Strategy | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MyStrategy(Strategy): | ||
""" | ||
Consider implementing a first strategy for active learning | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__("Get First Sample") | ||
|
||
def __call__(self, request, datastore: Datastore): | ||
images = datastore.get_unlabeled_images() | ||
if not len(images): | ||
return None | ||
|
||
images.sort() | ||
image = images[0] | ||
|
||
logger.info(f"First: Selected Image: {image}") | ||
return image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from monai.inferers import SimpleInferer | ||
from monai.engines.utils import CommonKeys as Keys | ||
|
||
from monai.transforms import ( | ||
AddChanneld, | ||
LoadImaged, | ||
ToTensord, | ||
ScaleIntensityd, | ||
AsDiscreted, | ||
ConcatItemsd, | ||
ToNumpyd, | ||
SqueezeDimd | ||
) | ||
|
||
from monailabel.utils.others.post import Restored | ||
from monailabel.interfaces.tasks import InferTask, InferType | ||
|
||
from .transforms import DistanceTransformd | ||
|
||
|
||
class MyInfer(InferTask): | ||
""" | ||
This provides Inference Engine for pre-trained tricuspid valve segmentation (VNet) model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path, | ||
network=None, | ||
type=InferType.SEGMENTATION, | ||
labels=("anterior", "posterior", "septal"), | ||
dimension=3, | ||
description="A pre-trained model for volumetric (3D) segmentation of tricuspid valve from 3DE image", | ||
): | ||
super().__init__( | ||
path=path, | ||
network=network, | ||
type=type, | ||
labels=labels, | ||
dimension=dimension, | ||
description=description, | ||
) | ||
|
||
def pre_transforms(self): | ||
all_keys = [Keys.IMAGE, Keys.LABEL] | ||
return [ | ||
LoadImaged(keys=all_keys, reader="NibabelReader"), | ||
AddChanneld(keys=all_keys), | ||
DistanceTransformd(keys=[Keys.LABEL]), | ||
ScaleIntensityd( | ||
keys=[Keys.IMAGE], | ||
minv=0.0, | ||
maxv=1.0 | ||
), | ||
ToTensord(keys=all_keys), | ||
ConcatItemsd(keys=all_keys, name=Keys.IMAGE, dim=0) | ||
] | ||
|
||
def inferer(self): | ||
return SimpleInferer() | ||
|
||
def post_transforms(self): | ||
return [ | ||
AddChanneld(keys="pred"), | ||
AsDiscreted(keys="pred", argmax=True), | ||
SqueezeDimd(keys="pred", dim=0), | ||
ToNumpyd(keys="pred"), | ||
Restored(keys="pred", ref_image="image"), | ||
] |
131 changes: 131 additions & 0 deletions
131
MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import logging | ||
from monai.transforms import MapTransform | ||
import SimpleITK as sitk | ||
import numpy as np | ||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def simplex(t, axis: int = 1) -> bool: | ||
import torch | ||
_sum = t.sum(axis).type(torch.float32) | ||
_ones = torch.ones_like(_sum, dtype=torch.float32) | ||
return torch.allclose(_sum, _ones) | ||
|
||
|
||
def is_one_hot(t, axis=1) -> bool: | ||
return simplex(t, axis) and sset(t, [0, 1]) | ||
|
||
|
||
def sset(a, sub) -> bool: | ||
return uniq(a).issubset(sub) | ||
|
||
|
||
def uniq(a) -> set: | ||
import torch | ||
return set(torch.unique(a.cpu()).numpy()) | ||
|
||
|
||
class OneHotTransform(object): | ||
|
||
@classmethod | ||
def run(cls, data): | ||
if len(data.shape) == 4: | ||
assert data.shape[0] == 1 | ||
data = data[0] | ||
|
||
n_classes = (len(np.unique(data))) | ||
assert n_classes > 1, f"{cls.__name__}: Not enough unique pixel values found in data." | ||
assert n_classes < 10, f"{cls.__name__}: Too many unique pixel values found in data." | ||
|
||
w, h, d = data.shape | ||
res = np.stack([data == c for c in range(n_classes)], axis=0).astype(np.int32) | ||
assert res.shape == (n_classes, w, h, d) | ||
assert np.all(res.sum(axis=0) == 1) | ||
return res | ||
|
||
def __init__(self, fields): | ||
self.fields = fields | ||
|
||
def __call__(self, data): | ||
for field in self.fields: | ||
data[field] = self.run(data[field]) | ||
assert np.isfinite(data[field]).all() | ||
return data | ||
|
||
|
||
class OneHotTransformd(MapTransform): | ||
|
||
def __init__(self, keys): | ||
super(OneHotTransformd, self).__init__(keys) | ||
|
||
def __call__(self, data): | ||
for key in self.keys: | ||
one_hot = OneHotTransform.run(data[key]) | ||
assert np.isfinite(one_hot).all() | ||
assert np.any(one_hot) | ||
|
||
data[key] = one_hot.astype(np.float32) | ||
return data | ||
|
||
|
||
class DistanceTransform(object): | ||
""" Create distance map on the fly for labels | ||
""" | ||
|
||
METHODS = { | ||
"SDM": sitk.SignedMaurerDistanceMapImageFilter, | ||
"EDM": sitk.DanielssonDistanceMapImageFilter | ||
} | ||
DEFAULT_METHOD = "SDM" | ||
|
||
@classmethod | ||
def get_distance_map(cls, data, method=DEFAULT_METHOD): | ||
image = sitk.GetImageFromArray(data.astype(np.int16)) | ||
distanceMapFilter = cls.METHODS[method]() | ||
distanceMapFilter.SetUseImageSpacing(True) | ||
distanceMapFilter.SetSquaredDistance(False) | ||
out = distanceMapFilter.Execute(image) | ||
return sitk.GetArrayFromImage(out) | ||
|
||
def __init__(self, fields, method=DEFAULT_METHOD): | ||
self.fields = fields | ||
self.computationMethod = method | ||
|
||
def __call__(self, data): | ||
for field in self.fields: | ||
d = data[field] | ||
assert is_one_hot(torch.Tensor(d), axis=0) | ||
# NB: skipping computation of background distance map | ||
d = d[1:, ...] | ||
assert d.shape[0] > 0 | ||
data[field] = np.stack([ | ||
self.get_distance_map(d[ch].astype(np.float32), self.computationMethod) for ch in range(d.shape[0])], | ||
axis=0) | ||
assert np.isfinite(data[field]).all() | ||
return data | ||
|
||
|
||
class DistanceTransformd(MapTransform): | ||
|
||
def one_hot_to_dist(self, input_array): | ||
assert is_one_hot(torch.Tensor(input_array), axis=0) | ||
out = np.stack( | ||
[DistanceTransform.get_distance_map(input_array[ch].astype(np.float32), | ||
method=self.method) for ch in range(input_array.shape[0])], axis=0) | ||
return out | ||
|
||
def __init__(self, keys, method=DistanceTransform.DEFAULT_METHOD): | ||
super(DistanceTransformd, self).__init__(keys) | ||
self.method = method | ||
|
||
def __call__(self, data): | ||
for key in self.keys: | ||
one_hot = OneHotTransform.run(data[key]) | ||
assert np.isfinite(one_hot).all() | ||
assert np.any(one_hot) | ||
|
||
result_np = self.one_hot_to_dist(one_hot).astype(np.float32) | ||
data[key] = result_np[1:, ...] | ||
return data |
Oops, something went wrong.