Skip to content

Commit

Permalink
Fix and Improve performance for pathology models (#1158)
Browse files Browse the repository at this point in the history
* Fix and Improve performance for pathology models

Signed-off-by: Sachidanand Alle <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix dependency

Signed-off-by: Sachidanand Alle <[email protected]>

* sync up nuclick changes

Signed-off-by: Sachidanand Alle <[email protected]>

Signed-off-by: Sachidanand Alle <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SachidanandAlle and pre-commit-ci[bot] authored Nov 24, 2022
1 parent f8ebb38 commit 6007e0c
Show file tree
Hide file tree
Showing 12 changed files with 883 additions and 417 deletions.
50 changes: 29 additions & 21 deletions monailabel/transform/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from monai.config import KeysCollection, NdarrayOrTensor
from monai.data import MetaTensor
from monai.transforms import MapTransform, Resize, Transform, generate_spatial_bounding_box, get_extreme_points
from monai.utils import InterpolateMode, ensure_tuple_rep
from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep
from shapely.geometry import Point, Polygon
from torchvision.utils import make_grid, save_image

Expand Down Expand Up @@ -176,7 +176,7 @@ def __call__(self, data):
color_map = d.get(self.key_label_colors) if self.colormap is None else self.colormap

foreground_points = d.get(self.key_foreground_points, []) if self.key_foreground_points else []
foreground_points = [Point(pt[1], pt[0]) for pt in foreground_points] # polygons in (y, x) format
foreground_points = [Point(pt[0], pt[1]) for pt in foreground_points] # polygons in (x, y) format

elements = []
label_names = set()
Expand All @@ -188,8 +188,9 @@ def __call__(self, data):
labels = [label for label in np.unique(p).tolist() if label > 0]
logger.debug(f"Total Unique Masks (excluding background): {labels}")
for label_idx in labels:
p = d[key].array if isinstance(d[key], MetaTensor) else d[key]
p = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key]
p = np.where(p == label_idx, 1, 0).astype(np.uint8)
p = np.moveaxis(p, 0, 1) # for cv2

label_name = self.labels.get(label_idx, label_idx)
label_names.add(label_name)
Expand Down Expand Up @@ -237,29 +238,36 @@ def __call__(self, data):


class DumpImagePrediction2Dd(Transform):
def __init__(self, image_path, pred_path):
def __init__(self, image_path, pred_path, pred_only=True):
self.image_path = image_path
self.pred_path = pred_path
self.pred_only = pred_only

def __call__(self, data):
d = dict(data)
image = d["image"].array
pred = d["pred"].array

img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
save_image(img_tensor, self.image_path)

image_pred = [pred, image[3][None], image[4][None]] if image.shape[0] == 5 else [pred]
image_pred_np = np.array(image_pred)
image_pred_t = torch.from_numpy(image_pred_np)

tensor = make_grid(
tensor=image_pred_t,
nrow=len(image_pred),
normalize=True,
pad_value=10,
)
save_image(tensor, self.pred_path)
for bidx in range(d["image"].shape[0]):
image = np.moveaxis(d["image"][bidx], 1, 2)
pred = np.moveaxis(d["pred"][bidx], 0, 1)

img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
save_image(img_tensor, self.image_path)

if self.pred_only:
pred_tensor = make_grid(torch.from_numpy(pred), normalize=True)
save_image(pred_tensor[0], self.pred_path)
return d

image_pred = [pred[None], image[3][None], image[4][None]] if image.shape[0] == 5 else [pred[None]]
image_pred_np = np.array(image_pred)
image_pred_t = torch.from_numpy(image_pred_np)

tensor = make_grid(
tensor=image_pred_t,
nrow=len(image_pred),
normalize=True,
pad_value=10,
)
save_image(tensor, self.pred_path)
return d


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@ bcrypt==3.2.2
shapely==1.8.2
requests==2.28.1
scikit-learn
scipy

#sudo apt-get install openslide-tools -y
90 changes: 58 additions & 32 deletions sample-apps/pathology/lib/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self,
summary_writer: Optional[SummaryWriter] = None,
log_dir: str = "./runs",
tag_name="val_acc",
tag_name="val",
interval: int = 1,
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
Expand All @@ -88,7 +88,8 @@ def __init__(
self.class_y_pred: List[Any] = []

def attach(self, engine: Engine) -> None:
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration")
if self.interval == 1:
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration")
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self, "epoch")

def __call__(self, engine: Engine, action) -> None:
Expand Down Expand Up @@ -130,39 +131,37 @@ def write_images(self, batch_data, output_data, epoch):
image = batch_data[bidx]["image"].detach().cpu().numpy()
y = output_data[bidx]["label"].detach().cpu().numpy()

tag_prefix = f"b{bidx} - " if self.batch_limit != 1 else ""
img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True)
self.writer.add_image(tag=f"{tag_prefix}Image", img_tensor=img_tensor, global_step=epoch)

if self.class_names:
sig_np = image[:3] * 128 + 128
sig_np[0, :, :] = np.where(image[3] > 0, 1, sig_np[0, :, :])
sig_tensor = make_grid(torch.from_numpy(sig_np), normalize=True)
self.writer.add_image(tag=f"{tag_prefix}Signal", img_tensor=sig_tensor, global_step=epoch)
if np.count_nonzero(image[3]) == 0:
self.logger.info("+++++++++ BUG (Signal is ZERO)")
self.logger.info(f"{self.tag_name} => +++++++++ BUG (Signal is ZERO)")

y_pred = output_data[bidx]["pred"].detach().cpu().numpy()

y_c = np.argmax(y)
y_pred_c = np.argmax(y_pred)

tag_prefix = f"b{bidx} - " if self.batch_limit != 1 else ""
label_pred_tag = f"{tag_prefix}Label vs Pred:"
tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else f"{self.tag_name} - "
label_pred_tag = f"{tag_prefix}Image/Signal/Label/Pred:"

y_img = Image.new("RGB", (200, 100))
y_img = Image.new("RGB", image.shape[-2:])
draw = ImageDraw.Draw(y_img)
draw.text((10, 50), self.class_names.get(f"{y_c}", f"{y_c}"))

y_pred_img = Image.new("RGB", (200, 100), "green" if y_c == y_pred_c else "red")
y_pred_img = Image.new("RGB", image.shape[-2:], "green" if y_c == y_pred_c else "red")
draw = ImageDraw.Draw(y_pred_img)
draw.text((10, 50), self.class_names.get(f"{y_pred_c}", f"{y_pred_c}"))

label_pred = [np.moveaxis(np.array(y_img), -1, 0), np.moveaxis(np.array(y_pred_img), -1, 0)]
img_tensor = make_grid(
tensor=torch.from_numpy(np.array(label_pred)),
nrow=3,
normalize=False,
tensor=[
torch.from_numpy(sig_np),
torch.from_numpy(np.stack((np.where(image[3] > 0, 255, 0),) * 3)),
torch.from_numpy(np.moveaxis(np.array(y_img), -1, 0)),
torch.from_numpy(np.moveaxis(np.array(y_pred_img), -1, 0)),
],
nrow=4,
normalize=True,
pad_value=10,
)
self.writer.add_image(tag=label_pred_tag, img_tensor=img_tensor, global_step=epoch)
Expand All @@ -171,35 +170,60 @@ def write_images(self, batch_data, output_data, epoch):
if self.batch_limit == 1 and bidx < (len(batch_data) - 1) and np.sum(y) == 0:
continue

tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else ""
img_np = image[:3] * 128 + 128
if image.shape[0] > 3:
img_np[0, :, :] = np.where(image[3] > 0, 1, img_np[0, :, :])
img_tensor = make_grid(torch.from_numpy(img_np), normalize=True)
self.writer.add_image(tag=f"{tag_prefix}Image", img_tensor=img_tensor, global_step=epoch)

y_pred = output_data[bidx]["pred"].detach().cpu().numpy()

for region in range(y_pred.shape[0]):
if region == 0 and y_pred.shape[0] > 1: # one-hot; background
continue

cl = np.count_nonzero(y[region])
cp = np.count_nonzero(y_pred[region])
self.logger.info(
"{} - {} - Image: {};"
"{} => {} - {} - Image: {};"
" Label: {} (nz: {});"
" Pred: {} (nz: {});"
" Sig: (pos-nz: {}, neg-nz: {})".format(
" Diff: {:.2f}%; "
"{}".format(
self.tag_name,
bidx,
region,
image.shape,
y.shape,
np.count_nonzero(y[region]),
cl,
y_pred.shape,
np.count_nonzero(y_pred[region]),
np.count_nonzero(image[3]) if image.shape[0] == 5 else 0,
np.count_nonzero(image[4]) if image.shape[0] == 5 else 0,
cp,
100 * (cp - cl) / (cl + 1),
" Sig: (pos-nz: {}, neg-nz: {})".format(
np.count_nonzero(image[3]) if image.shape[0] == 5 else 0,
np.count_nonzero(image[4]) if image.shape[0] == 5 else 0,
)
if image.shape[0] == 5
else "",
)
)

tag_prefix = f"b{bidx}:l{region} - " if self.batch_limit != 1 else f"l{region} - "
tag_prefix = (
f"{self.tag_name} - b{bidx}:l{region} - "
if self.batch_limit != 1
else f"{self.tag_name} - l{region} - "
)

label_pred = [y[region][None], y_pred[region][None]]
label_pred_tag = f"{tag_prefix}Label vs Pred:"
if image.shape[0] == 5:
label_pred = [y[region][None], y_pred[region][None], image[3][None], image[4][None]]
label_pred = [
y[region][None],
y_pred[region][None],
image[3][None] > 0,
image[4][None] > 0,
]
label_pred_tag = f"{tag_prefix}Label vs Pred vs Pos vs Neg"

img_tensor = make_grid(
Expand All @@ -222,12 +246,12 @@ def write_region_metrics(self, epoch):
for n, m in v.items():
ltext.append(f"{n} => {m:.4f}")
cname = self.class_names.get(k, k)
self.writer.add_scalar(f"cr_{k}_{n}", m, epoch)
self.writer.add_scalar(f"{self.tag_name}_cr_{k}_{n}", m, epoch)

self.logger.info(f"Epoch[{epoch}] Metrics -- Class: {cname}; {'; '.join(ltext)}")
self.logger.info(f"{self.tag_name} => Epoch[{epoch}] Metrics -- Class: {cname}; {'; '.join(ltext)}")
else:
self.logger.info(f"Epoch[{epoch}] Metrics -- {k} => {v:.4f}")
self.writer.add_scalar(f"cr_{k}", v, epoch)
self.logger.info(f"{self.tag_name} => Epoch[{epoch}] Metrics -- {k} => {v:.4f}")
self.writer.add_scalar(f"{self.tag_name}_cr_{k}", v, epoch)

self.class_y = []
self.class_y_pred = []
Expand All @@ -237,13 +261,15 @@ def write_region_metrics(self, epoch):
metric_sum = 0
for region in self.metric_data:
metric = self.metric_data[region].mean()
self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}")
self.logger.info(
f"{self.tag_name} => Epoch[{epoch}] Metrics (Dice) -- Region: {region:0>2d}: {metric:.4f}"
)

self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch)
self.writer.add_scalar(f"{self.tag_name}_dice_{region:0>2d}", metric, epoch)
metric_sum += metric

metric_avg = metric_sum / len(self.metric_data)
self.writer.add_scalar("dice_regions_avg", metric_avg, epoch)
self.writer.add_scalar(f"{self.tag_name}_dice_regions_avg", metric_avg, epoch)

self.writer.flush()
self.metric_data = {}
10 changes: 5 additions & 5 deletions sample-apps/pathology/lib/infers/classification_nuclei.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from typing import Any, Callable, Dict, Sequence

import numpy as np
from lib.transforms import FixNuclickClassd, LoadImagePatchd
from lib.nuclick import AddLabelAsGuidanced
from lib.transforms import LoadImagePatchd
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Activationsd, AsChannelFirstd, EnsureTyped, ScaleIntensityRangeD
from monai.transforms import Activationsd, EnsureChannelFirstd, ScaleIntensityRangeD

from monailabel.interfaces.tasks.infer_v2 import InferType
from monailabel.tasks.infer.basic_infer import BasicInferTask
Expand Down Expand Up @@ -62,10 +63,9 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
return [
LoadImagePatchd(keys="image", dtype=np.uint8),
LoadImagePatchd(keys="label", dtype=np.uint8, mode="L"),
EnsureTyped(keys=("image", "label")),
AsChannelFirstd(keys="image"),
EnsureChannelFirstd(keys=("image", "label")),
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
FixNuclickClassd(image="image", label="label", offset=-1),
AddLabelAsGuidanced(keys="image", source="label"),
]

def inferer(self, data=None) -> Inferer:
Expand Down
23 changes: 13 additions & 10 deletions sample-apps/pathology/lib/infers/nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import numpy as np
import torch
from lib.transforms import FixNuclickClassd, LoadImagePatchd, NuClickPostFilterLabelExd
from monai.apps.nuclick.transforms import AddClickSignalsd, NuclickKeys
from lib.nuclick import AddClickSignalsd, AddLabelAsGuidanced, NuclickKeys, PostFilterLabeld
from lib.transforms import LoadImagePatchd
from monai.config import KeysCollection
from monai.transforms import (
Activationsd,
AsChannelFirstd,
AsDiscreted,
EnsureTyped,
MapTransform,
ScaleIntensityRangeD,
SqueezeDimd,
ToNumpyd,
)
Expand Down Expand Up @@ -113,36 +114,38 @@ def info(self) -> Dict[str, Any]:
def pre_transforms(self, data=None):
return [
LoadImagePatchd(keys="image", mode="RGB", dtype=np.uint8, padding=False),
EnsureTyped(keys="image", device=data.get("device") if data else None),
AsChannelFirstd(keys="image"),
ConvertInteractiveClickSignals(
source_annotation_keys="nuclick points",
target_data_keys=NuclickKeys.FOREGROUND,
allow_missing_keys=True,
),
AddClickSignalsd(image="image", foreground=NuclickKeys.FOREGROUND),
EnsureTyped(keys="image", device=data.get("device") if data else None),
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
AddClickSignalsd(image="image", foreground=NuclickKeys.FOREGROUND, gaussian=False),
]

def run_inferer(self, data, convert_to_batch=True, device="cuda"):
output = super().run_inferer(data, False, device)
if self.task_classification:
data2 = copy.deepcopy(self.task_classification.config())
pred1 = output["pred"]
pred1 = torch.sigmoid(pred1)
pred1 = pred1 >= 0.5

data2 = copy.deepcopy(self.task_classification.config())
data2.update({"image": output["image"][:, :3], "label": pred1, "device": device})

data2 = self.task_classification.run_pre_transforms(data2, [FixNuclickClassd(image="image", label="label")])
data2 = self.task_classification.run_pre_transforms(
data2, [AddLabelAsGuidanced(keys="image", source="label")]
)

output2 = self.task_classification.run_inferer(data2, False, device)
pred2 = output2["pred"]
pred2 = torch.softmax(pred2, dim=1)
pred2 = torch.argmax(pred2, dim=1)
pred2 = [int(p) for p in pred2]

output["pred_classes"] = [v + 1 for v in pred2]
logger.info(f"Predicted Classes: {output['pred_classes']}")
output[NuclickKeys.PRED_CLASSES] = [v + 1 for v in pred2]
logger.info(f"Predicted Classes: {output[NuclickKeys.PRED_CLASSES]}")
return output

def post_transforms(self, data=None) -> Sequence[Callable]:
Expand All @@ -152,7 +155,7 @@ def post_transforms(self, data=None) -> Sequence[Callable]:
AsDiscreted(keys="pred", threshold=0.5),
SqueezeDimd(keys="pred", dim=1),
ToNumpyd(keys=("image", "pred")),
NuClickPostFilterLabelExd(keys="pred"),
PostFilterLabeld(keys="pred"),
FindContoursd(keys="pred", labels=self.labels),
]

Expand Down
Loading

0 comments on commit 6007e0c

Please sign in to comment.