Skip to content

Commit

Permalink
Bug fixes from copying code over
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert, Andrew committed Nov 23, 2020
1 parent 6107519 commit 89c92be
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 491 deletions.
13 changes: 9 additions & 4 deletions public-segmentation/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Medical Segmentation

![SegTask](docs/seg_task.png)

A repository for training segmentation algorithms for medical images. This repository is primarily developed for
segmentation of echocardiography images from the [Camus](https://www.creatis.insa-lyon.fr/Challenge/camus/) and
[EchoNet](https://echonet.github.io/dynamic/) datasets as well as other data with similar structure.
segmentation of apical four and two chamber echocardiography images like those from the
[Camus](https://www.creatis.insa-lyon.fr/Challenge/camus/) and
[EchoNet](https://echonet.github.io/dynamic/) datasets as well as other data with similar structure.

Some parts of the code are specifically adapted for this (e.g. the [metrics](evaluators/__init__.py)) which assume
specific regions are associated with specific classes. These will need to be updated for a new segmentation task.

## Install

Expand Down Expand Up @@ -38,5 +39,9 @@ Run `python inference.py -h` to see available options. The main options match th

## Notes

Some parts of the code are specifically adapted for apical segmentation of the left ventricle
(e.g. the [metrics](evaluators/__init__.py)) which assume
specific regions are associated with specific classes. These will need to be updated for a new segmentation task.

Some of the structure/code from this repository is based on
[CycleGAN_and_pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
6 changes: 4 additions & 2 deletions public-segmentation/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def load_data(self):

def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
if self.opt.max_dataset_size is not None:
return min(len(self.dataset), self.opt.max_dataset_size)
return len(self.dataset)

def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
if i * self.opt.batch_size >= len(self):
break
yield data

Expand Down
Binary file added public-segmentation/docs/seg_task.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 1 addition & 5 deletions public-segmentation/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from .loss import Losses, Dice, BCEWithLogits, MSE, Pix2PixNoInputOtherLabels, P2PDiscLossWrapper, CrossEntropy, WCE
from .loss import Losses, Dice, BCEWithLogits, MSE, CrossEntropy, WCE
from .metrics import IoU, Metrics, Curvature, Hausdorff, DiceScore, Simplicity, Convexity, CurvatureIndividual, \
CosineSim, SliceMerger, Bias, SurfaceDist

Expand Down Expand Up @@ -35,8 +35,6 @@ def collect_losses(opt):
loss_fcn.add_loss(loss, BCEWithLogits(loss_weight, out_type="segs"))
elif loss == "bbox":
loss_fcn.add_loss(loss, MSE(loss_weight, out_type="bboxes"))
elif loss == "adversarial":
loss_fcn.add_loss(loss, Pix2PixNoInputOtherLabels(loss_weight, opt, out_type="segs"))
elif loss == "weight_decay":
pass # this is handled in the optimizer initialization (above) so nothing needed here
# Add others here as needed
Expand All @@ -50,8 +48,6 @@ def add_losses_to_metrics(losses: loss.Losses):
metric_dict = dict()
for name, loss in losses.loss_dict.items():
metric_dict["loss_" + name] = loss # LossWrapper(loss)
if name == "adversarial": # add second metric for discriminator using a special helper class
metric_dict["loss_adversarial_D"] = P2PDiscLossWrapper(loss)
return metric_dict


Expand Down
148 changes: 1 addition & 147 deletions public-segmentation/evaluators/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@

import torch
from kornia.utils import one_hot as kornia_one_hot
from torch import nn
# Below are imported for access from other files to keep everything in the same place
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.functional import softmax, one_hot

from segmentation.UltrasoundSegmentation.models.discriminator import NLayerDiscriminator
from segmentation.UltrasoundSegmentation.models.net import init_net, get_norm_layer
from transformation.CycleGAN_and_pix2pix.models.networks import GANLoss
from torch.nn.functional import softmax
from .metrics import MetricBase
from .weighted_cross_entropy import WeightedCrossEntropyLoss

Expand Down Expand Up @@ -178,147 +173,6 @@ def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return self.loss(output, target)


class Pix2PixNoInputOtherLabels(LossBase):
"""
A class just like Pix2Pix loss except for it does not use an input image (so discriminator only
learns from labels) and also the real labels can be derived from a different source.
"""

def __init__(self, loss_weight, opt, out_type="segs"):
super(Pix2PixNoInputOtherLabels, self).__init__(loss_weight, out_type)
self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
self.separate_discs = opt.separate_discs
if self.separate_discs:
self.netD = [self.define_D(1, 64, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
for _ in range(opt.output_nc)]

else:
self.netD = [self.define_D(opt.output_nc, 64, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)]
if opt.isTrain:
self.optimizer_D = [torch.optim.Adam(nD.parameters(), lr=opt.adv_lr, betas=(opt.beta1, 0.999))
for nD in self.netD]
self.criterionGAN = GANLoss("lsgan").to(self.device)
self.weight = loss_weight
self.loss_D = torch.tensor(-.1) # watch discriminator loss as well
self.isTrain = opt.isTrain
self.sigmoid = nn.Sigmoid()
self.num_classes = opt.output_nc
if opt.adv_lr_policy == "step_higher":
if opt.num_epochs > 20:
raise NotImplementedError("scheduler currently not configured")
# logging.info("configuring Step Higher for discriminator optimizer. lr=lrx2 at epoch 10 and 20")
# self.scheduler = lr_scheduler.MultiStepLR(self.optimizer_D, milestones=[10, 20], gamma=2.)
else:
logging.warning("adv_lr_policy step only configured for >20 epochs. Not using.")

@staticmethod
def define_D(input_nc, ndf, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=()):
norm_layer = get_norm_layer(norm_type=norm)
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
return init_net(net, init_type, init_gain, gpu_ids)

def backward_D(self, outputs, real_AB, netD):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
fake_AB = self.sigmoid(outputs) # only outputs
pred_fake = netD(fake_AB.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = real_AB.type(torch.float) # cast to float to match output from network
pred_real = netD(real_AB)
loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (loss_D_fake + loss_D_real) * 0.5
self.loss_D *= self.weight # need to multiple loss_D times weight here since it won't be outside.
# if outputs.requires_grad and self.isTrain: # else just calculating for metrics
self.loss_D.backward()

def backward_G(self, outputs, netD):
"""Calculate GAN and L1 loss for the generator"""
# G(A) should fake the discriminator
fake_AB = self.sigmoid(outputs) # outputs # torch.cat((inputs, outputs), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = self.criterionGAN(pred_fake, True)
return loss_G_GAN # actual G update is performed in train.py

def _forward(self, output: torch.Tensor, target: torch.Tensor, netD, optimizer_D) -> torch.Tensor:
""" separate out forward call for case of multiple discriminators"""
# update D
# when this function is called from metrics requires_grad will be false, and self.loss_D will be set already
# so no need to redo the update D step
if output.requires_grad and self.isTrain:
self.set_requires_grad(netD, True) # enable backprop for D
optimizer_D.zero_grad() # set D's gradients to zero
self.backward_D(output, target, netD) # calculate gradients for D
# if output.requires_grad and self.isTrain:
optimizer_D.step() # update D's weights

# update G
self.set_requires_grad(netD, False) # D requires no gradients when optimizing G
return self.backward_G(output, netD) # no direct targets, just trying to full discriminator

def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# need one hot here to match the fake outputs - permute gets dimensions right from one_hot output
if self.num_classes > 1:
target = one_hot(target, num_classes=self.num_classes).squeeze(1).permute((0, 3, 1, 2))
if self.separate_discs:
loss = 0
for c in range(self.num_classes):
o = output[:, c, :, :].unsqueeze(1)
t = target[:, c, :, :].unsqueeze(1)
loss += self._forward(o, t, self.netD[c], self.optimizer_D[c])
return loss / self.num_classes
else:
return self._forward(output, target, self.netD[0], self.optimizer_D[0])

def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

def state_dict(self):
""" Get params to save """
return dict(
model_dict=[nD.state_dict() for nD in self.netD],
optim_dict=[oD.state_dict() for oD in self.optimizer_D],
)

def load_state_dict(self, state_dict):
""" Reload params """
for nD, oD, model_params, optim_params in zip(self.netD, self.optimizer_D, state_dict["model_dict"],
state_dict["optim_dict"]):
nD.load_state_dict(model_params)
oD.load_state_dict(optim_params)


class P2PDiscLossWrapper(MetricBase):
"""
A helper class to provide access to the discriminative loss of Pix2PixNoInputOtherLabels
in metrics.
"""

def __init__(self, p2p_loss: Pix2PixNoInputOtherLabels):
super(P2PDiscLossWrapper, self).__init__(lower_is_better=True)
self.p2p_loss = p2p_loss

def process_single(self, output: torch.Tensor, target: torch.Tensor):
pass # not necessary here since we overwrite call function

def __call__(self, outputs: dict, targets: dict, confidence=None):
# Relies on p2p loss being already called and updating it's own loss
# If this function is called before p2p_loss is called we will get the loss from
# the previous call. However, that case shouldn't matter much for the purpose of metrics.
self.results.append(self.p2p_loss.loss_D.item())


class Losses:
"""
Class to handle combine all losses.
Expand Down
141 changes: 0 additions & 141 deletions public-segmentation/models/camus_unet1.py

This file was deleted.

Loading

0 comments on commit 89c92be

Please sign in to comment.