Skip to content

Commit

Permalink
ENH: added MONAILabel app for tricuspid valve segmentation (issue #1)
Browse files Browse the repository at this point in the history
- currently no training is supported

ref: Project-MONAI/MONAILabel#154
  • Loading branch information
che85 committed Jun 30, 2021
1 parent a1d5b47 commit 2aa83bc
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 0 deletions.
3 changes: 3 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Segmentation - Tricuspid Valve from 3DE

## Overview
22 changes: 22 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/info.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
version: 1
name: Segmentation Tricuspid Valve
description: MONAI Label App for segmentation of the tricuspid valve from 3DE images
dimension: 3
labels:
- anterior
- posterior
- septal
config:
infer:
device: cuda
train:
name: model_01
pretrained: True
device: cuda
amp: true
lr: 0.02
epochs: 200
val_split: 0.1
train_batch_size: 8
val_batch_size: 8
4 changes: 4 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .activelearning import MyStrategy
from .infer import MyInfer
# from .train import MyTrain
from .vnet import VNet
26 changes: 26 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

from monailabel.interfaces import Datastore
from monailabel.interfaces.tasks import Strategy

logger = logging.getLogger(__name__)


class MyStrategy(Strategy):
"""
Consider implementing a first strategy for active learning
"""

def __init__(self):
super().__init__("Get First Sample")

def __call__(self, request, datastore: Datastore):
images = datastore.get_unlabeled_images()
if not len(images):
return None

images.sort()
image = images[0]

logger.info(f"First: Selected Image: {image}")
return image
69 changes: 69 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/lib/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from monai.inferers import SimpleInferer
from monai.engines.utils import CommonKeys as Keys

from monai.transforms import (
AddChanneld,
LoadImaged,
ToTensord,
ScaleIntensityd,
AsDiscreted,
ConcatItemsd,
ToNumpyd,
SqueezeDimd
)

from monailabel.utils.others.post import Restored
from monailabel.interfaces.tasks import InferTask, InferType

from .transforms import DistanceTransformd


class MyInfer(InferTask):
"""
This provides Inference Engine for pre-trained tricuspid valve segmentation (VNet) model.
"""

def __init__(
self,
path,
network=None,
type=InferType.SEGMENTATION,
labels=("anterior", "posterior", "septal"),
dimension=3,
description="A pre-trained model for volumetric (3D) segmentation of tricuspid valve from 3DE image",
):
super().__init__(
path=path,
network=network,
type=type,
labels=labels,
dimension=dimension,
description=description,
)

def pre_transforms(self):
all_keys = [Keys.IMAGE, Keys.LABEL]
return [
LoadImaged(keys=all_keys, reader="NibabelReader"),
AddChanneld(keys=all_keys),
DistanceTransformd(keys=[Keys.LABEL]),
ScaleIntensityd(
keys=[Keys.IMAGE],
minv=0.0,
maxv=1.0
),
ToTensord(keys=all_keys),
ConcatItemsd(keys=all_keys, name=Keys.IMAGE, dim=0)
]

def inferer(self):
return SimpleInferer()

def post_transforms(self):
return [
AddChanneld(keys="pred"),
AsDiscreted(keys="pred", argmax=True),
SqueezeDimd(keys="pred", dim=0),
ToNumpyd(keys="pred"),
Restored(keys="pred", ref_image="image"),
]
131 changes: 131 additions & 0 deletions MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
from monai.transforms import MapTransform
import SimpleITK as sitk
import numpy as np
import torch

logger = logging.getLogger(__name__)


def simplex(t, axis: int = 1) -> bool:
import torch
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)


def is_one_hot(t, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])


def sset(a, sub) -> bool:
return uniq(a).issubset(sub)


def uniq(a) -> set:
import torch
return set(torch.unique(a.cpu()).numpy())


class OneHotTransform(object):

@classmethod
def run(cls, data):
if len(data.shape) == 4:
assert data.shape[0] == 1
data = data[0]

n_classes = (len(np.unique(data)))
assert n_classes > 1, f"{cls.__name__}: Not enough unique pixel values found in data."
assert n_classes < 10, f"{cls.__name__}: Too many unique pixel values found in data."

w, h, d = data.shape
res = np.stack([data == c for c in range(n_classes)], axis=0).astype(np.int32)
assert res.shape == (n_classes, w, h, d)
assert np.all(res.sum(axis=0) == 1)
return res

def __init__(self, fields):
self.fields = fields

def __call__(self, data):
for field in self.fields:
data[field] = self.run(data[field])
assert np.isfinite(data[field]).all()
return data


class OneHotTransformd(MapTransform):

def __init__(self, keys):
super(OneHotTransformd, self).__init__(keys)

def __call__(self, data):
for key in self.keys:
one_hot = OneHotTransform.run(data[key])
assert np.isfinite(one_hot).all()
assert np.any(one_hot)

data[key] = one_hot.astype(np.float32)
return data


class DistanceTransform(object):
""" Create distance map on the fly for labels
"""

METHODS = {
"SDM": sitk.SignedMaurerDistanceMapImageFilter,
"EDM": sitk.DanielssonDistanceMapImageFilter
}
DEFAULT_METHOD = "SDM"

@classmethod
def get_distance_map(cls, data, method=DEFAULT_METHOD):
image = sitk.GetImageFromArray(data.astype(np.int16))
distanceMapFilter = cls.METHODS[method]()
distanceMapFilter.SetUseImageSpacing(True)
distanceMapFilter.SetSquaredDistance(False)
out = distanceMapFilter.Execute(image)
return sitk.GetArrayFromImage(out)

def __init__(self, fields, method=DEFAULT_METHOD):
self.fields = fields
self.computationMethod = method

def __call__(self, data):
for field in self.fields:
d = data[field]
assert is_one_hot(torch.Tensor(d), axis=0)
# NB: skipping computation of background distance map
d = d[1:, ...]
assert d.shape[0] > 0
data[field] = np.stack([
self.get_distance_map(d[ch].astype(np.float32), self.computationMethod) for ch in range(d.shape[0])],
axis=0)
assert np.isfinite(data[field]).all()
return data


class DistanceTransformd(MapTransform):

def one_hot_to_dist(self, input_array):
assert is_one_hot(torch.Tensor(input_array), axis=0)
out = np.stack(
[DistanceTransform.get_distance_map(input_array[ch].astype(np.float32),
method=self.method) for ch in range(input_array.shape[0])], axis=0)
return out

def __init__(self, keys, method=DistanceTransform.DEFAULT_METHOD):
super(DistanceTransformd, self).__init__(keys)
self.method = method

def __call__(self, data):
for key in self.keys:
one_hot = OneHotTransform.run(data[key])
assert np.isfinite(one_hot).all()
assert np.any(one_hot)

result_np = self.one_hot_to_dist(one_hot).astype(np.float32)
data[key] = result_np[1:, ...]
return data
Loading

0 comments on commit 2aa83bc

Please sign in to comment.