From 7e3da8a7caeadc15dedc1df38ae0491b5ce0a0e4 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 8 Aug 2022 12:00:38 -0400 Subject: [PATCH 1/3] Condensing -> gonna try inheritance --- raygun/torch/losses/GANLoss.py | 63 ++++++ raygun/torch/losses/LinkCycleLoss.py | 185 ++++++++++++++++ raygun/torch/losses/SplitCycleLoss.py | 198 ++++++++++++++++++ .../{Link_Cycle_Model.py => CycleModel.py} | 25 ++- raygun/torch/models/Split_Cycle_Model.py | 57 ----- ...mmy_Optimizer.py => BaseDummyOptimizer.py} | 2 +- 6 files changed, 463 insertions(+), 67 deletions(-) create mode 100644 raygun/torch/losses/GANLoss.py create mode 100644 raygun/torch/losses/LinkCycleLoss.py create mode 100644 raygun/torch/losses/SplitCycleLoss.py rename raygun/torch/models/{Link_Cycle_Model.py => CycleModel.py} (71%) delete mode 100644 raygun/torch/models/Split_Cycle_Model.py rename raygun/torch/optimizers/{Base_Dummy_Optimizer.py => BaseDummyOptimizer.py} (85%) diff --git a/raygun/torch/losses/GANLoss.py b/raygun/torch/losses/GANLoss.py new file mode 100644 index 00000000..ac6a0043 --- /dev/null +++ b/raygun/torch/losses/GANLoss.py @@ -0,0 +1,63 @@ +# ORIGINALLY WRITTEN BY TRI NGUYEN (HARVARD, 2021) +import torch + +class GANLoss(torch.nn.Module): + """Define different GAN objectives. + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = torch.nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = torch.nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + Parameters: + prediction (tensor) - - typically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + Parameters: + prediction (tensor) - - typically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss diff --git a/raygun/torch/losses/LinkCycleLoss.py b/raygun/torch/losses/LinkCycleLoss.py new file mode 100644 index 00000000..8e766e17 --- /dev/null +++ b/raygun/torch/losses/LinkCycleLoss.py @@ -0,0 +1,185 @@ + +import torch +from losses import GANLoss + +import logging +logger = logging.Logger('CycleGANLoss', 'INFO') + +class LinkCycleLoss(torch.nn.Module): + """CycleGAN loss function""" + def __init__(self, + netD1, + netG1, + netD2, + netG2, + optimizer_G, + optimizer_D, + dims, + l1_loss = torch.nn.SmoothL1Loss(), + g_lambda_dict= {'A': {'l1_loss': {'cycled': 10, 'identity': 0}, + 'gan_loss': {'fake': 1, 'cycled': 0}, + }, + 'B': {'l1_loss': {'cycled': 10, 'identity': 0}, + 'gan_loss': {'fake': 1, 'cycled': 0}, + }, + }, + d_lambda_dict= {'A': {'real': 1, 'fake': 1, 'cycled': 0}, + 'B': {'real': 1, 'fake': 1, 'cycled': 0}, + }, + gan_mode='lsgan' + ): + super().__init__() + self.l1_loss = l1_loss + self.gan_loss = GANLoss(gan_mode=gan_mode) + self.netD1 = netD1 # differentiates between fake and real Bs + self.netG1 = netG1 # turns As into Bs + self.netD2 = netD2 # differentiates between fake and real As + self.netG2 = netG2 # turns Bs into As + self.optimizer_G = optimizer_G + self.optimizer_D = optimizer_D + self.g_lambda_dict = g_lambda_dict + self.d_lambda_dict = d_lambda_dict + self.gan_mode = gan_mode + self.dims = dims + self.loss_dict = {} + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=False for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def crop(self, x, shape): + '''Center-crop x to match spatial dimensions given by shape.''' + + x_target_size = x.size()[:-self.dims] + shape + + offset = tuple( + (a - b)//2 + for a, b in zip(x.size(), x_target_size)) + + slices = tuple( + slice(o, o + s) + for o, s in zip(offset, x_target_size)) + + return x[slices] + + def clamp_weights(self, net, min=-0.01, max=0.01): + for module in net.model: + if hasattr(module, 'weight') and hasattr(module.weight, 'data'): + temp = module.weight.data + module.weight.data = temp.clamp(min, max) + + def backward_D(self, side, dnet, data_dict): + """Calculate losses for a discriminator""" + loss = 0 + for key, lambda_ in self.d_lambda_dict[side].items(): + if lambda_ != 0: + # if key == 'identity': # TODO: ADD IDENTITY SUPPORT + # pred = gnet(data_dict['real']) + # else: + # pred = data_dict[key] + + this_loss = self.gan_loss(dnet(data_dict[key].detach()), key == 'real') + + self.loss_dict.update({f'Discriminator_{side}/{key}': this_loss}) + loss += lambda_ * this_loss + + loss.backward() + return loss + + def backward_Ds(self, data_dict, n_loop=5): + self.set_requires_grad([self.netG1, self.netG2], False) # G does not require gradients when optimizing D + self.set_requires_grad([self.netD1, self.netD2], True) # enable backprop for D + self.optimizer_D.zero_grad(set_to_none=True) # set D's gradients to zero + + if self.gan_mode.lower() == 'wgangp': # Wasserstein Loss + for _ in range(n_loop): + loss_D1 = self.backward_D('B', self.netD1, data_dict['B']) + loss_D2 = self.backward_D('A', self.netD2, data_dict['A']) + self.optimizer_D.step() # update D's weights + self.clamp_weights(self.netD1) + self.clamp_weights(self.netD2) + else: + loss_D1 = self.backward_D('B', self.netD1, data_dict['B']) + loss_D2 = self.backward_D('A', self.netD2, data_dict['A']) + self.optimizer_D.step() # update D's weights + + #return losses + return loss_D1, loss_D2 + + def backward_G(self, side, gnet, dnet, data_dict): + """Calculate losses for a generator""" + loss = 0 + real = data_dict['real'] + for fcn_name, lambdas in self.g_lambda_dict[side].items(): + loss_fcn = getattr(self, fcn_name) + for key, lambda_ in lambdas.items(): + if lambda_ != 0: + if key == 'identity' and key not in data_dict: + data_dict['identity'] = gnet(real) + pred = data_dict[key] + + if fcn_name == 'l1_loss': + if real.size()[-self.dims:] != pred.size()[-self.dims:]: + this_loss = loss_fcn(self.crop(real, pred.size()[-self.dims:]), pred) + else: + this_loss = loss_fcn(real, pred) + elif fcn_name == 'gan_loss': + this_loss = loss_fcn(dnet(pred), True) + + self.loss_dict.update({f'{fcn_name}/{key}_{side}': this_loss}) + loss += lambda_ * this_loss + + # calculate gradients + loss.backward() + return loss + + def backward_Gs(self, data_dict): + self.set_requires_grad([self.netD1, self.netD2], False) # D requires no gradients when optimizing G + self.set_requires_grad([self.netG1, self.netG2], True) # Turn G gradients back on + + self.optimizer_G.zero_grad(set_to_none=True) # set G1's gradients to zero + loss_G1 = self.backward_G('B', self.netG1, self.netD1, data_dict['B']) # calculate gradient for G + loss_G2 = self.backward_G('A', self.netG2, self.netD2, data_dict['A']) # calculate gradient for G + self.optimizer_G.step() # udpate G1's weights + + #return losses + return loss_G1, loss_G2 + + def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): + + # crop if necessary + if real_A.size()[-self.dims:] != fake_B.size()[-self.dims:]: + real_A = self.crop(real_A, fake_A.size()[-self.dims:]) + real_B = self.crop(real_B, fake_B.size()[-self.dims:]) + + data_dict = {'A': {'real': real_A, 'fake': fake_A, 'cycled': cycled_A}, + 'B': {'real': real_B, 'fake': fake_B, 'cycled': cycled_B} + } + # update Gs + loss_G1, loss_G2 = self.backward_Gs(data_dict) + + # update Ds + loss_D1, loss_D2 = self.backward_Ds(data_dict) + + self.loss_dict.update({ + 'Total_Loss/D1': float(loss_D1), + 'Total_Loss/D2': float(loss_D2), + 'Total_Loss/G1': float(loss_G1), + 'Total_Loss/G2': float(loss_G2), + }) + + total_loss = loss_G1 + loss_G2 + loss_D1 + loss_D2 + # define dummy backward pass to disable Gunpowder's Train node loss.backward() call + total_loss.backward = lambda: None + + logger.info(self.loss_dict) + return total_loss \ No newline at end of file diff --git a/raygun/torch/losses/SplitCycleLoss.py b/raygun/torch/losses/SplitCycleLoss.py new file mode 100644 index 00000000..34992de4 --- /dev/null +++ b/raygun/torch/losses/SplitCycleLoss.py @@ -0,0 +1,198 @@ + +import torch +from losses import GANLoss + +import logging +logger = logging.Logger('CycleGANLoss', 'INFO') + +class SplitCycleLoss(torch.nn.Module): + """CycleGAN loss function""" + def __init__(self, + netD1, + netG1, + netD2, + netG2, + optimizer_G1, + optimizer_G2, + optimizer_D, + dims, + l1_loss = torch.nn.SmoothL1Loss(), + g_lambda_dict= {'A': {'l1_loss': {'cycled': 10, 'identity': 0}, + 'gan_loss': {'fake': 1, 'cycled': 0}, + }, + 'B': {'l1_loss': {'cycled': 10, 'identity': 0}, + 'gan_loss': {'fake': 1, 'cycled': 0}, + }, + }, + d_lambda_dict= {'A': {'real': 1, 'fake': 1, 'cycled': 0}, + 'B': {'real': 1, 'fake': 1, 'cycled': 0}, + }, + gan_mode='lsgan' + ): + super().__init__() + self.l1_loss = l1_loss + self.gan_loss = GANLoss(gan_mode=gan_mode) + self.netD1 = netD1 # differentiates between fake and real Bs + self.netG1 = netG1 # turns As into Bs + self.netD2 = netD2 # differentiates between fake and real As + self.netG2 = netG2 # turns Bs into As + self.optimizer_G1 = optimizer_G1 + self.optimizer_G2 = optimizer_G2 + self.optimizer_D = optimizer_D + self.g_lambda_dict = g_lambda_dict + self.d_lambda_dict = d_lambda_dict + self.gan_mode = gan_mode + self.dims = dims + self.loss_dict = {} + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=False for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def crop(self, x, shape): + '''Center-crop x to match spatial dimensions given by shape.''' + + x_target_size = x.size()[:-self.dims] + shape + + offset = tuple( + (a - b)//2 + for a, b in zip(x.size(), x_target_size)) + + slices = tuple( + slice(o, o + s) + for o, s in zip(offset, x_target_size)) + + return x[slices] + + def clamp_weights(self, net, min=-0.01, max=0.01): + for module in net.model: + if hasattr(module, 'weight') and hasattr(module.weight, 'data'): + temp = module.weight.data + module.weight.data = temp.clamp(min, max) + + def backward_D(self, side, dnet, data_dict): + """Calculate losses for a discriminator""" + loss = 0 + for key, lambda_ in self.d_lambda_dict[side].items(): + if lambda_ != 0: + # if key == 'identity': # TODO: ADD IDENTITY SUPPORT + # pred = gnet(data_dict['real']) + # else: + # pred = data_dict[key] + + this_loss = self.gan_loss(dnet(data_dict[key].detach()), key == 'real') + + self.loss_dict.update({f'Discriminator_{side}/{key}': this_loss}) + loss += lambda_ * this_loss + + loss.backward() + return loss + + def backward_Ds(self, data_dict, n_loop=5): + self.set_requires_grad([self.netG1, self.netG2], False) # G does not require gradients when optimizing D + self.set_requires_grad([self.netD1, self.netD2], True) # enable backprop for D + self.optimizer_D.zero_grad(set_to_none=True) # set D's gradients to zero + + if self.gan_mode.lower() == 'wgangp': # Wasserstein Loss + for _ in range(n_loop): + loss_D1 = self.backward_D('B', self.netD1, data_dict['B']) + loss_D2 = self.backward_D('A', self.netD2, data_dict['A']) + self.optimizer_D.step() # update D's weights + self.clamp_weights(self.netD1) + self.clamp_weights(self.netD2) + else: + loss_D1 = self.backward_D('B', self.netD1, data_dict['B']) + loss_D2 = self.backward_D('A', self.netD2, data_dict['A']) + self.optimizer_D.step() # update D's weights + + #return losses + return loss_D1, loss_D2 + + def backward_G(self, side, gnet, dnet, data_dict): + """Calculate losses for a generator""" + loss = 0 + real = data_dict['real'] + for fcn_name, lambdas in self.g_lambda_dict[side].items(): + loss_fcn = getattr(self, fcn_name) + for key, lambda_ in lambdas.items(): + if lambda_ != 0: + if key == 'identity' and key not in data_dict: + data_dict['identity'] = gnet(real) + pred = data_dict[key] + + if fcn_name == 'l1_loss': + if real.size()[-self.dims:] != pred.size()[-self.dims:]: + this_loss = loss_fcn(self.crop(real, pred.size()[-self.dims:]), pred) + else: + this_loss = loss_fcn(real, pred) + elif fcn_name == 'gan_loss': + this_loss = loss_fcn(dnet(pred), True) + + self.loss_dict.update({f'{fcn_name}/{key}_{side}': this_loss}) + loss += lambda_ * this_loss + + # calculate gradients + loss.backward() + return loss + + def backward_Gs(self, data_dict): + self.set_requires_grad([self.netD1, self.netD2], False) # D requires no gradients when optimizing G + self.set_requires_grad([self.netG1, self.netG2], True) # Turn G gradients back on + + #G1 first + self.set_requires_grad([self.netG1], True) # G1 requires gradients when optimizing + self.set_requires_grad([self.netG2], False) # G2 requires no gradients when optimizing G1 + self.optimizer_G1.zero_grad(set_to_none=True) # set G1's gradients to zero + loss_G1 = self.backward_G('B', self.netG1, self.netD1, data_dict['B']) # calculate gradient for G + self.optimizer_G1.step() # udpate G1's weights + + #Then G2 + self.set_requires_grad([self.netG2], True) # G2 requires gradients when optimizing + self.set_requires_grad([self.netG1], False) # G1 requires no gradients when optimizing G2 + self.optimizer_G2.zero_grad(set_to_none=True) # set G2's gradients to zero + loss_G2 = self.backward_G('A', self.netG2, self.netD2, data_dict['A']) # calculate gradient for G + self.optimizer_G2.step() # udpate G2's weights + + # Turn gradients back on + self.set_requires_grad([self.netG1], True) + #return losses + return loss_G1, loss_G2 + + def forward(self, real_A, fake_A, cycled_A, real_B, fake_B, cycled_B): + + # crop if necessary + if real_A.size()[-self.dims:] != fake_B.size()[-self.dims:]: + real_A = self.crop(real_A, fake_A.size()[-self.dims:]) + real_B = self.crop(real_B, fake_B.size()[-self.dims:]) + + data_dict = {'A': {'real': real_A, 'fake': fake_A, 'cycled': cycled_A}, + 'B': {'real': real_B, 'fake': fake_B, 'cycled': cycled_B} + } + # update Gs + loss_G1, loss_G2 = self.backward_Gs(data_dict) + + # update Ds + loss_D1, loss_D2 = self.backward_Ds(data_dict) + + self.loss_dict.update({ + 'Total_Loss/D1': float(loss_D1), + 'Total_Loss/D2': float(loss_D2), + 'Total_Loss/G1': float(loss_G1), + 'Total_Loss/G2': float(loss_G2), + }) + + total_loss = loss_G1 + loss_G2 + loss_D1 + loss_D2 + # define dummy backward pass to disable Gunpowder's Train node loss.backward() call + total_loss.backward = lambda: None + + logger.info(self.loss_dict) + return total_loss \ No newline at end of file diff --git a/raygun/torch/models/Link_Cycle_Model.py b/raygun/torch/models/CycleModel.py similarity index 71% rename from raygun/torch/models/Link_Cycle_Model.py rename to raygun/torch/models/CycleModel.py index 00b99a05..d571db72 100644 --- a/raygun/torch/models/Link_Cycle_Model.py +++ b/raygun/torch/models/CycleModel.py @@ -1,15 +1,16 @@ import torch import torch.nn.functional as F -class Link_Cycle_Model(torch.nn.Module): - def __init__(self, netG1, netG2, scale_factor_A=None, scale_factor_B=None): +class CycleModel(torch.nn.Module): + def __init__(self, netG1, netG2, scale_factor_A=None, scale_factor_B=None, split=False): super().__init__() self.netG1 = netG1 self.netG2 = netG2 self.scale_factor_A = scale_factor_A self.scale_factor_B = scale_factor_B + self.split = split self.cycle = True - self.crop_pad = False + self.crop_pad = None #TODO: Determine if this is depracated def sampling_bottleneck(self, array, scale_factor): size = array.shape[-len(scale_factor):] @@ -26,12 +27,15 @@ def forward(self, real_A=None, real_B=None): if real_A is not None: #allow calling for single direction pass (i.e. prediction) self.fake_B = self.netG1(real_A) - if self.crop_pad is not False: + if self.crop_pad is not None: self.fake_B = self.fake_B[self.crop_pad] if self.scale_factor_B: self.fake_B = self.sampling_bottleneck(self.fake_B, self.scale_factor_B) #apply sampling bottleneck if self.cycle: - self.cycled_A = self.netG2(self.fake_B) - if self.crop_pad is not False: + if self.split: + self.cycled_A = self.netG2(self.fake_B.detach()) # detach to prevent backprop to first generator + else: + self.cycled_A = self.netG2(self.fake_B) + if self.crop_pad is not None: self.cycled_A = self.cycled_A[self.crop_pad] else: self.cycled_A = None @@ -41,12 +45,15 @@ def forward(self, real_A=None, real_B=None): if real_B is not None: self.fake_A = self.netG2(real_B) - if self.crop_pad is not False: + if self.crop_pad is not None: self.fake_A = self.fake_A[self.crop_pad] if self.scale_factor_A: self.fake_A = self.sampling_bottleneck(self.fake_A, self.scale_factor_A) #apply sampling bottleneck if self.cycle: - self.cycled_B = self.netG1(self.fake_A) - if self.crop_pad is not False: + if self.split: + self.cycled_B = self.netG1(self.fake_A.detach()) # detach to prevent backprop to first generator + else: + self.cycled_B = self.netG1(self.fake_A) + if self.crop_pad is not None: self.cycled_B = self.cycled_B[self.crop_pad] else: self.cycled_B = None diff --git a/raygun/torch/models/Split_Cycle_Model.py b/raygun/torch/models/Split_Cycle_Model.py deleted file mode 100644 index b19cd3d8..00000000 --- a/raygun/torch/models/Split_Cycle_Model.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -import torch.nn.functional as F - -class CycleGAN_Split_Model(torch.nn.Module): - def __init__(self, netG1, netG2, scale_factor_A=None, scale_factor_B=None): - super().__init__() - self.netG1 = netG1 - self.netG2 = netG2 - self.scale_factor_A = scale_factor_A - self.scale_factor_B = scale_factor_B - self.cycle = True - self.crop_pad = False - - def sampling_bottleneck(self, array, scale_factor): - size = array.shape[-len(scale_factor):] - mode = {2: 'bilinear', 3: 'trilinear'}[len(size)] - down = F.interpolate(array, scale_factor=scale_factor, mode=mode, align_corners=True) - return F.interpolate(down, size=size, mode=mode, align_corners=True) - - def set_crop_pad(self, crop_pad, ndims): - self.crop_pad = (slice(None,None,None),)*2 + (slice(crop_pad,-crop_pad),)*ndims - - def forward(self, real_A=None, real_B=None): - self.real_A = real_A - self.real_B = real_B - - if real_A is not None: #allow calling for single direction pass (i.e. prediction) - self.fake_B = self.netG1(real_A) - if self.crop_pad is not False: - self.fake_B = self.fake_B[self.crop_pad] - if self.scale_factor_B: self.fake_B = self.sampling_bottleneck(self.fake_B, self.scale_factor_B) #apply sampling bottleneck - if self.cycle: - self.cycled_A = self.netG2(self.fake_B.detach()) # detach to prevent backprop to first generator - if self.crop_pad is not False: - self.cycled_A = self.cycled_A[self.crop_pad] - else: - self.cycled_A = None - else: - self.fake_B = None - self.cycled_A = None - - if real_B is not None: - self.fake_A = self.netG2(real_B) - if self.crop_pad is not False: - self.fake_A = self.fake_A[self.crop_pad] - if self.scale_factor_A: self.fake_A = self.sampling_bottleneck(self.fake_A, self.scale_factor_A) #apply sampling bottleneck - if self.cycle: - self.cycled_B = self.netG1(self.fake_A.detach()) # detach to prevent backprop to first generator - if self.crop_pad is not False: - self.cycled_B = self.cycled_B[self.crop_pad] - else: - self.cycled_B = None - else: - self.fake_A = None - self.cycled_B = None - - return self.fake_B, self.cycled_B, self.fake_A, self.cycled_A diff --git a/raygun/torch/optimizers/Base_Dummy_Optimizer.py b/raygun/torch/optimizers/BaseDummyOptimizer.py similarity index 85% rename from raygun/torch/optimizers/Base_Dummy_Optimizer.py rename to raygun/torch/optimizers/BaseDummyOptimizer.py index 5416b044..e860c69b 100644 --- a/raygun/torch/optimizers/Base_Dummy_Optimizer.py +++ b/raygun/torch/optimizers/BaseDummyOptimizer.py @@ -1,6 +1,6 @@ import torch -class Base_Dummy_Optimizer(torch.nn.Module): +class BaseDummyOptimizer(torch.nn.Module): def __init__(self, **optimizers): super().__init__() for name, optimizer in optimizers.items(): From dbc1954f0c53d30854dff3caeba9dcd23b4f5a73 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 8 Aug 2022 13:59:59 -0400 Subject: [PATCH 2/3] Consolidated ResNet (tested), etc. Added freeze norm utils --- raygun/torch/networks/NLayerDiscriminator.py | 143 ++++ raygun/torch/networks/ResNet.py | 339 +++++++++ raygun/torch/networks/ResidualUNet.py | 2 +- raygun/torch/networks/UNet.py | 2 +- raygun/torch/networks/utils.py | 743 +------------------ 5 files changed, 513 insertions(+), 716 deletions(-) create mode 100644 raygun/torch/networks/NLayerDiscriminator.py create mode 100644 raygun/torch/networks/ResNet.py diff --git a/raygun/torch/networks/NLayerDiscriminator.py b/raygun/torch/networks/NLayerDiscriminator.py new file mode 100644 index 00000000..2ccd9fa9 --- /dev/null +++ b/raygun/torch/networks/NLayerDiscriminator.py @@ -0,0 +1,143 @@ +import torch +import functools + +class NLayerDiscriminator2D(torch.nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=torch.nn.BatchNorm2d, + kw=4, downsampling_kw=None): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ngf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super().__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == torch.nn.InstanceNorm2d + else: + use_bias = norm_layer == torch.nn.InstanceNorm2d + + if downsampling_kw is None: + downsampling_kw = kw + + padw = 1 + ds_kw = downsampling_kw + sequence = [torch.nn.Conv2d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + torch.nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ngf * nf_mult), + torch.nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + torch.nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ngf * nf_mult), + torch.nn.LeakyReLU(0.2, True) + ] + + sequence += [torch.nn.Conv2d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = torch.nn.Sequential(*sequence) + + @property + def FOV(self): + # Returns the receptive field of one output neuron for a network (written for patch discriminators) + # See https://distill.pub/2019/computing-receptive-fields/#solving-receptive-field-region for formula derivation + + L = 0 # num of layers + k = [] # [kernel width at layer l] + s = [] # [stride at layer i] + for layer in self.model: + if hasattr(layer, 'kernel_size'): + L += 1 + k += [layer.kernel_size[-1]] + s += [layer.stride[-1]] + + r = 1 + for l in range(L-1, 0, -1): + r = s[l]*r + (k[l] - s[l]) + + return r + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class NLayerDiscriminator3D(torch.nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=torch.nn.BatchNorm3d, + kw=4, downsampling_kw=None, + ): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ngf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super().__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm3d has affine parameters + use_bias = norm_layer.func == torch.nn.InstanceNorm3d + else: + use_bias = norm_layer == torch.nn.InstanceNorm3d + + if downsampling_kw is None: + downsampling_kw = kw + + padw = 1 + ds_kw = downsampling_kw + sequence = [torch.nn.Conv3d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + torch.nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ngf * nf_mult), + torch.nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + torch.nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ngf * nf_mult), + torch.nn.LeakyReLU(0.2, True) + ] + + sequence += [torch.nn.Conv3d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = torch.nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class NLayerDiscriminator(torch.nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, ndims, **kwargs): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ngf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + if ndims == 2: + NLayerDiscriminator2D.__init__(self, **kwargs) + elif ndims == 3: + NLayerDiscriminator3D.__init__(self, **kwargs) + else: + raise ValueError(ndims, 'Only 2D or 3D currently implemented. Feel free to contribute more!') diff --git a/raygun/torch/networks/ResNet.py b/raygun/torch/networks/ResNet.py new file mode 100644 index 00000000..dfd4b8bb --- /dev/null +++ b/raygun/torch/networks/ResNet.py @@ -0,0 +1,339 @@ + +import functools +import torch +from networks.utils import NoiseBlock, ParameterizedNoiseBlock + +class ResnetGenerator2D(torch.nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=torch.nn.ReLU, add_noise=False): + """Construct a Resnet-based generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + add_noise -- whether to append a noise feature to the data prior to upsampling layers: True | False | 'param' + """ + assert(n_blocks >= 0) + super().__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == torch.nn.InstanceNorm2d + else: + use_bias = norm_layer == torch.nn.InstanceNorm2d + + p = 0 + updown_p = 1 + padder = [] + if padding_type.lower() == 'reflect': + padder = [torch.nn.ReflectionPad2d(3)] + elif padding_type.lower() == 'replicate': + padder = [torch.nn.ReplicationPad2d(3)] + elif padding_type.lower() == 'zeros': + p = 3 + elif padding_type.lower() == 'valid': + p = 'valid' + updown_p = 0 + + model = [] + model += padder.copy() + model += [torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), + norm_layer(ngf), + activation()] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias), + norm_layer(ngf * mult * 2), + activation()] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock2D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] + + if add_noise == 'param': # add noise feature if necessary + model += [ParameterizedNoiseBlock()] + elif add_noise: + model += [NoiseBlock()] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [torch.nn.ConvTranspose2d(ngf * mult + (i==0 and (add_noise is not False)), + int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=updown_p, output_padding=updown_p, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + activation()] + model += padder.copy() + model += [torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)] + model += [torch.nn.Tanh()] + + self.model = torch.nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + +class ResnetBlock2D(torch.nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super().__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation) + self.padding_type = padding_type + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + activation -- non-linearity layer to apply (default is ReLU) + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) + """ + p = 0 + padder = [] + if padding_type == 'reflect': + padder = [torch.nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + padder = [torch.nn.ReplicationPad2d(1)] + elif padding_type == 'zeros': + p = 1 + elif padding_type == 'valid': + p = 'valid' + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block = [] + conv_block += padder.copy() + + conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), activation()] + if use_dropout: + conv_block += [torch.nn.Dropout(0.2)] + + conv_block += padder.copy() + conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return torch.nn.Sequential(*conv_block) + + def crop(self, x, shape): + '''Center-crop x to match spatial dimensions given by shape.''' + + x_target_size = x.size()[:-2] + shape + + offset = tuple( + torch.div((a - b), 2, rounding_mode='trunc') + for a, b in zip(x.size(), x_target_size)) + + slices = tuple( + slice(o, o + s) + for o, s in zip(offset, x_target_size)) + + return x[slices] + + def forward(self, x): + """Forward function (with skip connections)""" + if self.padding_type == 'valid': # crop for valid networks + res = self.conv_block(x) + out = self.crop(x, res.size()[-2:]) + res + else: + out = x + self.conv_block(x) # add skip connections + return out + + +class ResnetGenerator3D(torch.nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNorm3d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=torch.nn.ReLU, add_noise=False): + """Construct a Resnet-based generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + add_noise -- whether to append a noise feature to the data prior to upsampling layers: True | False | 'param' + """ + assert(n_blocks >= 0) + super().__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == torch.nn.InstanceNorm3d + else: + use_bias = norm_layer == torch.nn.InstanceNorm3d + + p = 0 + updown_p = 1 + padder = [] + if padding_type.lower() == 'reflect': + padder = [torch.nn.ReflectionPad3d(3)] + elif padding_type.lower() == 'replicate': + padder = [torch.nn.ReplicationPad3d(3)] + elif padding_type.lower() == 'zeros': + p = 3 + elif padding_type.lower() == 'valid': + p = 'valid' + updown_p = 0 + + model = [] + model += padder.copy() + model += [torch.nn.Conv3d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), + norm_layer(ngf), + activation()] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [torch.nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid) + norm_layer(ngf * mult * 2), + activation()] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock3D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] + + if add_noise == 'param': # add noise feature if necessary + model += [ParameterizedNoiseBlock()] + elif add_noise: + model += [NoiseBlock()] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [torch.nn.ConvTranspose3d(ngf * mult + (i==0 and (add_noise is not False)), + int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=updown_p, output_padding=updown_p, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + activation()] + model += padder.copy() + model += [torch.nn.Conv3d(ngf, output_nc, kernel_size=7, padding=p)] + model += [torch.nn.Tanh()] + + self.model = torch.nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + +class ResnetBlock3D(torch.nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super().__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation) + self.padding_type = padding_type + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + activation -- non-linearity layer to apply (default is ReLU) + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) + """ + p = 0 + padder = [] + if padding_type == 'reflect': + padder = [torch.nn.ReflectionPad3d(1)] + elif padding_type == 'replicate': + padder = [torch.nn.ReplicationPad3d(1)] + elif padding_type == 'zeros': + p = 1 + elif padding_type == 'valid': + p = 'valid' + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block = [] + conv_block += padder.copy() + + conv_block += [torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), activation()] + if use_dropout: + conv_block += [torch.nn.Dropout(0.2)] + + conv_block += padder.copy() + conv_block += [torch.nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return torch.nn.Sequential(*conv_block) + + def crop(self, x, shape): + '''Center-crop x to match spatial dimensions given by shape.''' + + x_target_size = x.size()[:-3] + shape + + offset = tuple( + torch.div((a - b), 2, rounding_mode='trunc') + for a, b in zip(x.size(), x_target_size)) + + slices = tuple( + slice(o, o + s) + for o, s in zip(offset, x_target_size)) + + return x[slices] + + def forward(self, x): + """Forward function (with skip connections)""" + if self.padding_type == 'valid': # crop for valid networks + res = self.conv_block(x) + out = self.crop(x, res.size()[-3:]) + res + else: + out = x + self.conv_block(x) # add skip connections + return out + + +class ResNet(ResnetGenerator2D, ResnetGenerator3D): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, ndims, **kwargs): + """Construct a Resnet-based generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid + activation -- non-linearity layer to apply (default is ReLU) + add_noise -- whether to append a noise feature to the data prior to upsampling layers: True | False | 'param' + """ + if ndims == 2: + ResnetGenerator2D.__init__(self, **kwargs) + elif ndims == 3: + ResnetGenerator3D.__init__(self, **kwargs) + else: + raise ValueError(ndims, 'Only 2D or 3D currently implemented. Feel free to contribute more!') diff --git a/raygun/torch/networks/ResidualUNet.py b/raygun/torch/networks/ResidualUNet.py index 5ea7da29..d1e57c7f 100644 --- a/raygun/torch/networks/ResidualUNet.py +++ b/raygun/torch/networks/ResidualUNet.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from utils import NoiseBlock, ParameterizedNoiseBlock +from networks.utils import NoiseBlock, ParameterizedNoiseBlock class ConvPass(torch.nn.Module): diff --git a/raygun/torch/networks/UNet.py b/raygun/torch/networks/UNet.py index b475d325..4289fca8 100644 --- a/raygun/torch/networks/UNet.py +++ b/raygun/torch/networks/UNet.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from utils import NoiseBlock, ParameterizedNoiseBlock +from networks.utils import NoiseBlock, ParameterizedNoiseBlock class ConvPass(torch.nn.Module): diff --git a/raygun/torch/networks/utils.py b/raygun/torch/networks/utils.py index 20c29c9d..0bb46988 100644 --- a/raygun/torch/networks/utils.py +++ b/raygun/torch/networks/utils.py @@ -1,144 +1,34 @@ # ORIGINALLY WRITTEN BY TRI NGUYEN (HARVARD, 2021) -import functools import torch -import torch.nn as nn -import torch.functional as F -from torch import Tensor from torch.nn import init -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator""" - - def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, - kw=4, downsampling_kw=None): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ngf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator, self).__init__() - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func == nn.InstanceNorm2d - else: - use_bias = norm_layer == nn.InstanceNorm2d - - if downsampling_kw is None: - downsampling_kw = kw - - padw = 1 - ds_kw = downsampling_kw - sequence = [nn.Conv2d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) - sequence += [ - nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, bias=use_bias), - norm_layer(ngf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) - sequence += [ - nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), - norm_layer(ngf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - sequence += [nn.Conv2d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map - self.model = nn.Sequential(*sequence) - - @property - def FOV(self): - # Returns the receptive field of one output neuron for a network (written for patch discriminators) - # See https://distill.pub/2019/computing-receptive-fields/#solving-receptive-field-region for formula derivation - - L = 0 # num of layers - k = [] # [kernel width at layer l] - s = [] # [stride at layer i] - for layer in self.model: - if hasattr(layer, 'kernel_size'): - L += 1 - k += [layer.kernel_size[-1]] - s += [layer.stride[-1]] - - r = 1 - for l in range(L-1, 0, -1): - r = s[l]*r + (k[l] - s[l]) - - return r - - def forward(self, input): - """Standard forward.""" - return self.model(input) - -class GANLoss(nn.Module): - """Define different GAN objectives. - The GANLoss class abstracts away the need to create the target label tensor - that has the same size as the input. - """ - - def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): - """ Initialize the GANLoss class. - Parameters: - gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. - target_real_label (bool) - - label for a real image - target_fake_label (bool) - - label of a fake image - Note: Do not use sigmoid as the last layer of Discriminator. - LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. - """ - super(GANLoss, self).__init__() - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) - self.gan_mode = gan_mode - if gan_mode == 'lsgan': - self.loss = nn.MSELoss() - elif gan_mode == 'vanilla': - self.loss = nn.BCEWithLogitsLoss() - elif gan_mode in ['wgangp']: - self.loss = None - else: - raise NotImplementedError('gan mode %s not implemented' % gan_mode) - - def get_target_tensor(self, prediction, target_is_real): - """Create label tensors with the same size as the input. - Parameters: - prediction (tensor) - - typically the prediction from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images - Returns: - A label tensor filled with ground truth label, and with the size of the input - """ - - if target_is_real: - target_tensor = self.real_label - else: - target_tensor = self.fake_label - return target_tensor.expand_as(prediction) - - def __call__(self, prediction, target_is_real): - """Calculate loss given Discriminator's output and grount truth labels. - Parameters: - prediction (tensor) - - typically the prediction output from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images - Returns: - the calculated loss. - """ - if self.gan_mode in ['lsgan', 'vanilla']: - target_tensor = self.get_target_tensor(prediction, target_is_real) - loss = self.loss(prediction, target_tensor) - elif self.gan_mode == 'wgangp': - if target_is_real: - loss = -prediction.mean() - else: - loss = prediction.mean() - return loss - +def get_norm_layers(net): + return [n for n in net.modules() if 'norm' in type(n).__name__.lower()] + +def get_running_norm_stats(net): + means = [] + vars = [] + for norm in get_norm_layers(net): + means.append(norm.running_mean) + vars.append(norm.running_var) + means = torch.cat(means) + vars = torch.cat(vars) + return means, vars + +def set_mode(net, mode='train'): + if mode == 'fix_stats': + net.train() + for m in net.modules(): + if 'norm' in type(m).__name__.lower(): + m.eval() + + if mode == 'train': + net.train() + + if mode == 'eval': + net.eval() def init_weights(net, init_type='normal', init_gain=0.02, nonlinearity='relu'): """Initialize network weights. @@ -171,93 +61,11 @@ def init_func(m): # define the initialization function print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function - -class ResnetGenerator(nn.Module): - """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. - We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) - """ - - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=nn.ReLU, add_noise=False): - """Construct a Resnet-based generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers - n_blocks (int) -- the number of ResNet blocks - padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid - activation -- non-linearity layer to apply (default is ReLU) - add_noise -- whether to append a noise feature to the data prior to upsampling layers: True | False | 'param' - """ - assert(n_blocks >= 0) - super(ResnetGenerator, self).__init__() - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm2d - else: - use_bias = norm_layer == nn.InstanceNorm2d - - p = 0 - updown_p = 1 - padder = [] - if padding_type.lower() == 'reflect': - padder = [nn.ReflectionPad2d(3)] - elif padding_type.lower() == 'replicate': - padder = [nn.ReplicationPad2d(3)] - elif padding_type.lower() == 'zeros': - p = 3 - elif padding_type.lower() == 'valid': - p = 'valid' - updown_p = 0 - - model = [] - model += padder.copy() - model += [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), - norm_layer(ngf), - activation()] - - n_downsampling = 2 - for i in range(n_downsampling): # add downsampling layers - mult = 2 ** i - model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias), - norm_layer(ngf * mult * 2), - activation()] - - mult = 2 ** n_downsampling - for i in range(n_blocks): # add ResNet blocks - - model += [ResnetBlock(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] - - if add_noise == 'param': # add noise feature if necessary - model += [ParameterizedNoiseBlock()] - elif add_noise: - model += [NoiseBlock()] - - for i in range(n_downsampling): # add upsampling layers - mult = 2 ** (n_downsampling - i) - model += [nn.ConvTranspose2d(ngf * mult + (i==0 and (add_noise is not False)), - int(ngf * mult / 2), - kernel_size=3, stride=2, - padding=updown_p, output_padding=updown_p, - bias=use_bias), - norm_layer(int(ngf * mult / 2)), - activation()] - model += padder.copy() - model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)] - model += [nn.Tanh()] - - self.model = nn.Sequential(*model) - - def forward(self, input): - """Standard forward""" - return self.model(input) - - -class NoiseBlock(nn.Module): +class NoiseBlock(torch.nn.Module): """Definies a block for producing and appending a feature map of gaussian noise with mean=0 and stdev=1""" def __init__(self): - super(NoiseBlock, self).__init__() + super().__init__() def forward(self, x): shape = list(x.shape) @@ -265,505 +73,12 @@ def forward(self, x): noise = torch.empty(shape, device=x.device).normal_() return torch.cat([x, noise.requires_grad_()], 1) - -class ParameterizedNoiseBlock(nn.Module): +class ParameterizedNoiseBlock(torch.nn.Module): """Definies a block for producing and appending a feature map of gaussian noise with mean and stdev defined by the first two feature maps of the incoming tensor""" def __init__(self): - super(ParameterizedNoiseBlock, self).__init__() + super().__init__() def forward(self, x): noise = torch.normal(x[:,0,...], torch.relu(x[:,1,...])).unsqueeze(1) return torch.cat([x, noise.requires_grad_()], 1) - -class ResnetBlock(nn.Module): - """Define a Resnet block""" - - def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=nn.ReLU): - """Initialize the Resnet block - A resnet block is a conv block with skip connections - We construct a conv block with build_conv_block function, - and implement skip connections in function. - Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf - """ - super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation) - self.padding_type = padding_type - - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=nn.ReLU): - """Construct a convolutional block. - Parameters: - dim (int) -- the number of channels in the conv layer. - padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - use_bias (bool) -- if the conv layer uses bias or not - activation -- non-linearity layer to apply (default is ReLU) - Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) - """ - p = 0 - padder = [] - if padding_type == 'reflect': - padder = [nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - padder = [nn.ReplicationPad2d(1)] - elif padding_type == 'zeros': - p = 1 - elif padding_type == 'valid': - p = 'valid' - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - - conv_block = [] - conv_block += padder.copy() - - conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), activation()] - if use_dropout: - conv_block += [nn.Dropout(0.2)] - - conv_block += padder.copy() - conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] - - return nn.Sequential(*conv_block) - - def crop(self, x, shape): - '''Center-crop x to match spatial dimensions given by shape.''' - - x_target_size = x.size()[:-2] + shape - - offset = tuple( - torch.div((a - b), 2, rounding_mode='trunc') - for a, b in zip(x.size(), x_target_size)) - - slices = tuple( - slice(o, o + s) - for o, s in zip(offset, x_target_size)) - - return x[slices] - - def forward(self, x): - """Forward function (with skip connections)""" - if self.padding_type == 'valid': # crop for valid networks - res = self.conv_block(x) - out = self.crop(x, res.size()[-2:]) + res - else: - out = x + self.conv_block(x) # add skip connections - return out - - -class UnetGenerator(nn.Module): - """Create a Unet-based generator""" - - def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): - """Construct a Unet generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, - image of size 128x128 will become of size 1x1 # at the bottleneck - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - We construct the U-Net from the innermost layer to the outermost layer. - It is a recursive process. - """ - super(UnetGenerator, self).__init__() - # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer - for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) - # gradually reduce the number of filters from ngf * 8 to ngf - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer - - def forward(self, input): - """Standard forward""" - return self.model(input) - - -class UnetSkipConnectionBlock(nn.Module): - """Defines the Unet submodule with skip connection. - X -------------------identity---------------------- - |-- downsampling -- |submodule| -- upsampling --| - """ - - def __init__(self, outer_nc, inner_nc, input_nc=None, - submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): - """Construct a Unet submodule with skip connections. - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - """ - super(UnetSkipConnectionBlock, self).__init__() - self.outermost = outermost - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm2d - else: - use_bias = norm_layer == nn.InstanceNorm2d - if input_nc is None: - input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, - stride=2, padding=1, bias=use_bias) - downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc) - uprelu = nn.ReLU(True) - upnorm = norm_layer(outer_nc) - - if outermost: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1) - down = [downconv] - up = [uprelu, upconv, nn.Tanh()] - model = down + [submodule] + up - elif innermost: - upconv = nn.ConvTranspose2d(inner_nc, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv] - up = [uprelu, upconv, upnorm] - model = down + up - else: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv, downnorm] - up = [uprelu, upconv, upnorm] - - if use_dropout: - model = down + [submodule] + up + [nn.Dropout(0.2)] - else: - model = down + [submodule] + up - - self.model = nn.Sequential(*model) - - def forward(self, x): - if self.outermost: - # print('outermost') - # print(x.size()) - return self.model(x) - else: # add skip connections - # print(x.size()) - # print(self.model(x).size()) - return torch.cat([x, self.model(x)], 1) - - -class UnetGenerator3D(nn.Module): - """Create a Unet-based generator""" - - def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm3d, use_dropout=False): - """Construct a Unet generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, - image of size 128x128 will become of size 1x1 # at the bottleneck - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - We construct the U-Net from the innermost layer to the outermost layer. - It is a recursive process. - """ - super(UnetGenerator3D, self).__init__() - # construct unet structure - unet_block = UnetSkipConnectionBlock3D(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer - for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters - unet_block = UnetSkipConnectionBlock3D(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) - # gradually reduce the number of filters from ngf * 8 to ngf - unet_block = UnetSkipConnectionBlock3D(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock3D(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock3D(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - self.model = UnetSkipConnectionBlock3D(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer - - def forward(self, input): - """Standard forward""" - return self.model(input) - - -class UnetSkipConnectionBlock3D(nn.Module): - """Defines the Unet submodule with skip connection. - X -------------------identity---------------------- - |-- downsampling -- |submodule| -- upsampling --| - """ - - def __init__(self, outer_nc, inner_nc, input_nc=None, - submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm3d, use_dropout=False): - """Construct a Unet submodule with skip connections. - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock3D) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - """ - super(UnetSkipConnectionBlock3D, self).__init__() - self.outermost = outermost - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm3d - else: - use_bias = norm_layer == nn.InstanceNorm3d - if input_nc is None: - input_nc = outer_nc - downconv = nn.Conv3d(input_nc, inner_nc, kernel_size=4, - stride=2, padding=1, bias=use_bias) - downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc) - uprelu = nn.ReLU(True) - upnorm = norm_layer(outer_nc) - - if outermost: - upconv = nn.ConvTranspose3d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1) - down = [downconv] - up = [uprelu, upconv, nn.Tanh()] - model = down + [submodule] + up - elif innermost: - upconv = nn.ConvTranspose3d(inner_nc, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv] - up = [uprelu, upconv, upnorm] - model = down + up - else: - upconv = nn.ConvTranspose3d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv, downnorm] - up = [uprelu, upconv, upnorm] - - if use_dropout: - model = down + [submodule] + up + [nn.Dropout(0.2)] - else: - model = down + [submodule] + up - - self.model = nn.Sequential(*model) - - def forward(self, x): - if self.outermost: - # print('outermost') - # print(x.size()) - return self.model(x) - else: # add skip connections - # print(x.size()) - # print(self.model(x).size()) - return torch.cat([x, self.model(x)], 1) - - -class NLayerDiscriminator3D(nn.Module): - """Defines a PatchGAN discriminator""" - - def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm3d, - kw=4, downsampling_kw=None, - ): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ngf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator3D, self).__init__() - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm3d has affine parameters - use_bias = norm_layer.func == nn.InstanceNorm3d - else: - use_bias = norm_layer == nn.InstanceNorm3d - - if downsampling_kw is None: - downsampling_kw = kw - - padw = 1 - ds_kw = downsampling_kw - sequence = [nn.Conv3d(input_nc, ngf, kernel_size=ds_kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) - sequence += [ - nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=ds_kw, stride=2, padding=padw, bias=use_bias), - norm_layer(ngf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) - sequence += [ - nn.Conv3d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), - norm_layer(ngf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - sequence += [nn.Conv3d(ngf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map - self.model = nn.Sequential(*sequence) - - def forward(self, input): - """Standard forward.""" - return self.model(input) - - -class ResnetGenerator3D(nn.Module): - """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. - We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) - """ - - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm3d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=nn.ReLU, add_noise=False): - """Construct a Resnet-based generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers - n_blocks (int) -- the number of ResNet blocks - padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zeros | valid - activation -- non-linearity layer to apply (default is ReLU) - add_noise -- whether to append a noise feature to the data prior to upsampling layers: True | False | 'param' - """ - assert(n_blocks >= 0) - super(ResnetGenerator3D, self).__init__() - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm3d - else: - use_bias = norm_layer == nn.InstanceNorm3d - - p = 0 - updown_p = 1 - padder = [] - if padding_type.lower() == 'reflect': - padder = [nn.ReflectionPad3d(3)] - elif padding_type.lower() == 'replicate': - padder = [nn.ReplicationPad3d(3)] - elif padding_type.lower() == 'zeros': - p = 3 - elif padding_type.lower() == 'valid': - p = 'valid' - updown_p = 0 - - model = [] - model += padder.copy() - model += [nn.Conv3d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), - norm_layer(ngf), - activation()] - - n_downsampling = 2 - for i in range(n_downsampling): # add downsampling layers - mult = 2 ** i - model += [nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias), #TODO: Make actually use padding_type for every convolution (currently does zeros if not valid) - norm_layer(ngf * mult * 2), - activation()] - - mult = 2 ** n_downsampling - for i in range(n_blocks): # add ResNet blocks - - model += [ResnetBlock3D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] - - if add_noise == 'param': # add noise feature if necessary - model += [ParameterizedNoiseBlock()] - elif add_noise: - model += [NoiseBlock()] - - for i in range(n_downsampling): # add upsampling layers - mult = 2 ** (n_downsampling - i) - model += [nn.ConvTranspose3d(ngf * mult + (i==0 and (add_noise is not False)), - int(ngf * mult / 2), - kernel_size=3, stride=2, - padding=updown_p, output_padding=updown_p, - bias=use_bias), - norm_layer(int(ngf * mult / 2)), - activation()] - model += padder.copy() - model += [nn.Conv3d(ngf, output_nc, kernel_size=7, padding=p)] - model += [nn.Tanh()] - - self.model = nn.Sequential(*model) - - def forward(self, input): - """Standard forward""" - return self.model(input) - - -class ResnetBlock3D(nn.Module): - """Define a Resnet block""" - - def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=nn.ReLU): - """Initialize the Resnet block - A resnet block is a conv block with skip connections - We construct a conv block with build_conv_block function, - and implement skip connections in function. - Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf - """ - super(ResnetBlock3D, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation) - self.padding_type = padding_type - - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=nn.ReLU): - """Construct a convolutional block. - Parameters: - dim (int) -- the number of channels in the conv layer. - padding_type (str) -- the name of padding layer: reflect | replicate | zeros | valid - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - use_bias (bool) -- if the conv layer uses bias or not - activation -- non-linearity layer to apply (default is ReLU) - Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer) - """ - p = 0 - padder = [] - if padding_type == 'reflect': - padder = [nn.ReflectionPad3d(1)] - elif padding_type == 'replicate': - padder = [nn.ReplicationPad3d(1)] - elif padding_type == 'zeros': - p = 1 - elif padding_type == 'valid': - p = 'valid' - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - - conv_block = [] - conv_block += padder.copy() - - conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), activation()] - if use_dropout: - conv_block += [nn.Dropout(0.2)] - - conv_block += padder.copy() - conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] - - return nn.Sequential(*conv_block) - - def crop(self, x, shape): - '''Center-crop x to match spatial dimensions given by shape.''' - - x_target_size = x.size()[:-3] + shape - - offset = tuple( - torch.div((a - b), 2, rounding_mode='trunc') - for a, b in zip(x.size(), x_target_size)) - - slices = tuple( - slice(o, o + s) - for o, s in zip(offset, x_target_size)) - - return x[slices] - - def forward(self, x): - """Forward function (with skip connections)""" - if self.padding_type == 'valid': # crop for valid networks - res = self.conv_block(x) - out = self.crop(x, res.size()[-3:]) + res - else: - out = x + self.conv_block(x) # add skip connections - return out - - From 525de2a68f41f6aa014007de2e94f24fcbf33f33 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 8 Aug 2022 17:12:01 -0400 Subject: [PATCH 3/3] reorganize --- .../default_cycleGAN_conf.json | 57 ++ .../split/seed13/train_conf.json | 4 - .../split/seed3/train_conf.json | 4 - .../split/seed42/train_conf.json | 4 - .../ieee-isbi-2022/split/train_conf.json | 18 - .../ieee-isbi-2022/train_conf.json | 61 -- raygun/torch/systems/CycleGAN.py | 884 ++++++++++++++++++ raygun/torch/utils/read_config.py | 4 +- scratch/freezeNorm.py | 2 +- 9 files changed, 945 insertions(+), 93 deletions(-) create mode 100644 raygun/torch/default_configs/default_cycleGAN_conf.json delete mode 100644 raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed13/train_conf.json delete mode 100644 raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed3/train_conf.json delete mode 100644 raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed42/train_conf.json delete mode 100644 raygun/torch/examples/batch_training/ieee-isbi-2022/split/train_conf.json delete mode 100644 raygun/torch/examples/batch_training/ieee-isbi-2022/train_conf.json create mode 100644 raygun/torch/systems/CycleGAN.py diff --git a/raygun/torch/default_configs/default_cycleGAN_conf.json b/raygun/torch/default_configs/default_cycleGAN_conf.json new file mode 100644 index 00000000..ee5a95d1 --- /dev/null +++ b/raygun/torch/default_configs/default_cycleGAN_conf.json @@ -0,0 +1,57 @@ +{ + "common_voxel_size": null, // voxel size to resample A and B into for training + "ndims": null, + "A_name": "raw", + "B_name": "raw", + "mask_A_name": null, // expects mask to be in same place as real zarr + "mask_B_name": null, + "A_out_path": null, + "B_out_path": null, + "model_name": "CycleGAN", + "gnet_type": "unet", + "dnet_type": "classic", + "dnet_kwargs": { + "input_nc": 1, + "downsampling_kw": 2, // downsampling factor + "kw": 3, // kernel size + "n_layers": 3, // number of layers in Discriminator networks + "ngf": 64 + }, + "loss_type": "cycle", // supports "link" or "split" + "loss_kwargs": {"g_lambda_dict": {"A": { + "l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper + "gan_loss": {"fake": 1, "cycled": 0} + }, + "B": { + "l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper + "gan_loss": {"fake": 1, "cycled": 0} + } + }, + "d_lambda_dict": {"A": {"real": 1, "fake": 1, "cycled": 0}, + "B": {"real": 1, "fake": 1, "cycled": 0} + } + }, + "sampling_bottleneck": false, + "optim_type": "Adam", + "optim_kwargs": {"betas": [0.9, 0.999], + "weight_decay": 0 + }, + "g_init_learning_rate": 1e-5, + "d_init_learning_rate": 1e-5, + "min_coefvar": null, + "interp_order": null, + "side_length": 64, // in common sized voxels + "batch_size": 1, + "num_workers": 11, + "cache_size": 50, + "spawn_subprocess": false, + "num_epochs": 20000, + "log_every": 20, + "save_every": 2000, + "model_path": "./models/", + "tensorboard_path": "./tensorboard/", + "verbose": true, + "checkpoint": null, // Used for prediction/rendering, training always starts from latest + "pretrain_gnet": false, + "random_seed": 42 +} \ No newline at end of file diff --git a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed13/train_conf.json b/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed13/train_conf.json deleted file mode 100644 index 06403cc3..00000000 --- a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed13/train_conf.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "include_config": "../train_kwargs.conf", - "random_seed": 13 -} diff --git a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed3/train_conf.json b/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed3/train_conf.json deleted file mode 100644 index 71561978..00000000 --- a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed3/train_conf.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "include_config": "../train_kwargs.conf", - "random_seed": 3 -} diff --git a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed42/train_conf.json b/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed42/train_conf.json deleted file mode 100644 index bc99d075..00000000 --- a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed42/train_conf.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "include_config": "../train_kwargs.conf", - "random_seed": 42 -} diff --git a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/train_conf.json b/raygun/torch/examples/batch_training/ieee-isbi-2022/split/train_conf.json deleted file mode 100644 index f539cc4f..00000000 --- a/raygun/torch/examples/batch_training/ieee-isbi-2022/split/train_conf.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "include_config": "../train_kwargs.conf", - - "loss_style": "custom", - "loss_kwargs": {"g_lambda_dict": {"A": { - "l1_loss": {"cycled": 3, "identity": 0}, - "gan_loss": {"fake": 1, "cycled": 0} - }, - "B": { - "l1_loss": {"cycled": 3, "identity": 0}, - "gan_loss": {"fake": 1, "cycled": 0} - } - }, - "d_lambda_dict": {"A": {"real": 1, "fake": 1, "cycled": 0}, - "B": {"real": 1, "fake": 1, "cycled": 0} - } - } -} \ No newline at end of file diff --git a/raygun/torch/examples/batch_training/ieee-isbi-2022/train_conf.json b/raygun/torch/examples/batch_training/ieee-isbi-2022/train_conf.json deleted file mode 100644 index 69331260..00000000 --- a/raygun/torch/examples/batch_training/ieee-isbi-2022/train_conf.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "src_A": "/n/groups/htem/ESRF_id16a/tomo_ML/ResolutionEnhancement/jlr54_tests/volumes/CBxs_lobV_overview_90nm_rec5iter_db9_l20p15_.n5", - "src_B": "/n/groups/htem/ESRF_id16a/tomo_ML/ResolutionEnhancement/jlr54_tests/volumes/CBxs_lobV_bottomp100um_30nm_rec_db9_.n5", - "A_voxel_size": [90, 90, 90], //voxel_size of A - "B_voxel_size": [30, 30, 30], //voxel_size of B - "common_voxel_size": [30, 30, 30], //voxel size to cast all data into - "ndims": 3, - "A_name": "volumes/raw", - "mask_A_name": "volumes/training_mask", - "B_name": "volumes/raw", - "mask_B_name": "volumes/volume_mask", - "batch_size": 1, - "num_workers": 16, - "cache_size": 96, - "min_coefvar": 1e-02, - "adam_betas": [0.5, 0.999], - - // "sampling_bottleneck": True, - // "loss_style": "cycle", - // "loss_kwargs": { "l1_lambda": 3, - // "identity_lambda": 0, - // // "gan_mode": "wgangp", - // }, - - "gnet_type": "resnet", - "gnet_kwargs": { - "input_nc": 1, - "output_nc": 1, - "norm_layer": "#partial(torch.nn.InstanceNorm2d, track_running_stats=True, momentum=0.01)#", - // "padding_type": "valid", - // "activation": torch.nn.SELU,//torch.nn.Tanh, //torch.nn.SiLU,// - // "add_noise": True, - "ngf": 64, - // "constant_upsample": True, // unet specific - // "fmap_inc_factor": 2, // unet specific - // "residual": True, // unet specific - "n_blocks": 9 // resnet specific - }, - "g_init_learning_rate": 0.00004, - - "pretrain_gnet": false, - - "d_init_learning_rate": 0.000004, - "dnet_type": "classic", //"resnet", - "dnet_kwargs": { - "input_nc": 1, - // "downsampling_kw": 2, - // "kw": 3, - "n_layers": 4, - // "output_nc": 1, - // "norm_layer": torch.nn.InstanceNorm3d,//partial(torch.nn.InstanceNorm3d, affine=True),//, track_running_stats=True), - // // "activation": torch.nn.SELU, - "ngf": 64 - // "n_blocks": 9, // resnet specific - }, - - "spawn_subprocess": true, - "side_length": 96, // requires odd number for valid resnet9 (which gives odd output) - "num_epochs": 300000, - "log_every": 20, -} \ No newline at end of file diff --git a/raygun/torch/systems/CycleGAN.py b/raygun/torch/systems/CycleGAN.py new file mode 100644 index 00000000..0f3e83a7 --- /dev/null +++ b/raygun/torch/systems/CycleGAN.py @@ -0,0 +1,884 @@ +from copy import deepcopy +import itertools +import random +from matplotlib import pyplot as plt +import torch +import glob +import re +import zarr +import daisy +import os + +import gunpowder as gp + +import logging +logger = logging.Logger('CycleGAN', 'INFO') + +import math +import functools +from tqdm import tqdm +import numpy as np + +torch.backends.cudnn.benchmark = True + +from networks import * +from models import CycleModel +from losses import LinkCycleLoss, SplitCycleLoss +from optimizers import BaseDummyOptimizer +from utils import read_config + +class CycleGAN(): #TODO: Just pass config file or dictionary + def __init__(self, config_file): + #Add default params + for key, value in read_config('../default_configs/default_cycleGAN_conf.json').items(): + setattr(self, key, value) + + #Get this configuration + for key, value in read_config(config_file).items(): + setattr(self, key, value) + + if self.common_voxel_size is None: + self.common_voxel_size = gp.Coordinate(daisy.open_ds(self.src_B, self.B_name).voxel_size) + else: + self.common_voxel_size = gp.Coordinate(self.common_voxel_size) + if self.ndims is None: + self.ndims = sum(np.array(self.common_voxel_size) == np.min(self.common_voxel_size)) + if self.A_out_path is None: + self.A_out_path = self.src_A + if self.B_out_path is None: + self.B_out_path = self.src_B + self.gnet_type = self.gnet_type.lower() + + self._set_verbose() + if self.checkpoint is None: + try: + self.checkpoint, self.iteration = self._get_latest_checkpoint() + except: + print('Checkpoint not found. Starting from scratch.') + self.checkpoint = None + + if self.random_seed is not None: + torch.manual_seed(self.random_seed) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + self.build_machine() + self.training_pipeline = None + self.test_training_pipeline = None + + def set_device(self, id=0): + self.device_id = id + os.environ["CUDA_VISIBLE_DEVICES"] = str(id) + + def set_verbose(self, verbose=True): + self.verbose = verbose + self._set_verbose() + + def _set_verbose(self): + if self.verbose: + logging.basicConfig(level=logging.INFO) + else: + logging.basicConfig(level=logging.WARNING) + + def batch_show(self, batch=None, i=0, show_mask=False): + if batch is None: + batch = self.batch + if not hasattr(self, 'col_dict'): + self.col_dict = {'REAL':0, 'FAKE':1, 'CYCL':2} + if show_mask: self.col_dict['MASK'] = 3 + rows = (self.real_A in batch.arrays) + (self.real_B in batch.arrays) + cols = 0 + for key in self.col_dict.keys(): + cols += key in [array.identifier[:4] for array in batch.arrays] + fig, axes = plt.subplots(rows, cols, figsize=(10*cols, 10*rows)) + for array, value in batch.items(): + label = array.identifier + if label[:4] in self.col_dict: + c = self.col_dict[label[:4]] + r = (int('_B' in label) + int('FAKE' in label)) % 2 + if len(value.data.shape) > 3: # pick one from the batch + img = value.data[i].squeeze() + else: + img = value.data.squeeze() + if len(img.shape) == 3: + mid = img.shape[0] // 2 # for 3D volume + data = img[mid] + else: + data = img + if rows == 1: + axes[c].imshow(data, cmap='gray', vmin=0, vmax=1) + axes[c].set_title(label) + else: + axes[r, c].imshow(data, cmap='gray', vmin=0, vmax=1) + axes[r, c].set_title(label) + + def write_tBoard_graph(self, batch=None): + if batch is None: + batch = self.batch + + ex_inputs = [] + if self.real_A in batch: + ex_inputs += [torch.tensor(batch[self.real_A].data)] + if self.real_B in batch: + ex_inputs += [torch.tensor(batch[self.real_B].data)] + + for i, ex_input in enumerate(ex_inputs): + if self.ndims == len(self.common_voxel_size): # add channel dimension if necessary + ex_input = ex_input.unsqueeze(axis=1) + if self.batch_size == 1: # ensure batch dimension is present + ex_input = ex_input.unsqueeze(axis=0) + ex_inputs[i] = ex_input + + try: + self.trainer.summary_writer.add_graph(self.model, ex_inputs) + except: + logger.warning('Failed to add model graph to tensorboard.') + + def batch_tBoard_write(self, i=0): + self.trainer.summary_writer.flush() + self.n_iter = self.trainer.iteration + + def _get_latest_checkpoint(self): + basename = self.model_path + self.model_name + def atoi(text): + return int(text) if text.isdigit() else text + + def natural_keys(text): + return [ atoi(c) for c in re.split(r'(\d+)', text) ] + + checkpoints = glob.glob(basename + '_checkpoint_*') + checkpoints.sort(key=natural_keys) + + if len(checkpoints) > 0: + + checkpoint = checkpoints[-1] + iteration = int(checkpoint.split('_')[-1]) + return checkpoint, iteration + + return None, 0 + + def get_extents(self, side_length=None, array_name=None): + if side_length is None: + side_length = self.side_length + + if ('padding_type' in self.gnet_kwargs) and (self.gnet_kwargs['padding_type'].lower() == 'valid'): + if array_name is not None and not ('real' in array_name.lower() or 'mask' in array_name.lower()): + shape = (1,1) + (side_length,) * self.ndims + pars = [par for par in self.netG1.parameters()] + result = self.netG1(torch.zeros(*shape, device=pars[0].device)) + if 'fake' in array_name.lower(): + side_length = result.shape[-1] + elif 'cycle' in array_name.lower(): + result = self.netG1(result) + side_length = result.shape[-1] + + extents = np.ones((len(self.common_voxel_size))) + extents[-self.ndims:] = side_length # assumes first dimension is z (i.e. the dimension breaking isotropy) + return gp.Coordinate(extents) + + def get_valid_context(self, side_length=None): + # returns number of pixels to crop from a side to trim network outputs to valid FOV + if side_length is None: + side_length = self.side_length + + gnet_kwargs = self.gnet_kwargs.copy() + gnet_kwargs['padding_type'] = 'valid' + gnet = self.get_generator(gnet_kwargs=gnet_kwargs) + + shape = (1,1) + (side_length,) * self.ndims + pars = [par for par in gnet.parameters()] + result = gnet(torch.zeros(*shape, device=pars[0].device)) + return ((gp.Coordinate(shape) - gp.Coordinate(result.shape)) / 2)[-self.ndims:] + + def get_valid_crop(self, side_length=None): + # returns number of pixels to crop from a side to trim network outputs to valid FOV + if side_length is None: + side_length = self.side_length + + gnet_kwargs = self.gnet_kwargs.copy() + gnet_kwargs['padding_type'] = 'valid' + gnet = self.get_generator(gnet_kwargs=gnet_kwargs) + + shape = (1,1) + (side_length,) * self.ndims + pars = [par for par in gnet.parameters()] + result = gnet(torch.zeros(*shape, device=pars[0].device)) + pad = np.floor((gp.Coordinate(shape) - gp.Coordinate(result.shape)) / 2) + raise 'get_valid_crop() not fully implemented (in conflict with other usage of cropping' + return gp.Coordinate(pad[-self.ndims:]) + + def get_downsample_factors(self, net_kwargs): + if 'downsample_factors' not in net_kwargs: + down_factor = 2 if 'down_factor' not in net_kwargs else net_kwargs.pop('down_factor') + num_downs = 3 if 'num_downs' not in net_kwargs else net_kwargs.pop('num_downs') + net_kwargs.update({'downsample_factors': [(down_factor,)*self.ndims,] * (num_downs - 1)}) + return net_kwargs + + def pretrain_generator(self, gnet=None, iter=1000, accel_factor=100, loss_fn=torch.nn.HuberLoss()): + if gnet is None: + gnet = self.get_generator() + optimizer = torch.optim.Adam(gnet.parameters(), lr=self.g_init_learning_rate*accel_factor, betas=self.adam_betas, weight_decay=self.adam_decay) + shape = (1,1) + (self.side_length,) * self.ndims + pars = [par for par in gnet.parameters()] + test = torch.rand(*shape, device=pars[0].device, requires_grad=True) * 2 - 1 + pbar = tqdm(range(iter)) + for i in pbar: + out = gnet(test) + loss = loss_fn(out, test) + pbar.set_postfix({'loss': loss.item()}) + loss.backward() + optimizer.step() + test = torch.rand(*shape, requires_grad=True, device=pars[0].device) * 2 - 1 + + print(f'Final loss: {loss.item()}') + #TODO: Figure out how to get this to work for spawning workers later + # gnet = gnet.to('cpu') + # torch.cuda.empty_cache() + # cudart.cudaDeviceReset() + return gnet + + def get_generator(self, gnet_kwargs=None): + if gnet_kwargs is None: + gnet_kwargs = self.gnet_kwargs + + if self.gnet_type == 'unet': + + # if self.ndims == 2: + # generator = UnetGenerator(**self.gnet_kwargs) + + # elif self.ndims == 3: + # generator = UnetGenerator3D(**self.gnet_kwargs) + + # else: + # raise f'Unet generators only specified for 2D or 3D, not {self.ndims}D' + gnet_kwargs = self.get_downsample_factors(gnet_kwargs) + + generator = torch.nn.Sequential( + UNet(**gnet_kwargs), + # ConvPass(self.gnet_kwargs['ngf'], self.gnet_kwargs['output_nc'], [(1,)*self.ndims], activation=None, padding=self.gnet_kwargs['padding_type']), + torch.nn.Tanh() + ) + + elif self.gnet_type == 'residual_unet': + gnet_kwargs = self.get_downsample_factors(gnet_kwargs) + + generator = torch.nn.Sequential( + ResidualUNet(**gnet_kwargs), + torch.nn.Tanh() + ) + + elif self.gnet_type == 'resnet': + + if self.ndims == 2: + generator = ResnetGenerator(**gnet_kwargs) + + elif self.ndims == 3: + generator = ResnetGenerator3D(**gnet_kwargs) + + else: + raise f'Resnet generators only specified for 2D or 3D, not {self.ndims}D' + + else: + + raise f'Unknown generator type requested: {self.gnet_type}' + + activation = gnet_kwargs['activation'] if 'activation' in gnet_kwargs else nn.ReLU + + if activation is not None: + # if activation == nn.SELU: + # init_weights(generator, init_type='kaiming', nonlinearity='linear') # For Self-Normalizing Neural Networks + # else: + init_weights(generator, init_type='kaiming', nonlinearity=activation.__class__.__name__.lower()) + else: + init_weights(generator, init_type='normal', init_gain=0.05) #TODO: MAY WANT TO ADD TO CONFIG FILE + return generator + + def get_discriminator(self, dnet_kwargs=None): + if dnet_kwargs is None: + dnet_kwargs = self.dnet_kwargs + + if self.dnet_type == 'unet': + + # if self.ndims == 2: + # discriminator = UnetGenerator(**self.dnet_kwargs) + + # elif self.ndims == 3: + # discriminator = UnetGenerator3D(**self.dnet_kwargs) + + # else: + # raise f'Unet discriminators only specified for 2D or 3D, not {self.ndims}D' + dnet_kwargs = self.get_downsample_factors(dnet_kwargs) + + discriminator = torch.nn.Sequential( + UNet(**dnet_kwargs), + # ConvPass(self.dnet_kwargs['ngf'], self.dnet_kwargs['output_nc'], [(1,)*self.ndims], activation=None, padding=self.dnet_kwargs['padding_type']), + torch.nn.Tanh() + ) + + elif self.dnet_type == 'residualunet': + dnet_kwargs = self.get_downsample_factors(dnet_kwargs) + + discriminator = torch.nn.Sequential( + ResidualUNet(**dnet_kwargs), + torch.nn.Tanh() + ) + + elif self.dnet_type == 'resnet': + + if self.ndims == 2: + discriminator = ResnetGenerator(**dnet_kwargs) + + elif self.ndims == 3: + discriminator = ResnetGenerator3D(**dnet_kwargs) + + else: + raise f'Resnet discriminators only specified for 2D or 3D, not {self.ndims}D' + + elif self.dnet_type == 'classic': + if self.ndims == 3: #3D case + norm_instance = torch.nn.InstanceNorm3d + discriminator_maker = NLayerDiscriminator3D + elif self.ndims == 2: + norm_instance = torch.nn.InstanceNorm2d + discriminator_maker = NLayerDiscriminator + + dnet_kwargs['norm_layer'] = functools.partial(norm_instance, affine=False, track_running_stats=False) + discriminator = discriminator_maker(**dnet_kwargs) + + init_weights(discriminator, init_type='kaiming') + return discriminator + + else: + + raise f'Unknown discriminator type requested: {self.dnet_type}' + + activation = dnet_kwargs['activation'] if 'activation' in dnet_kwargs else nn.ReLU + + if activation is not None: + # if activation == nn.SELU: + # init_weights(discriminator, init_type='kaiming', nonlinearity='linear') # For Self-Normalizing Neural Networks + # else: + init_weights(discriminator, init_type='kaiming', nonlinearity=activation.__class__.__name__.lower()) + else: + init_weights(discriminator, init_type='normal', init_gain=0.05) #TODO: MAY WANT TO ADD TO CONFIG FILE + return discriminator + + def setup_networks(self): + if self.pretrain_gnet: + self.netG1 = self.pretrain_generator() + self.netG2 = deepcopy(self.netG1) + else: + self.netG1 = self.get_generator() + self.netG2 = self.get_generator() + + self.netD1 = self.get_discriminator() + self.netD2 = self.get_discriminator() + + def setup_model(self): + if not hasattr(self, 'netG1'): + self.setup_networks() + + if self.sampling_bottleneck: + scale_factor_A = tuple(np.divide(self.common_voxel_size, self.A_voxel_size)[-self.ndims:]) + if not any([s < 1 for s in scale_factor_A]): scale_factor_A = None + scale_factor_B = tuple(np.divide(self.common_voxel_size, self.B_voxel_size)[-self.ndims:]) + if not any([s < 1 for s in scale_factor_B]): scale_factor_B = None + else: + scale_factor_A, scale_factor_B = None, None + + if self.loss_style.lower()=='cycle': + + self.model = CycleGAN_Model(self.netG1, self.netD1, self.netG2, self.netD2, scale_factor_A, scale_factor_B) + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG1.parameters(), self.netG2.parameters()), lr=self.g_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()), lr=self.d_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer = CycleGAN_Optimizer(self.optimizer_G, self.optimizer_D) + + self.loss = CycleGAN_Loss(self.netD1, self.netG1, self.netD2, self.netG2, self.optimizer_D, self.optimizer_G, self.ndims, **self.loss_kwargs) + + elif self.loss_style.lower()=='split': + + self.model = CycleGAN_Split_Model(self.netG1, self.netD1, self.netG2, self.netD2, scale_factor_A, scale_factor_B) + self.optimizer_G1 = torch.optim.Adam(self.netG1.parameters(), lr=self.g_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer_G2 = torch.optim.Adam(self.netG2.parameters(), lr=self.g_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()), lr=self.d_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer = Split_CycleGAN_Optimizer(self.optimizer_G1, self.optimizer_G2, self.optimizer_D) + + self.loss = SplitGAN_Loss(self.netD1, self.netG1, self.netD2, self.netG2, self.optimizer_G1, self.optimizer_G2, self.optimizer_D, self.ndims, **self.loss_kwargs) + + elif self.loss_style.lower()=='custom': + + self.model = CycleGAN_Split_Model(self.netG1, self.netD1, self.netG2, self.netD2, scale_factor_A, scale_factor_B) + self.optimizer_G1 = torch.optim.Adam(self.netG1.parameters(), lr=self.g_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer_G2 = torch.optim.Adam(self.netG2.parameters(), lr=self.g_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()), lr=self.d_init_learning_rate, betas=self.adam_betas, weight_decay=self.adam_decay) + self.optimizer = Split_CycleGAN_Optimizer(self.optimizer_G1, self.optimizer_G2, self.optimizer_D) + + self.loss = Custom_Loss(self.netD1, self.netG1, self.netD2, self.netG2, self.optimizer_G1, self.optimizer_G2, self.optimizer_D, self.ndims, **self.loss_kwargs) + + else: + + print("Unexpected Loss Style. Accepted options are 'cycle' or 'split'") + raise + + def build_machine(self): + # initialize needed variables + self.arrays = [] + + # define our network model for training + self.setup_networks() + self.setup_model() + + # get performance stats + self.performance = gp.PrintProfilingStats(every=self.log_every) + + # setup a cache + self.cache = gp.PreCache(num_workers=self.num_workers, cache_size=self.cache_size)#os.cpu_count()) + + # define axes for mirroring and transpositions + self.augment_axes = list(np.arange(3)[-self.ndims:]) #TODO: Maybe limit to xy? + + # build datapipes + self.datapipe_A = self.get_datapipe('A') + self.datapipe_B = self.get_datapipe('B') #datapipe has: train_pipe, source, reject, resample, augment, unsqueeze, etc.} + + def get_datapipe(self, side): + datapipe = type('DataPipe', (object,), {}) # make simple object to smoothly store variables + side = side.upper() # ensure uppercase + + datapipe.src_path = getattr(self, 'src_'+side)# the zarr container + datapipe.real_name = getattr(self, side+'_name') + datapipe.src_voxel_size = daisy.open_ds(datapipe.src_path, datapipe.real_name).voxel_size + + # declare arrays to use in the pipelines + array_names = ['real', + 'fake', + 'cycled'] + if getattr(self, f'mask_{side}_name') is not None: + array_names += ['mask'] + datapipe.masked = True + else: + datapipe.masked = False + + for array in array_names: + if 'fake' in array: + other_side = ['A','B'] + other_side.remove(side) + array_name = array + '_' + other_side[0] + else: + array_name = array + '_' + side + array_key = gp.ArrayKey(array_name.upper()) + setattr(datapipe, array, array_key) # add ArrayKeys to object + setattr(self, array_name, array_key) + self.arrays += [array_key] + #add normalizations and scaling, if appropriate + if 'mask' not in array: + setattr(datapipe, 'scaletanh2img_'+array, gp.IntensityScaleShift(array_key, 0.5, 0.5)) + setattr(self, 'scaletanh2img_'+array_name, gp.IntensityScaleShift(array_key, 0.5, 0.5)) + + if 'real' in array: + setattr(datapipe, 'normalize_'+array, gp.Normalize(array_key)) + setattr(self, 'normalize_'+array_name, gp.Normalize(array_key)) + setattr(datapipe, 'scaleimg2tanh_'+array, gp.IntensityScaleShift(array_key, 2, -1)) + setattr(self, 'scaleimg2tanh_'+array_name, gp.IntensityScaleShift(array_key, 2, -1)) + + #Setup sources and resampling nodes + if self.common_voxel_size != datapipe.src_voxel_size: + datapipe.real_src = gp.ArrayKey(f'REAL_{side}_SRC') + setattr(self, f'real_{side}_src', datapipe.real_src) + datapipe.resample = gp.Resample(datapipe.real_src, self.common_voxel_size, datapipe.real, ndim=self.ndims, interp_order=self.interp_order) + if datapipe.masked: + datapipe.mask_src = gp.ArrayKey(f'MASK_{side}_SRC') + setattr(self, f'mask_{side}_src', datapipe.mask_src) + datapipe.resample += gp.Resample(datapipe.mask_src, self.common_voxel_size, datapipe.mask, ndim=self.ndims, interp_order=self.interp_order) + else: + datapipe.real_src = datapipe.real + datapipe.resample = None + if datapipe.masked: + datapipe.mask_src = datapipe.mask + + # setup data sources + datapipe.out_path = getattr(self, side+'_out_path') + datapipe.src_names = {datapipe.real_src: datapipe.real_name} + datapipe.src_specs = {datapipe.real_src: gp.ArraySpec(interpolatable=True, voxel_size=datapipe.src_voxel_size)} + if datapipe.masked: + datapipe.mask_name = getattr(self, f'mask_{side}_name') + datapipe.src_names[datapipe.mask_src] = datapipe.mask_name + datapipe.src_specs[datapipe.mask_src] = gp.ArraySpec(interpolatable=False) + datapipe.source = gp.ZarrSource( # add the data source + datapipe.src_path, + datapipe.src_names, # which dataset to associate to the array key + datapipe.src_specs # meta-information + ) + + # setup rejections + datapipe.reject = None + if datapipe.masked: + datapipe.reject = gp.Reject(mask = datapipe.mask_src, min_masked=0.999) + + if self.min_coefvar: + if datapipe.reject is None: + datapipe.reject = gp.RejectConstant(datapipe.real_src, min_coefvar = self.min_coefvar) + else: + datapipe.reject += gp.RejectConstant(datapipe.real_src, min_coefvar = self.min_coefvar) + + datapipe.augment = gp.SimpleAugment(mirror_only = self.augment_axes, transpose_only = self.augment_axes) + datapipe.augment += datapipe.normalize_real + datapipe.augment += datapipe.scaleimg2tanh_real + datapipe.augment += gp.ElasticAugment( #TODO: MAKE THESE SPECS PART OF CONFIG + control_point_spacing=100, # self.side_length//2, + # jitter_sigma=(5.0,)*self.ndims, + jitter_sigma=(0., 5.0, 5.0,)[-self.ndims:], + rotation_interval=(0, math.pi/2), + subsample=4, + spatial_dims=self.ndims + ) + + # add "channel" dimensions if neccessary, else use z dimension as channel + if self.ndims == len(self.common_voxel_size): + datapipe.unsqueeze = gp.Unsqueeze([datapipe.real]) + else: + datapipe.unsqueeze = None + + # Make post-net data pipes + # remove "channel" dimensions if neccessary + datapipe.postnet_pipe = type('SubDataPipe', (object,), {}) + datapipe.postnet_pipe.nocycle = datapipe.scaletanh2img_real + datapipe.scaletanh2img_fake + datapipe.postnet_pipe.cycle = datapipe.scaletanh2img_real + datapipe.scaletanh2img_fake + datapipe.scaletanh2img_cycled + if self.ndims == len(self.common_voxel_size): + datapipe.postnet_pipe.nocycle += gp.Squeeze([datapipe.real, + datapipe.fake, + ], axis=1) # remove channel dimension for grayscale + datapipe.postnet_pipe.cycle += gp.Squeeze([datapipe.real, + datapipe.fake, + datapipe.cycled, + ], axis=1) # remove channel dimension for grayscale + + # Make training datapipe + datapipe.train_pipe = datapipe.source + gp.RandomLocation() + if datapipe.reject: + datapipe.train_pipe += datapipe.reject + if datapipe.resample: + datapipe.train_pipe += datapipe.resample + datapipe.train_pipe += datapipe.augment + if datapipe.unsqueeze: + datapipe.train_pipe += datapipe.unsqueeze # add "channel" dimensions if neccessary, else use z dimension as channel + datapipe.train_pipe += gp.Stack(self.batch_size)# add "batch" dimensions + + + # Make predicting datapipe + datapipe.predict_pipe = datapipe.source + if datapipe.reject: + datapipe.predict_pipe += datapipe.reject + if datapipe.resample: + datapipe.predict_pipe += datapipe.resample + datapipe.predict_pipe += datapipe.normalize_real + datapipe.scaleimg2tanh_real + if datapipe.unsqueeze: + datapipe.predict_pipe += datapipe.unsqueeze # add "channel" dimensions if neccessary, else use z dimension as channel + datapipe.predict_pipe += gp.Stack(1)# add "batch" dimensions + setattr(self, 'pipe_'+side, datapipe.train_pipe) + + return datapipe + + def build_training_pipeline(self): + # create a train node using our model, loss, and optimizer + self.trainer = gp.torch.Train( + self.model, + self.loss, + self.optimizer, + inputs = { + 'real_A': self.real_A, + 'real_B': self.real_B + }, + outputs = { + 0: self.fake_B, + 1: self.cycled_B, + 2: self.fake_A, + 3: self.cycled_A + }, + loss_inputs = { + 0: self.real_A, + 1: self.fake_A, + 2: self.cycled_A, + 3: self.real_B, + 4: self.fake_B, + 5: self.cycled_B, + }, + log_dir=self.tensorboard_path, + log_every=self.log_every, + checkpoint_basename=self.model_path+self.model_name, + save_every=self.save_every, + spawn_subprocess=self.spawn_subprocess + ) + + # assemble pipeline + self.training_pipeline = (self.pipe_A, self.pipe_B) + gp.MergeProvider() #merge upstream pipelines for two sources + self.training_pipeline += self.trainer + self.training_pipeline += self.datapipe_A.postnet_pipe.cycle + self.datapipe_B.postnet_pipe.cycle + if self.batch_size == 1: + self.training_pipeline += gp.Squeeze([self.real_A, + self.fake_A, + self.cycled_A, + self.real_B, + self.fake_B, + self.cycled_B + ], axis=0) + self.test_training_pipeline = self.training_pipeline.copy() + self.performance + self.training_pipeline += self.cache + + # create request + self.train_request = gp.BatchRequest() + for array in self.arrays: + extents = self.get_extents(array_name=array.identifier) + self.train_request.add(array, self.common_voxel_size * extents, self.common_voxel_size) + + def test_train(self): + if self.test_training_pipeline is None: + self.build_training_pipeline() + self.model.train() + with gp.build(self.test_training_pipeline): + self.batch = self.test_training_pipeline.request_batch(self.train_request) + self.batch_show() + return self.batch + + def train(self): + if self.training_pipeline is None: + self.build_training_pipeline() + self.model.train() + with gp.build(self.training_pipeline): + for i in tqdm(range(self.num_epochs)): + # this_request = copy.deepcopy(self.train_request) + # this_request._random_seed = random.randint(0, 2**32) + # self.batch = self.training_pipeline.request_batch(this_request) + self.batch = self.training_pipeline.request_batch(self.train_request) + if i == 1: + self.write_tBoard_graph() + if hasattr(self.loss, 'loss_dict'): + print(self.loss.loss_dict) + if i % self.log_every == 0: + self.batch_tBoard_write() + return self.batch + + def test_prediction(self, side='A', side_length=None, cycle=True): + #set model into evaluation mode + self.model.eval() + self.model.cycle = cycle + # model_outputs = { + # 0: self.fake_B, + # 1: self.cycled_B, + # 2: self.fake_A, + # 3: self.cycled_A} + + #datapipe has: train_pipe, source, reject, resample, augment, unsqueeze, etc.} + datapipe = getattr(self, 'datapipe_'+side) + arrays = [datapipe.real, datapipe.fake] + if cycle: + arrays += [datapipe.cycled] + squeeze_arrays = arrays.copy() + if datapipe.masked: + arrays += [datapipe.mask] + + input_dict = {'real_'+side: datapipe.real} + + if side=='A': + output_dict = {0: datapipe.fake} + if cycle: + output_dict[3] = datapipe.cycled + else: + output_dict = {2: datapipe.fake} + if cycle: + output_dict[1] = datapipe.cycled + + predict_pipe = datapipe.source + gp.RandomLocation() + if datapipe.reject: predict_pipe += datapipe.reject + if datapipe.resample: predict_pipe += datapipe.resample + predict_pipe += datapipe.normalize_real + predict_pipe += datapipe.scaleimg2tanh_real + + if datapipe.unsqueeze: # add "channel" dimensions if neccessary, else use z dimension as channel + predict_pipe += datapipe.unsqueeze + predict_pipe += gp.Unsqueeze([datapipe.real]) # add batch dimension + + predict_pipe += gp.torch.Predict(self.model, + inputs = input_dict, + outputs = output_dict, + checkpoint = self.checkpoint + ) + + if cycle: + predict_pipe += datapipe.postnet_pipe.cycle + else: + predict_pipe += datapipe.postnet_pipe.nocycle + + predict_pipe += gp.Squeeze(squeeze_arrays, axis=0) # remove batch dimension + + request = gp.BatchRequest(random_seed=random.randint(0, 4294967295)) + for array in arrays: + extents = self.get_extents(side_length, array_name=array.identifier) + request.add(array, self.common_voxel_size * extents, self.common_voxel_size) + + with gp.build(predict_pipe): + self.batch = predict_pipe.request_batch(request) + + self.batch_show() + return self.batch + + def render_full(self, side='A', side_length=None, cycle=False, crop_to_valid=False, test=False, label_dict=None): + raise DeprecationWarning() + #CYCLED CURRENTLY SAVED IN UPSAMPLED FORM (i.e. not original voxel size) + #set model into evaluation mode + self.model.eval() + self.model.cycle = cycle + # model_outputs = { + # 0: self.fake_B, + # 1: self.cycled_B, + # 2: self.fake_A, + # 3: self.cycled_A} + + side_length = self.side_length if side_length is None else side_length + + #datapipe has: train_pipe, source, reject, resample, augment, unsqueeze, etc.} + datapipe = getattr(self, 'datapipe_'+side) + arrays = [datapipe.real, datapipe.fake] + if cycle: + arrays += [datapipe.cycled] + + input_dict = {'real_'+side: datapipe.real} + + if side=='A': + output_dict = {0: datapipe.fake} + if cycle: + output_dict[3] = datapipe.cycled + else: + output_dict = {2: datapipe.fake} + if cycle: + output_dict[1] = datapipe.cycled + + # set prediction spec + if datapipe.source.spec is None: + data_file = zarr.open(datapipe.src_path) + pred_spec = datapipe.source._Hdf5LikeSource__read_spec(datapipe.real, data_file, datapipe.real_name).copy() + else: + pred_spec = datapipe.source.spec[datapipe.real].copy() + pred_spec.voxel_size = self.common_voxel_size + + if label_dict is None: + dataset_names = {datapipe.fake: 'volumes/'+self.model_name+'_enFAKE'} + if cycle: + dataset_names[datapipe.cycled] = 'volumes/'+self.model_name+'_enCYCLED' + else: + dataset_names = {datapipe.fake: 'volumes/'+label_dict['fake']} + if cycle: + dataset_names[datapipe.cycled] = 'volumes/'+label_dict['cycled'] + + + # Calculate padding if necessary: + if crop_to_valid: + if crop_to_valid is True: + px_pad = self.get_valid_crop(side_length=side_length) + else: + px_pad = crop_to_valid + self.model.set_crop_pad(px_pad, self.ndims) + coor_pad = np.zeros((len(self.common_voxel_size))) + coor_pad[-self.ndims:] = px_pad # assumes first dimension is z (i.e. the dimension breaking isotropy) + coor_pad = gp.Coordinate(coor_pad) + + scan_request = gp.BatchRequest() + for array in arrays: + if array is not datapipe.real: + extents = self.get_extents(side_length, array_name=array.identifier) + if crop_to_valid: + extents -= coor_pad * 2 + if array is datapipe.cycled: + extents -= coor_pad * 2 + scan_request.add(array, self.common_voxel_size * extents, self.common_voxel_size) + + render_pipe = datapipe.source + if datapipe.resample: render_pipe += datapipe.resample + + render_pipe += datapipe.normalize_real + render_pipe += datapipe.scaleimg2tanh_real + + if datapipe.unsqueeze: # add "channel" dimensions if neccessary, else use z dimension as channel + render_pipe += datapipe.unsqueeze + render_pipe += gp.Unsqueeze([datapipe.real]) # add batch dimension + + render_pipe += gp.torch.Predict(self.model, + inputs = input_dict, + outputs = output_dict, + checkpoint = self.checkpoint, + spawn_subprocess=self.spawn_subprocess + ) + + if cycle: + render_pipe += datapipe.postnet_pipe.cycle + else: + render_pipe += datapipe.postnet_pipe.nocycle + render_pipe += gp.Squeeze(arrays[1:], axis=0) # remove batch dimension + + # Convert float32 on [0,1] to uint8 on [0,255] + render_pipe += gp.IntensityScaleShift(datapipe.fake, 255, 0) + render_pipe += gp.AsType(datapipe.fake, np.uint8) + if cycle: + render_pipe += gp.IntensityScaleShift(datapipe.cycled, 255, 0) + render_pipe += gp.AsType(datapipe.cycled, np.uint8) + + if test: + return scan_request, render_pipe + + # Declare new array to write to + if not hasattr(self, 'compressor'): + self.compressor = {'id': 'blosc', + 'clevel': 3, + 'cname': 'blosclz', + # 'blocksize': 64 + } + source_ds = daisy.open_ds(datapipe.src_path, datapipe.real_name) + datapipe.total_roi = source_ds.data_roi.snap_to_grid(self.common_voxel_size, 'shrink') + for key, name in dataset_names.items(): + write_size = scan_request[key].roi.get_shape() + daisy.prepare_ds( + datapipe.out_path, + name, + datapipe.total_roi, + daisy.Coordinate(self.common_voxel_size), + np.uint8, + write_size=write_size, + num_channels=1, + compressor=self.compressor, + delete=True) + + render_pipe += gp.ZarrWrite( + dataset_names = dataset_names, + output_filename = datapipe.out_path, + compression_type = self.compressor + ) + + render_pipe += gp.Scan(scan_request, num_workers=self.num_workers, cache_size=self.cache_size) + + request = gp.BatchRequest() + + print(f'Full rendering pipeline declared for input type {side}. Building...') + with gp.build(render_pipe): + print('Starting full volume render...') + render_pipe.request_batch(request) + print('Finished.') + + def load_saved_model(self, checkpoint=None, cuda_available=None): + if cuda_available is None: + cuda_available = torch.cuda.is_available() + if checkpoint is None: + checkpoint = self.checkpoint + else: + self.checkpoint = checkpoint + + if checkpoint is not None: + if not cuda_available: + checkpoint = torch.load(checkpoint, map_location=torch.device('cpu')) + else: + checkpoint = torch.load(checkpoint) + + if "model_state_dict" in checkpoint: + self.model.load_state_dict(checkpoint["model_state_dict"]) + else: + self.model.load_state_dict(checkpoint) + else: + logger.warning('No saved checkpoint found.') diff --git a/raygun/torch/utils/read_config.py b/raygun/torch/utils/read_config.py index 4b15df98..dd7eae41 100644 --- a/raygun/torch/utils/read_config.py +++ b/raygun/torch/utils/read_config.py @@ -34,10 +34,12 @@ def eval_args(config, file): elif isinstance(v, str): if '$working_dir' in v: v = v.replace('$working_dir', os.path.dirname(file)) + if v[0] == '#' and v[-1] == '#': v = eval(v[1:-1]) elif v.count('#') > 0 and v.count('#') % 2 == 0: - v = _eval_args(v) + v = _eval_args(v) + config[k] = v return config diff --git a/scratch/freezeNorm.py b/scratch/freezeNorm.py index 31dacc53..9d40ae3b 100644 --- a/scratch/freezeNorm.py +++ b/scratch/freezeNorm.py @@ -177,7 +177,7 @@ def eval_models(data_src, models): #%% model_kwargs = { - 'activation': torch.nn.SELU + # 'activation': torch.nn.SELU } model_names = ['allTrain',