This repository has been archived by the owner on Jan 14, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
133 lines (110 loc) · 5.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
assert __name__ == '__main__', 'This file cannot be imported.'
import argparse
parser = argparse.ArgumentParser(description='AdaIN Training Script')
# necessary arguments
parser.add_argument('-cd', '--content-dir', type=str, metavar='<dir>', required=True, help='Directory with content images')
parser.add_argument('-sd', '--style-dir', type=str, metavar='<dir>', required=True, help='Directory with style images')
# optional arguments for training
parser.add_argument('--continual', type=str, metavar='<.pth>', default=None, help='File to save and load for continual training, default=disabled')
parser.add_argument('--save-dir', type=str, metavar='<dir>', default='./experiments', help='Directory to save trained models, default=./experiments')
parser.add_argument('--log-dir', type=str, metavar='<dir>', default='./logs', help='Directory to save logs, default=./logs')
parser.add_argument('--log-image-every', type=int, metavar='<int>', default=100, help='Period for loging generated images, non-positive for disabling, default=100')
parser.add_argument('--save-interval', type=int, metavar='<int>', default=10000, help='Period for saving model, default=10000')
parser.add_argument('--include-encoder', action='store_true', help='Option for saving with the encoder')
parser.add_argument('--cuda', action='store_true', help='Option for using GPU if available')
parser.add_argument('--n-threads', type=int, metavar='<int>', default=2, help='Number of threads used for dataloader, default=2')
# hyper-parameters
parser.add_argument('--learning-rate', type=float, metavar='<float>', default=1e-4, help='Learning rate, default=1e-4')
parser.add_argument('--learning-rate-decay', type=float, metavar='<float>', default=5e-5, help='Learning rate decay, default=5e-5')
parser.add_argument('--max-iter', type=int, metavar='<int>', default=160000, help='Maximun number of iteration, default=160000')
parser.add_argument('--batch-size', type=int, metavar='<int>', default=8, help='Size of the batch, default=8')
parser.add_argument('--style-weight', type=float, metavar='<float>', default=10.0, help='Weight of style loss, default=10.0')
parser.add_argument('--content-weight', type=float, metavar='<float>', default=1.0, help='Weight of content loss, default=1.0')
args = parser.parse_args()
import os
import torch
import torch.utils.data as data
from dataloader import ImageFolderDataset, InfiniteSampler, train_transform
from network import AdaIN, save_AdaIn
from pathlib import Path
from PIL import Image, ImageFile
from tensorboardX import SummaryWriter
from tqdm import tqdm
from utils import learning_rate_decay
# for handling errors
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
# use gpu if possible
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
# directory trained models
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
# directory for logs
log_dir = Path(args.log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
# content dataset
content_dataset = ImageFolderDataset(args.content_dir, train_transform((512, 512), 256))
content_iter = iter(data.DataLoader(content_dataset, batch_size=args.batch_size, sampler=InfiniteSampler(len(content_dataset)), num_workers=args.n_threads))
# style dataset
style_dataset = ImageFolderDataset(args.style_dir, train_transform((512, 512), 256))
style_iter = iter(data.DataLoader(style_dataset, batch_size=args.batch_size, sampler=InfiniteSampler(len(style_dataset)), num_workers=args.n_threads))
# AdaIN model
model = AdaIN()
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=args.learning_rate)
# continual training
initial_iter = 0
if args.continual:
if os.path.exists(args.continual):
state_dict = torch.load(args.continual)
initial_iter = state_dict['iter']
model.encoder.load_state_dict(state_dict['encoder'])
model.decoder.load_state_dict(state_dict['decoder'])
optimizer.load_state_dict(state_dict['optimizer'])
# log writer
writer = SummaryWriter(log_dir=str(log_dir))
# for maximum iteration
model.to(device)
for i in tqdm(range(initial_iter, args.max_iter)):
# adjust learning rate
lr = learning_rate_decay(args.learning_rate, args.learning_rate_decay, i)
for group in optimizer.param_groups:
group['lr'] = lr
# get images
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
# calculate loss
g, loss_content, loss_style = model(content_images, style_images)
loss_content = args.content_weight * loss_content
loss_style = args.style_weight * loss_style
loss = loss_content + loss_style
# optimize the network
optimizer.zero_grad()
loss.backward()
optimizer.step()
# write logs
writer.add_scalar('Loss/Loss', loss.item(), i + 1)
writer.add_scalar('Loss/Loss_content', loss_content.item(), i + 1)
writer.add_scalar('Loss/Loss_style', loss_style.item(), i + 1)
if args.log_image_every > 0 and ((i + 1) % args.log_image_every == 0 or i == 0 or (i + 1) == args.max_iter):
writer.add_image('Image/Content', content_images[0], i + 1)
writer.add_image('Image/Style', style_images[0], i + 1)
writer.add_image('Image/Generated', g[0], i + 1)
# save model
if (i + 1) % args.save_interval == 0 or (i + 1) == args.max_iter:
save_AdaIn(model, os.path.join(save_dir, 'iter_{}.pth'.format(i + 1)), include_encoder=args.include_encoder)
# continual training
if args.continual:
encoder_dict = model.encoder.state_dict()
for key in encoder_dict.keys():
encoder_dict[key] = encoder_dict[key].cpu()
decoder_dict = model.decoder.state_dict()
for key in decoder_dict.keys():
decoder_dict[key] = decoder_dict[key].cpu()
optimizer_dict = optimizer.state_dict()
torch.save({
'iter': i + 1,
'encoder': encoder_dict,
'decoder': decoder_dict,
'optimizer': optimizer_dict
}, args.continual)
writer.close()