Skip to content

Commit

Permalink
Merge branch 'full_refactor' of https://github.com/htem/raygun into j…
Browse files Browse the repository at this point in the history
…ax_refactor
  • Loading branch information
brianreicher committed Aug 9, 2022
2 parents b3b8ca4 + 525de2a commit 149c80e
Show file tree
Hide file tree
Showing 20 changed files with 1,921 additions and 876 deletions.
57 changes: 57 additions & 0 deletions raygun/torch/default_configs/default_cycleGAN_conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"common_voxel_size": null, // voxel size to resample A and B into for training
"ndims": null,
"A_name": "raw",
"B_name": "raw",
"mask_A_name": null, // expects mask to be in same place as real zarr
"mask_B_name": null,
"A_out_path": null,
"B_out_path": null,
"model_name": "CycleGAN",
"gnet_type": "unet",
"dnet_type": "classic",
"dnet_kwargs": {
"input_nc": 1,
"downsampling_kw": 2, // downsampling factor
"kw": 3, // kernel size
"n_layers": 3, // number of layers in Discriminator networks
"ngf": 64
},
"loss_type": "cycle", // supports "link" or "split"
"loss_kwargs": {"g_lambda_dict": {"A": {
"l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper
"gan_loss": {"fake": 1, "cycled": 0}
},
"B": {
"l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper
"gan_loss": {"fake": 1, "cycled": 0}
}
},
"d_lambda_dict": {"A": {"real": 1, "fake": 1, "cycled": 0},
"B": {"real": 1, "fake": 1, "cycled": 0}
}
},
"sampling_bottleneck": false,
"optim_type": "Adam",
"optim_kwargs": {"betas": [0.9, 0.999],
"weight_decay": 0
},
"g_init_learning_rate": 1e-5,
"d_init_learning_rate": 1e-5,
"min_coefvar": null,
"interp_order": null,
"side_length": 64, // in common sized voxels
"batch_size": 1,
"num_workers": 11,
"cache_size": 50,
"spawn_subprocess": false,
"num_epochs": 20000,
"log_every": 20,
"save_every": 2000,
"model_path": "./models/",
"tensorboard_path": "./tensorboard/",
"verbose": true,
"checkpoint": null, // Used for prediction/rendering, training always starts from latest
"pretrain_gnet": false,
"random_seed": 42
}

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

63 changes: 63 additions & 0 deletions raygun/torch/losses/GANLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ORIGINALLY WRITTEN BY TRI NGUYEN (HARVARD, 2021)
import torch

class GANLoss(torch.nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""

def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = torch.nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = torch.nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)

def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - typically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""

if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)

def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - typically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
Loading

0 comments on commit 149c80e

Please sign in to comment.