-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'full_refactor' of https://github.com/htem/raygun into j…
…ax_refactor
- Loading branch information
Showing
20 changed files
with
1,921 additions
and
876 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
{ | ||
"common_voxel_size": null, // voxel size to resample A and B into for training | ||
"ndims": null, | ||
"A_name": "raw", | ||
"B_name": "raw", | ||
"mask_A_name": null, // expects mask to be in same place as real zarr | ||
"mask_B_name": null, | ||
"A_out_path": null, | ||
"B_out_path": null, | ||
"model_name": "CycleGAN", | ||
"gnet_type": "unet", | ||
"dnet_type": "classic", | ||
"dnet_kwargs": { | ||
"input_nc": 1, | ||
"downsampling_kw": 2, // downsampling factor | ||
"kw": 3, // kernel size | ||
"n_layers": 3, // number of layers in Discriminator networks | ||
"ngf": 64 | ||
}, | ||
"loss_type": "cycle", // supports "link" or "split" | ||
"loss_kwargs": {"g_lambda_dict": {"A": { | ||
"l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper | ||
"gan_loss": {"fake": 1, "cycled": 0} | ||
}, | ||
"B": { | ||
"l1_loss": {"cycled": 10, "identity": 0.5}, // Default from CycleGAN paper | ||
"gan_loss": {"fake": 1, "cycled": 0} | ||
} | ||
}, | ||
"d_lambda_dict": {"A": {"real": 1, "fake": 1, "cycled": 0}, | ||
"B": {"real": 1, "fake": 1, "cycled": 0} | ||
} | ||
}, | ||
"sampling_bottleneck": false, | ||
"optim_type": "Adam", | ||
"optim_kwargs": {"betas": [0.9, 0.999], | ||
"weight_decay": 0 | ||
}, | ||
"g_init_learning_rate": 1e-5, | ||
"d_init_learning_rate": 1e-5, | ||
"min_coefvar": null, | ||
"interp_order": null, | ||
"side_length": 64, // in common sized voxels | ||
"batch_size": 1, | ||
"num_workers": 11, | ||
"cache_size": 50, | ||
"spawn_subprocess": false, | ||
"num_epochs": 20000, | ||
"log_every": 20, | ||
"save_every": 2000, | ||
"model_path": "./models/", | ||
"tensorboard_path": "./tensorboard/", | ||
"verbose": true, | ||
"checkpoint": null, // Used for prediction/rendering, training always starts from latest | ||
"pretrain_gnet": false, | ||
"random_seed": 42 | ||
} |
4 changes: 0 additions & 4 deletions
4
raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed13/train_conf.json
This file was deleted.
Oops, something went wrong.
4 changes: 0 additions & 4 deletions
4
raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed3/train_conf.json
This file was deleted.
Oops, something went wrong.
4 changes: 0 additions & 4 deletions
4
raygun/torch/examples/batch_training/ieee-isbi-2022/split/seed42/train_conf.json
This file was deleted.
Oops, something went wrong.
18 changes: 0 additions & 18 deletions
18
raygun/torch/examples/batch_training/ieee-isbi-2022/split/train_conf.json
This file was deleted.
Oops, something went wrong.
61 changes: 0 additions & 61 deletions
61
raygun/torch/examples/batch_training/ieee-isbi-2022/train_conf.json
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# ORIGINALLY WRITTEN BY TRI NGUYEN (HARVARD, 2021) | ||
import torch | ||
|
||
class GANLoss(torch.nn.Module): | ||
"""Define different GAN objectives. | ||
The GANLoss class abstracts away the need to create the target label tensor | ||
that has the same size as the input. | ||
""" | ||
|
||
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): | ||
""" Initialize the GANLoss class. | ||
Parameters: | ||
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. | ||
target_real_label (bool) - - label for a real image | ||
target_fake_label (bool) - - label of a fake image | ||
Note: Do not use sigmoid as the last layer of Discriminator. | ||
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. | ||
""" | ||
super(GANLoss, self).__init__() | ||
self.register_buffer('real_label', torch.tensor(target_real_label)) | ||
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | ||
self.gan_mode = gan_mode | ||
if gan_mode == 'lsgan': | ||
self.loss = torch.nn.MSELoss() | ||
elif gan_mode == 'vanilla': | ||
self.loss = torch.nn.BCEWithLogitsLoss() | ||
elif gan_mode in ['wgangp']: | ||
self.loss = None | ||
else: | ||
raise NotImplementedError('gan mode %s not implemented' % gan_mode) | ||
|
||
def get_target_tensor(self, prediction, target_is_real): | ||
"""Create label tensors with the same size as the input. | ||
Parameters: | ||
prediction (tensor) - - typically the prediction from a discriminator | ||
target_is_real (bool) - - if the ground truth label is for real images or fake images | ||
Returns: | ||
A label tensor filled with ground truth label, and with the size of the input | ||
""" | ||
|
||
if target_is_real: | ||
target_tensor = self.real_label | ||
else: | ||
target_tensor = self.fake_label | ||
return target_tensor.expand_as(prediction) | ||
|
||
def __call__(self, prediction, target_is_real): | ||
"""Calculate loss given Discriminator's output and grount truth labels. | ||
Parameters: | ||
prediction (tensor) - - typically the prediction output from a discriminator | ||
target_is_real (bool) - - if the ground truth label is for real images or fake images | ||
Returns: | ||
the calculated loss. | ||
""" | ||
if self.gan_mode in ['lsgan', 'vanilla']: | ||
target_tensor = self.get_target_tensor(prediction, target_is_real) | ||
loss = self.loss(prediction, target_tensor) | ||
elif self.gan_mode == 'wgangp': | ||
if target_is_real: | ||
loss = -prediction.mean() | ||
else: | ||
loss = prediction.mean() | ||
return loss |
Oops, something went wrong.