diff --git a/DualGAN_test.py b/DualGAN_test.py new file mode 100644 index 0000000..0ba7435 --- /dev/null +++ b/DualGAN_test.py @@ -0,0 +1,79 @@ +import torch +from torchvision import transforms +from torch.autograd import Variable +from dataset import DatasetFromFolder +from model import Generator +import utils +import argparse +import os + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', required=False, default='horse2zebra', help='input dataset') +parser.add_argument('--batch_size', type=int, default=1, help='test batch size') +parser.add_argument('--ngf', type=int, default=32) +parser.add_argument('--num_resnet', type=int, default=9, help='number of resnet blocks in generator') +parser.add_argument('--input_size', type=int, default=256, help='input size') +params = parser.parse_args() +print(params) + +# Directories for loading data and saving results +data_dir = '../Data/' + params.dataset + '/' +save_dir = params.dataset + '_test_results/' +model_dir = params.dataset + '_model/' + +if not os.path.exists(save_dir): + os.mkdir(save_dir) +if not os.path.exists(model_dir): + os.mkdir(model_dir) + +# Data pre-processing +transform = transforms.Compose([transforms.Scale(params.input_size), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + +# Test data +test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform) +test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A, + batch_size=params.batch_size, + shuffle=False) +test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform) +test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B, + batch_size=params.batch_size, + shuffle=False) + +# Load model +G_A = Generator(3, params.ngf, 3, params.num_resnet) +G_B = Generator(3, params.ngf, 3, params.num_resnet) +G_A.cuda() +G_B.cuda() +G_A.load_state_dict(torch.load(model_dir + 'generator_A_param.pkl')) +G_B.load_state_dict(torch.load(model_dir + 'generator_B_param.pkl')) + +# Test +for i, real_A in enumerate(test_data_loader_A): + + # input image data + real_A = Variable(real_A.cuda()) + + # A -> B -> A + fake_B = G_A(real_A) + recon_A = G_B(fake_B) + + # Show result for test data + utils.plot_test_result(real_A, fake_B, recon_A, i, save=True, save_dir=save_dir + 'AtoB/') + + print('%d images are generated.' % (i + 1)) + +for i, real_B in enumerate(test_data_loader_B): + + # input image data + real_B = Variable(real_B.cuda()) + + # B -> A -> B + fake_A = G_B(real_B) + recon_B = G_A(fake_A) + + # Show result for test data + utils.plot_test_result(real_B, fake_A, recon_B, i, save=True, save_dir=save_dir + 'BtoA/') + + print('%d images are generated.' % (i + 1)) \ No newline at end of file diff --git a/DualGAN_train.py b/DualGAN_train.py new file mode 100644 index 0000000..837629d --- /dev/null +++ b/DualGAN_train.py @@ -0,0 +1,285 @@ +import torch +from torchvision import transforms +from torch.autograd import Variable +from dataset import DatasetFromFolder +from model import Generator, Discriminator +import utils +import argparse +import os, itertools +from logger import Logger +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', required=False, default='sketch-photo', help='input dataset') +parser.add_argument('--batch_size', type=int, default=1, help='train batch size') +parser.add_argument('--ngf', type=int, default=64) +parser.add_argument('--ndf', type=int, default=64) +parser.add_argument('--input_size', type=int, default=256, help='input size') +parser.add_argument('--num_channel', type=int, default=1, help='number of channels for input image') +parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True of False') +parser.add_argument('--num_epochs', type=int, default=45, help='number of train epochs') +parser.add_argument('--num_iter_G', type=int, default=2, help='number of iterations for training generator') +parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for generator, default=0.0002') +parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for discriminator, default=0.0002') +parser.add_argument('--decay', type=float, default=0.1, help='weight decay for RMSProp optimizer') +parser.add_argument('--lambdaA', type=float, default=20, help='lambdaA for L1 loss') +parser.add_argument('--lambdaB', type=float, default=20, help='lambdaB for L1 loss') +params = parser.parse_args() +print(params) + +# Directories for loading data and saving results +data_dir = '../Data/' + params.dataset + '/' +save_dir = params.dataset + '_results/' +model_dir = params.dataset + '_model/' + +if not os.path.exists(save_dir): + os.mkdir(save_dir) +if not os.path.exists(model_dir): + os.mkdir(model_dir) + +# Data pre-processing +transform = transforms.Compose([transforms.Scale(params.input_size), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + +# Train data +train_data_A = DatasetFromFolder(data_dir, subfolder='train/A', transform=transform, fliplr=params.fliplr, is_color=False) +train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, + batch_size=params.batch_size, + shuffle=True) +train_data_B = DatasetFromFolder(data_dir, subfolder='train/B', transform=transform, fliplr=params.fliplr, is_color=False) +train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, + batch_size=params.batch_size, + shuffle=True) + +# Test data +test_data_A = DatasetFromFolder(data_dir, subfolder='val/A', transform=transform, is_color=False) +test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A, + batch_size=params.batch_size, + shuffle=False) +test_data_B = DatasetFromFolder(data_dir, subfolder='val/B', transform=transform, is_color=False) +test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B, + batch_size=params.batch_size, + shuffle=False) + +# Get specific test images +test_real_A_data = test_data_A.__getitem__(0).unsqueeze(0) # Convert to 4d tensor (BxNxHxW) +test_real_B_data = test_data_B.__getitem__(0).unsqueeze(0) + +# Models +G_A = Generator(params.num_channel, params.ngf, params.num_channel) +G_B = Generator(params.num_channel, params.ngf, params.num_channel) +D_A = Discriminator(params.num_channel, params.ndf, 1) +D_B = Discriminator(params.num_channel, params.ndf, 1) +G_A.normal_weight_init(mean=0.0, std=0.02) +G_B.normal_weight_init(mean=0.0, std=0.02) +D_A.normal_weight_init(mean=0.0, std=0.02) +D_B.normal_weight_init(mean=0.0, std=0.02) +G_A.cuda() +G_B.cuda() +D_A.cuda() +D_B.cuda() + + +# Set the logger +D_A_log_dir = save_dir + 'D_A_logs' +D_B_log_dir = save_dir + 'D_B_logs' +if not os.path.exists(D_A_log_dir): + os.mkdir(D_A_log_dir) +D_A_logger = Logger(D_A_log_dir) +if not os.path.exists(D_B_log_dir): + os.mkdir(D_B_log_dir) +D_B_logger = Logger(D_B_log_dir) + +G_A_log_dir = save_dir + 'G_A_logs' +G_B_log_dir = save_dir + 'G_B_logs' +if not os.path.exists(G_A_log_dir): + os.mkdir(G_A_log_dir) +G_A_logger = Logger(G_A_log_dir) +if not os.path.exists(G_B_log_dir): + os.mkdir(G_B_log_dir) +G_B_logger = Logger(G_B_log_dir) + +L1_A_log_dir = save_dir + 'L1_A_logs' +L1_B_log_dir = save_dir + 'L1_B_logs' +if not os.path.exists(L1_A_log_dir): + os.mkdir(L1_A_log_dir) +L1_A_logger = Logger(L1_A_log_dir) +if not os.path.exists(L1_B_log_dir): + os.mkdir(L1_B_log_dir) +L1_B_logger = Logger(L1_B_log_dir) + +img_log_dir = save_dir + 'img_logs' +if not os.path.exists(img_log_dir): + os.mkdir(img_log_dir) +img_logger = Logger(img_log_dir) + + +# Loss function +BCE_loss = torch.nn.BCELoss().cuda() +L1_loss = torch.nn.L1Loss().cuda() + +# optimizers +G_optimizer = torch.optim.RMSprop(itertools.chain(G_A.parameters(), G_B.parameters()), lr=params.lrG, weight_decay=params.decay) +D_A_optimizer = torch.optim.RMSprop(D_A.parameters(), lr=params.lrD, weight_decay=params.decay) +D_B_optimizer = torch.optim.RMSprop(D_B.parameters(), lr=params.lrD, weight_decay=params.decay) + +# Training GAN +D_A_avg_losses = [] +D_B_avg_losses = [] +G_A_avg_losses = [] +G_B_avg_losses = [] +L1_A_avg_losses = [] +L1_B_avg_losses = [] + +step = 0 +for epoch in range(params.num_epochs): + D_A_losses = [] + D_B_losses = [] + G_A_losses = [] + G_B_losses = [] + L1_A_losses = [] + L1_B_losses = [] + + # training + for i, (real_A, real_B) in enumerate(zip(train_data_loader_A, train_data_loader_B)): + + # input image data + real_A = Variable(real_A.cuda()) + real_B = Variable(real_B.cuda()) + for _ in range(params.num_iter_G): + # Train generator G + # A -> B + fake_B = G_A(real_A) + D_B_fake_decision = D_B(fake_B) + G_A_loss = BCE_loss(D_B_fake_decision, Variable(torch.ones(D_B_fake_decision.size()).cuda())) + + # forward L1 loss + recon_A = G_B(fake_B) + L1_A_loss = L1_loss(recon_A, real_A) * params.lambdaA + + # B -> A + fake_A = G_B(real_B) + D_A_fake_decision = D_A(fake_A) + G_B_loss = BCE_loss(D_A_fake_decision, Variable(torch.ones(D_A_fake_decision.size()).cuda())) + + # backward L1 loss + recon_B = G_A(fake_A) + L1_B_loss = L1_loss(recon_B, real_B) * params.lambdaB + + # Back propagation + G_loss = G_A_loss + G_B_loss + L1_A_loss + L1_B_loss + G_optimizer.zero_grad() + G_loss.backward(retain_graph=True) + G_optimizer.step() + + # Train discriminator D_A + D_A_real_decision = D_A(real_A) + D_A_real_loss = BCE_loss(D_A_real_decision, Variable(torch.ones(D_A_real_decision.size()).cuda())) + D_A_fake_decision = D_A(fake_A) + D_A_fake_loss = BCE_loss(D_A_fake_decision, Variable(torch.zeros(D_A_fake_decision.size()).cuda())) + + # Back propagation + D_A_loss = D_A_real_loss + D_A_fake_loss + D_A_optimizer.zero_grad() + D_A_loss.backward() + D_A_optimizer.step() + + # Train discriminator D_B + D_B_real_decision = D_B(real_B) + D_B_real_loss = BCE_loss(D_B_real_decision, Variable(torch.ones(D_B_real_decision.size()).cuda())) + D_B_fake_decision = D_B(fake_B) + D_B_fake_loss = BCE_loss(D_B_fake_decision, Variable(torch.zeros(D_B_fake_decision.size()).cuda())) + + # Back propagation + D_B_loss = D_B_real_loss + D_B_fake_loss + D_B_optimizer.zero_grad() + D_B_loss.backward() + D_B_optimizer.step() + + # loss values + D_A_losses.append(D_A_loss.data[0]) + D_B_losses.append(D_B_loss.data[0]) + G_A_losses.append(G_A_loss.data[0]) + G_B_losses.append(G_B_loss.data[0]) + L1_A_losses.append(L1_A_loss.data[0]) + L1_B_losses.append(L1_B_loss.data[0]) + + print('Epoch [%d/%d], Step [%d/%d], D_A_loss: %.4f, D_B_loss: %.4f, G_A_loss: %.4f, G_B_loss: %.4f' + % (epoch+1, params.num_epochs, i+1, len(train_data_loader_A), D_A_loss.data[0], D_B_loss.data[0], G_A_loss.data[0], G_B_loss.data[0])) + + # ============ TensorBoard logging ============# + D_A_logger.scalar_summary('losses', D_A_loss.data[0], step + 1) + D_B_logger.scalar_summary('losses', D_B_loss.data[0], step + 1) + G_A_logger.scalar_summary('losses', G_A_loss.data[0], step + 1) + G_B_logger.scalar_summary('losses', G_B_loss.data[0], step + 1) + L1_A_logger.scalar_summary('losses', L1_A_loss.data[0], step + 1) + L1_B_logger.scalar_summary('losses', L1_B_loss.data[0], step + 1) + step += 1 + + D_A_avg_loss = torch.mean(torch.FloatTensor(D_A_losses)) + D_B_avg_loss = torch.mean(torch.FloatTensor(D_B_losses)) + G_A_avg_loss = torch.mean(torch.FloatTensor(G_A_losses)) + G_B_avg_loss = torch.mean(torch.FloatTensor(G_B_losses)) + L1_A_avg_loss = torch.mean(torch.FloatTensor(L1_A_losses)) + L1_B_avg_loss = torch.mean(torch.FloatTensor(L1_B_losses)) + + # avg loss values for plot + D_A_avg_losses.append(D_A_avg_loss) + D_B_avg_losses.append(D_B_avg_loss) + G_A_avg_losses.append(G_A_avg_loss) + G_B_avg_losses.append(G_B_avg_loss) + L1_A_avg_losses.append(L1_A_avg_loss) + L1_B_avg_losses.append(L1_B_avg_loss) + + # Show result for test image + test_real_A = Variable(test_real_A_data.cuda()) + test_fake_B = G_A(test_real_A) + test_recon_A = G_B(test_fake_B) + + test_real_B = Variable(test_real_B_data.cuda()) + test_fake_A = G_B(test_real_B) + test_recon_B = G_A(test_fake_A) + + utils.plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B], + epoch, save=True, save_dir=save_dir) + + # log the images + result_AtoB = np.concatenate((utils.to_np(test_real_A), utils.to_np(test_fake_B), utils.to_np(test_recon_A)), axis=3) + result_BtoA = np.concatenate((utils.to_np(test_real_B), utils.to_np(test_fake_A), utils.to_np(test_recon_B)), axis=3) + + if list(result_AtoB.shape)[1] == 1: + result_AtoB = result_AtoB.squeeze(axis=1) # for gray images, convert to BxHxW + else: + result_AtoB = result_AtoB.transpose(0, 2, 3, 1) # for color image, convert to BxHxWxC + if list(result_BtoA.shape)[1] == 1: + result_BtoA = result_BtoA.squeeze(axis=1) + else: + result_BtoA = result_BtoA.transpose(0, 2, 3, 1) + + info = { + 'result_AtoB': result_AtoB, + 'result_BtoA': result_BtoA + } + + for tag, images in info.items(): + img_logger.image_summary(tag, images, epoch + 1) + + +# Plot average losses +avg_losses = [] +avg_losses.append(D_A_avg_losses) +avg_losses.append(D_B_avg_losses) +avg_losses.append(G_A_avg_losses) +avg_losses.append(G_B_avg_losses) +avg_losses.append(L1_A_avg_losses) +avg_losses.append(L1_B_avg_losses) +utils.plot_loss(avg_losses, params.num_epochs, save=True, save_dir=save_dir) + +# Make gif +utils.make_gif(params.dataset, params.num_epochs, save_dir=save_dir) +# Save trained parameters of model +torch.save(G_A.state_dict(), model_dir + 'generator_A_param.pkl') +torch.save(G_B.state_dict(), model_dir + 'generator_B_param.pkl') +torch.save(D_A.state_dict(), model_dir + 'discriminator_A_param.pkl') +torch.save(D_B.state_dict(), model_dir + 'discriminator_B_param.pkl') diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..5484997 --- /dev/null +++ b/dataset.py @@ -0,0 +1,45 @@ +# Custom dataset +from PIL import Image +import torch.utils.data as data +import os +import random + + +class DatasetFromFolder(data.Dataset): + def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False, is_color=True): + super(DatasetFromFolder, self).__init__() + self.input_path = os.path.join(image_dir, subfolder) + self.image_filenames = [x for x in sorted(os.listdir(self.input_path))] + self.transform = transform + self.resize_scale = resize_scale + self.crop_size = crop_size + self.fliplr = fliplr + self.is_color = is_color + + def __getitem__(self, index): + # Load Image + img_fn = os.path.join(self.input_path, self.image_filenames[index]) + if self.is_color: + img = Image.open(img_fn).convert('RGB') + else: + img = Image.open(img_fn) + + # preprocessing + if self.resize_scale is not None: + img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) + + if self.crop_size is not None: + x = random.randint(0, self.resize_scale - self.crop_size + 1) + y = random.randint(0, self.resize_scale - self.crop_size + 1) + img = img.crop((x, y, x + self.crop_size, y + self.crop_size)) + if self.fliplr: + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if self.transform is not None: + img = self.transform(img) + + return img + + def __len__(self): + return len(self.image_filenames) diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..a183d46 --- /dev/null +++ b/logger.py @@ -0,0 +1,73 @@ +# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 +import tensorflow as tf +import numpy as np +import scipy.misc + +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + + +class Logger(object): + def __init__(self, log_dir): + """Create a summary writer logging to log_dir.""" + self.writer = tf.summary.FileWriter(log_dir) + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + self.writer.flush() + + def image_summary(self, tag, images, step): + """Log a list of images.""" + + img_summaries = [] + for i, img in enumerate(images): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(img).save(s, format="png") + + # Create an Image object + img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), + height=img.shape[0], + width=img.shape[1]) + # Create a Summary value + img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) + + # Create and write Summary + summary = tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + self.writer.flush() + + def histo_summary(self, tag, values, step, bins=1000): + """Log a histogram of the tensor of values.""" + + # Create a histogram using numpy + counts, bin_edges = np.histogram(values, bins=bins) + + # Fill the fields of the histogram proto + hist = tf.HistogramProto() + hist.min = float(np.min(values)) + hist.max = float(np.max(values)) + hist.num = int(np.prod(values.shape)) + hist.sum = float(np.sum(values)) + hist.sum_squares = float(np.sum(values ** 2)) + + # Drop the start of the first bin + bin_edges = bin_edges[1:] + + # Add bin edges and counts + for edge in bin_edges: + hist.bucket_limit.append(edge) + for c in counts: + hist.bucket.append(c) + + # Create and write Summary + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) + self.writer.add_summary(summary, step) + self.writer.flush() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..e7eb9e1 --- /dev/null +++ b/model.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as F + +class ConvBlock(torch.nn.Module): + def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, activation=True, batch_norm=True): + super(ConvBlock, self).__init__() + self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding) + self.activation = activation + self.lrelu = torch.nn.LeakyReLU(0.2, True) + self.batch_norm = batch_norm + self.bn = torch.nn.BatchNorm2d(output_size) + + def forward(self, x): + if self.activation: + out = self.conv(self.lrelu(x)) + else: + out = self.conv(x) + + if self.batch_norm: + return self.bn(out) + else: + return out + + +class DeconvBlock(torch.nn.Module): + def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, batch_norm=True, dropout=False): + super(DeconvBlock, self).__init__() + self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding) + self.bn = torch.nn.BatchNorm2d(output_size) + self.relu = torch.nn.ReLU(True) + self.batch_norm = batch_norm + self.dropout = dropout + + def forward(self, x): + if self.batch_norm: + out = self.bn(self.deconv(self.relu(x))) + else: + out = self.deconv(self.relu(x)) + + if self.dropout: + return F.dropout(out, 0.5, training=True) + else: + return out + + +class Generator(torch.nn.Module): + def __init__(self, input_dim, num_filter, output_dim): + super(Generator, self).__init__() + + # Encoder + self.conv1 = ConvBlock(input_dim, num_filter, activation=False, batch_norm=False) + self.conv2 = ConvBlock(num_filter, num_filter * 2) + self.conv3 = ConvBlock(num_filter * 2, num_filter * 4) + self.conv4 = ConvBlock(num_filter * 4, num_filter * 8) + self.conv5 = ConvBlock(num_filter * 8, num_filter * 8) + self.conv6 = ConvBlock(num_filter * 8, num_filter * 8) + self.conv7 = ConvBlock(num_filter * 8, num_filter * 8) + self.conv8 = ConvBlock(num_filter * 8, num_filter * 8) + # Decoder + self.deconv1 = DeconvBlock(num_filter * 8, num_filter * 8, dropout=True) + self.deconv2 = DeconvBlock(num_filter * 8 * 2, num_filter * 8, dropout=True) + self.deconv3 = DeconvBlock(num_filter * 8 * 2, num_filter * 8, dropout=True) + self.deconv4 = DeconvBlock(num_filter * 8 * 2, num_filter * 8) + self.deconv5 = DeconvBlock(num_filter * 8 * 2, num_filter * 4) + self.deconv6 = DeconvBlock(num_filter * 4 * 2, num_filter * 2) + self.deconv7 = DeconvBlock(num_filter * 2 * 2, num_filter) + self.deconv8 = DeconvBlock(num_filter * 2, output_dim, batch_norm=False) + + def forward(self, x): + # Encoder + enc1 = self.conv1(x) + enc2 = self.conv2(enc1) + enc3 = self.conv3(enc2) + enc4 = self.conv4(enc3) + enc5 = self.conv5(enc4) + enc6 = self.conv6(enc5) + enc7 = self.conv7(enc6) + enc8 = self.conv8(enc7) + # Decoder with skip-connections + dec1 = self.deconv1(enc8) + dec1 = torch.cat([dec1, enc7], 1) + dec2 = self.deconv2(dec1) + dec2 = torch.cat([dec2, enc6], 1) + dec3 = self.deconv3(dec2) + dec3 = torch.cat([dec3, enc5], 1) + dec4 = self.deconv4(dec3) + dec4 = torch.cat([dec4, enc4], 1) + dec5 = self.deconv5(dec4) + dec5 = torch.cat([dec5, enc3], 1) + dec6 = self.deconv6(dec5) + dec6 = torch.cat([dec6, enc2], 1) + dec7 = self.deconv7(dec6) + dec7 = torch.cat([dec7, enc1], 1) + dec8 = self.deconv8(dec7) + out = torch.nn.Tanh()(dec8) + return out + + def normal_weight_init(self, mean=0.0, std=0.02): + for m in self.children(): + if isinstance(m, ConvBlock): + torch.nn.init.normal(m.conv.weight, mean, std) + if isinstance(m, DeconvBlock): + torch.nn.init.normal(m.deconv.weight, mean, std) + + +class Discriminator(torch.nn.Module): + def __init__(self, input_dim, num_filter, output_dim): + super(Discriminator, self).__init__() + + self.conv1 = ConvBlock(input_dim, num_filter, activation=False, batch_norm=False) + self.conv2 = ConvBlock(num_filter, num_filter * 2) + self.conv3 = ConvBlock(num_filter * 2, num_filter * 4) + self.conv4 = ConvBlock(num_filter * 4, num_filter * 8, stride=1) + self.conv5 = ConvBlock(num_filter * 8, output_dim, stride=1, batch_norm=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + out = torch.nn.Sigmoid()(x) + return out + + def normal_weight_init(self, mean=0.0, std=0.02): + for m in self.children(): + if isinstance(m, ConvBlock): + torch.nn.init.normal(m.conv.weight, mean, std) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..37ba7f9 --- /dev/null +++ b/utils.py @@ -0,0 +1,164 @@ +import torch +from torch.autograd import Variable +import numpy as np +import matplotlib.pyplot as plt +import os +import imageio +import random + + +# For logger +def to_np(x): + return x.data.cpu().numpy() + + +def to_var(x): + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) + + +# De-normalization +def denorm(x): + out = (x + 1) / 2 + return out.clamp(0, 1) + + +# Plot losses +def plot_loss(avg_losses, num_epochs, save=False, save_dir='results/', show=False): + fig, ax = plt.subplots() + ax.set_xlim(0, num_epochs) + temp = 0.0 + for i in range(len(avg_losses)): + temp = max(np.max(avg_losses[i]), temp) + ax.set_ylim(0, temp*1.1) + plt.xlabel('# of Epochs') + plt.ylabel('Loss values') + + plt.plot(avg_losses[0], label='D_A') + plt.plot(avg_losses[1], label='D_B') + plt.plot(avg_losses[2], label='G_A') + plt.plot(avg_losses[3], label='G_B') + plt.plot(avg_losses[4], label='cycle_A') + plt.plot(avg_losses[5], label='cycle_B') + plt.legend() + + # save figure + if save: + if not os.path.exists(save_dir): + os.mkdir(save_dir) + save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png' + plt.savefig(save_fn) + + if show: + plt.show() + else: + plt.close() + + +def plot_train_result(real_image, gen_image, recon_image, epoch, save=False, save_dir='results/', show=False, fig_size=(5, 5)): + fig, axes = plt.subplots(2, 3, figsize=fig_size) + + imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]), + to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])] + for ax, img in zip(axes.flatten(), imgs): + ax.axis('off') + ax.set_adjustable('box-forced') + # Scale to 0-255 + img = img.squeeze() + if len(img) == 3: + img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8) + ax.imshow(img, cmap=None, aspect='equal') + else: + img = (((img - img.min()) * 255) / (img.max() - img.min())).astype(np.uint8) + ax.imshow(img, cmap='gray', aspect='equal') + plt.subplots_adjust(wspace=0, hspace=0) + + title = 'Epoch {0}'.format(epoch + 1) + fig.text(0.5, 0.04, title, ha='center') + + # save figure + if save: + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch+1) + '.png' + plt.savefig(save_fn) + + if show: + plt.show() + else: + plt.close() + + +def plot_test_result(real_image, gen_image, recon_image, index, save=False, save_dir='results/', show=False): + fig_size = (real_image.size(2) * 3 / 100, real_image.size(3) / 100) + fig, axes = plt.subplots(1, 3, figsize=fig_size) + + imgs = [to_np(real_image), to_np(gen_image), to_np(recon_image)] + for ax, img in zip(axes.flatten(), imgs): + ax.axis('off') + ax.set_adjustable('box-forced') + # Scale to 0-255 + img = img.squeeze() + img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8) + ax.imshow(img, cmap=None, aspect='equal') + plt.subplots_adjust(wspace=0, hspace=0) + + # save figure + if save: + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + save_fn = save_dir + 'Test_result_{:d}'.format(index + 1) + '.png' + fig.subplots_adjust(bottom=0) + fig.subplots_adjust(top=1) + fig.subplots_adjust(right=1) + fig.subplots_adjust(left=0) + plt.savefig(save_fn) + + if show: + plt.show() + else: + plt.close() + + +# Make gif +def make_gif(dataset, num_epochs, save_dir='results/'): + gen_image_plots = [] + for epoch in range(num_epochs): + # plot for generating gif + save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png' + gen_image_plots.append(imageio.imread(save_fn)) + + imageio.mimsave(save_dir + dataset + '_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5) + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images