Skip to content

Commit

Permalink
Codes added.
Browse files Browse the repository at this point in the history
  • Loading branch information
togheppi committed Sep 14, 2017
1 parent ec08abf commit 07bc8c0
Show file tree
Hide file tree
Showing 6 changed files with 774 additions and 0 deletions.
79 changes: 79 additions & 0 deletions DualGAN_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torchvision import transforms
from torch.autograd import Variable
from dataset import DatasetFromFolder
from model import Generator
import utils
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False, default='horse2zebra', help='input dataset')
parser.add_argument('--batch_size', type=int, default=1, help='test batch size')
parser.add_argument('--ngf', type=int, default=32)
parser.add_argument('--num_resnet', type=int, default=9, help='number of resnet blocks in generator')
parser.add_argument('--input_size', type=int, default=256, help='input size')
params = parser.parse_args()
print(params)

# Directories for loading data and saving results
data_dir = '../Data/' + params.dataset + '/'
save_dir = params.dataset + '_test_results/'
model_dir = params.dataset + '_model/'

if not os.path.exists(save_dir):
os.mkdir(save_dir)
if not os.path.exists(model_dir):
os.mkdir(model_dir)

# Data pre-processing
transform = transforms.Compose([transforms.Scale(params.input_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# Test data
test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)
test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A,
batch_size=params.batch_size,
shuffle=False)
test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)
test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B,
batch_size=params.batch_size,
shuffle=False)

# Load model
G_A = Generator(3, params.ngf, 3, params.num_resnet)
G_B = Generator(3, params.ngf, 3, params.num_resnet)
G_A.cuda()
G_B.cuda()
G_A.load_state_dict(torch.load(model_dir + 'generator_A_param.pkl'))
G_B.load_state_dict(torch.load(model_dir + 'generator_B_param.pkl'))

# Test
for i, real_A in enumerate(test_data_loader_A):

# input image data
real_A = Variable(real_A.cuda())

# A -> B -> A
fake_B = G_A(real_A)
recon_A = G_B(fake_B)

# Show result for test data
utils.plot_test_result(real_A, fake_B, recon_A, i, save=True, save_dir=save_dir + 'AtoB/')

print('%d images are generated.' % (i + 1))

for i, real_B in enumerate(test_data_loader_B):

# input image data
real_B = Variable(real_B.cuda())

# B -> A -> B
fake_A = G_B(real_B)
recon_B = G_A(fake_A)

# Show result for test data
utils.plot_test_result(real_B, fake_A, recon_B, i, save=True, save_dir=save_dir + 'BtoA/')

print('%d images are generated.' % (i + 1))
285 changes: 285 additions & 0 deletions DualGAN_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import torch
from torchvision import transforms
from torch.autograd import Variable
from dataset import DatasetFromFolder
from model import Generator, Discriminator
import utils
import argparse
import os, itertools
from logger import Logger
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False, default='sketch-photo', help='input dataset')
parser.add_argument('--batch_size', type=int, default=1, help='train batch size')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--input_size', type=int, default=256, help='input size')
parser.add_argument('--num_channel', type=int, default=1, help='number of channels for input image')
parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True of False')
parser.add_argument('--num_epochs', type=int, default=45, help='number of train epochs')
parser.add_argument('--num_iter_G', type=int, default=2, help='number of iterations for training generator')
parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for generator, default=0.0002')
parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for discriminator, default=0.0002')
parser.add_argument('--decay', type=float, default=0.1, help='weight decay for RMSProp optimizer')
parser.add_argument('--lambdaA', type=float, default=20, help='lambdaA for L1 loss')
parser.add_argument('--lambdaB', type=float, default=20, help='lambdaB for L1 loss')
params = parser.parse_args()
print(params)

# Directories for loading data and saving results
data_dir = '../Data/' + params.dataset + '/'
save_dir = params.dataset + '_results/'
model_dir = params.dataset + '_model/'

if not os.path.exists(save_dir):
os.mkdir(save_dir)
if not os.path.exists(model_dir):
os.mkdir(model_dir)

# Data pre-processing
transform = transforms.Compose([transforms.Scale(params.input_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# Train data
train_data_A = DatasetFromFolder(data_dir, subfolder='train/A', transform=transform, fliplr=params.fliplr, is_color=False)
train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A,
batch_size=params.batch_size,
shuffle=True)
train_data_B = DatasetFromFolder(data_dir, subfolder='train/B', transform=transform, fliplr=params.fliplr, is_color=False)
train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B,
batch_size=params.batch_size,
shuffle=True)

# Test data
test_data_A = DatasetFromFolder(data_dir, subfolder='val/A', transform=transform, is_color=False)
test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A,
batch_size=params.batch_size,
shuffle=False)
test_data_B = DatasetFromFolder(data_dir, subfolder='val/B', transform=transform, is_color=False)
test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B,
batch_size=params.batch_size,
shuffle=False)

# Get specific test images
test_real_A_data = test_data_A.__getitem__(0).unsqueeze(0) # Convert to 4d tensor (BxNxHxW)
test_real_B_data = test_data_B.__getitem__(0).unsqueeze(0)

# Models
G_A = Generator(params.num_channel, params.ngf, params.num_channel)
G_B = Generator(params.num_channel, params.ngf, params.num_channel)
D_A = Discriminator(params.num_channel, params.ndf, 1)
D_B = Discriminator(params.num_channel, params.ndf, 1)
G_A.normal_weight_init(mean=0.0, std=0.02)
G_B.normal_weight_init(mean=0.0, std=0.02)
D_A.normal_weight_init(mean=0.0, std=0.02)
D_B.normal_weight_init(mean=0.0, std=0.02)
G_A.cuda()
G_B.cuda()
D_A.cuda()
D_B.cuda()


# Set the logger
D_A_log_dir = save_dir + 'D_A_logs'
D_B_log_dir = save_dir + 'D_B_logs'
if not os.path.exists(D_A_log_dir):
os.mkdir(D_A_log_dir)
D_A_logger = Logger(D_A_log_dir)
if not os.path.exists(D_B_log_dir):
os.mkdir(D_B_log_dir)
D_B_logger = Logger(D_B_log_dir)

G_A_log_dir = save_dir + 'G_A_logs'
G_B_log_dir = save_dir + 'G_B_logs'
if not os.path.exists(G_A_log_dir):
os.mkdir(G_A_log_dir)
G_A_logger = Logger(G_A_log_dir)
if not os.path.exists(G_B_log_dir):
os.mkdir(G_B_log_dir)
G_B_logger = Logger(G_B_log_dir)

L1_A_log_dir = save_dir + 'L1_A_logs'
L1_B_log_dir = save_dir + 'L1_B_logs'
if not os.path.exists(L1_A_log_dir):
os.mkdir(L1_A_log_dir)
L1_A_logger = Logger(L1_A_log_dir)
if not os.path.exists(L1_B_log_dir):
os.mkdir(L1_B_log_dir)
L1_B_logger = Logger(L1_B_log_dir)

img_log_dir = save_dir + 'img_logs'
if not os.path.exists(img_log_dir):
os.mkdir(img_log_dir)
img_logger = Logger(img_log_dir)


# Loss function
BCE_loss = torch.nn.BCELoss().cuda()
L1_loss = torch.nn.L1Loss().cuda()

# optimizers
G_optimizer = torch.optim.RMSprop(itertools.chain(G_A.parameters(), G_B.parameters()), lr=params.lrG, weight_decay=params.decay)
D_A_optimizer = torch.optim.RMSprop(D_A.parameters(), lr=params.lrD, weight_decay=params.decay)
D_B_optimizer = torch.optim.RMSprop(D_B.parameters(), lr=params.lrD, weight_decay=params.decay)

# Training GAN
D_A_avg_losses = []
D_B_avg_losses = []
G_A_avg_losses = []
G_B_avg_losses = []
L1_A_avg_losses = []
L1_B_avg_losses = []

step = 0
for epoch in range(params.num_epochs):
D_A_losses = []
D_B_losses = []
G_A_losses = []
G_B_losses = []
L1_A_losses = []
L1_B_losses = []

# training
for i, (real_A, real_B) in enumerate(zip(train_data_loader_A, train_data_loader_B)):

# input image data
real_A = Variable(real_A.cuda())
real_B = Variable(real_B.cuda())
for _ in range(params.num_iter_G):
# Train generator G
# A -> B
fake_B = G_A(real_A)
D_B_fake_decision = D_B(fake_B)
G_A_loss = BCE_loss(D_B_fake_decision, Variable(torch.ones(D_B_fake_decision.size()).cuda()))

# forward L1 loss
recon_A = G_B(fake_B)
L1_A_loss = L1_loss(recon_A, real_A) * params.lambdaA

# B -> A
fake_A = G_B(real_B)
D_A_fake_decision = D_A(fake_A)
G_B_loss = BCE_loss(D_A_fake_decision, Variable(torch.ones(D_A_fake_decision.size()).cuda()))

# backward L1 loss
recon_B = G_A(fake_A)
L1_B_loss = L1_loss(recon_B, real_B) * params.lambdaB

# Back propagation
G_loss = G_A_loss + G_B_loss + L1_A_loss + L1_B_loss
G_optimizer.zero_grad()
G_loss.backward(retain_graph=True)
G_optimizer.step()

# Train discriminator D_A
D_A_real_decision = D_A(real_A)
D_A_real_loss = BCE_loss(D_A_real_decision, Variable(torch.ones(D_A_real_decision.size()).cuda()))
D_A_fake_decision = D_A(fake_A)
D_A_fake_loss = BCE_loss(D_A_fake_decision, Variable(torch.zeros(D_A_fake_decision.size()).cuda()))

# Back propagation
D_A_loss = D_A_real_loss + D_A_fake_loss
D_A_optimizer.zero_grad()
D_A_loss.backward()
D_A_optimizer.step()

# Train discriminator D_B
D_B_real_decision = D_B(real_B)
D_B_real_loss = BCE_loss(D_B_real_decision, Variable(torch.ones(D_B_real_decision.size()).cuda()))
D_B_fake_decision = D_B(fake_B)
D_B_fake_loss = BCE_loss(D_B_fake_decision, Variable(torch.zeros(D_B_fake_decision.size()).cuda()))

# Back propagation
D_B_loss = D_B_real_loss + D_B_fake_loss
D_B_optimizer.zero_grad()
D_B_loss.backward()
D_B_optimizer.step()

# loss values
D_A_losses.append(D_A_loss.data[0])
D_B_losses.append(D_B_loss.data[0])
G_A_losses.append(G_A_loss.data[0])
G_B_losses.append(G_B_loss.data[0])
L1_A_losses.append(L1_A_loss.data[0])
L1_B_losses.append(L1_B_loss.data[0])

print('Epoch [%d/%d], Step [%d/%d], D_A_loss: %.4f, D_B_loss: %.4f, G_A_loss: %.4f, G_B_loss: %.4f'
% (epoch+1, params.num_epochs, i+1, len(train_data_loader_A), D_A_loss.data[0], D_B_loss.data[0], G_A_loss.data[0], G_B_loss.data[0]))

# ============ TensorBoard logging ============#
D_A_logger.scalar_summary('losses', D_A_loss.data[0], step + 1)
D_B_logger.scalar_summary('losses', D_B_loss.data[0], step + 1)
G_A_logger.scalar_summary('losses', G_A_loss.data[0], step + 1)
G_B_logger.scalar_summary('losses', G_B_loss.data[0], step + 1)
L1_A_logger.scalar_summary('losses', L1_A_loss.data[0], step + 1)
L1_B_logger.scalar_summary('losses', L1_B_loss.data[0], step + 1)
step += 1

D_A_avg_loss = torch.mean(torch.FloatTensor(D_A_losses))
D_B_avg_loss = torch.mean(torch.FloatTensor(D_B_losses))
G_A_avg_loss = torch.mean(torch.FloatTensor(G_A_losses))
G_B_avg_loss = torch.mean(torch.FloatTensor(G_B_losses))
L1_A_avg_loss = torch.mean(torch.FloatTensor(L1_A_losses))
L1_B_avg_loss = torch.mean(torch.FloatTensor(L1_B_losses))

# avg loss values for plot
D_A_avg_losses.append(D_A_avg_loss)
D_B_avg_losses.append(D_B_avg_loss)
G_A_avg_losses.append(G_A_avg_loss)
G_B_avg_losses.append(G_B_avg_loss)
L1_A_avg_losses.append(L1_A_avg_loss)
L1_B_avg_losses.append(L1_B_avg_loss)

# Show result for test image
test_real_A = Variable(test_real_A_data.cuda())
test_fake_B = G_A(test_real_A)
test_recon_A = G_B(test_fake_B)

test_real_B = Variable(test_real_B_data.cuda())
test_fake_A = G_B(test_real_B)
test_recon_B = G_A(test_fake_A)

utils.plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B],
epoch, save=True, save_dir=save_dir)

# log the images
result_AtoB = np.concatenate((utils.to_np(test_real_A), utils.to_np(test_fake_B), utils.to_np(test_recon_A)), axis=3)
result_BtoA = np.concatenate((utils.to_np(test_real_B), utils.to_np(test_fake_A), utils.to_np(test_recon_B)), axis=3)

if list(result_AtoB.shape)[1] == 1:
result_AtoB = result_AtoB.squeeze(axis=1) # for gray images, convert to BxHxW
else:
result_AtoB = result_AtoB.transpose(0, 2, 3, 1) # for color image, convert to BxHxWxC
if list(result_BtoA.shape)[1] == 1:
result_BtoA = result_BtoA.squeeze(axis=1)
else:
result_BtoA = result_BtoA.transpose(0, 2, 3, 1)

info = {
'result_AtoB': result_AtoB,
'result_BtoA': result_BtoA
}

for tag, images in info.items():
img_logger.image_summary(tag, images, epoch + 1)


# Plot average losses
avg_losses = []
avg_losses.append(D_A_avg_losses)
avg_losses.append(D_B_avg_losses)
avg_losses.append(G_A_avg_losses)
avg_losses.append(G_B_avg_losses)
avg_losses.append(L1_A_avg_losses)
avg_losses.append(L1_B_avg_losses)
utils.plot_loss(avg_losses, params.num_epochs, save=True, save_dir=save_dir)

# Make gif
utils.make_gif(params.dataset, params.num_epochs, save_dir=save_dir)
# Save trained parameters of model
torch.save(G_A.state_dict(), model_dir + 'generator_A_param.pkl')
torch.save(G_B.state_dict(), model_dir + 'generator_B_param.pkl')
torch.save(D_A.state_dict(), model_dir + 'discriminator_A_param.pkl')
torch.save(D_B.state_dict(), model_dir + 'discriminator_B_param.pkl')
Loading

0 comments on commit 07bc8c0

Please sign in to comment.