-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
35 lines (29 loc) · 974 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from options import DebugOptions, TrainOptions
from data import create_dataset
from model import create_model
from torch.backends import cudnn
import torch
#opt = DebugOptions()
opt = TrainOptions()
#os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids[0]) # test single GPU first
torch.cuda.set_device(opt.gpu_ids[0])
cudnn.enabled = True
cudnn.benchmark = True
loader = create_dataset(opt)
dataset_size = len(loader)
print('#training images = %d' % dataset_size)
net = create_model(opt)
for epoch in range(1,opt.niter+opt.niter_decay+1):
print('Begin epoch %d' % epoch)
for i, data_i in enumerate(loader):
net.set_input(data_i)
net.optimize_parameters()
#### logging, visualizing, saving
if i % opt.print_every == 0:
net.log_loss(epoch, i)
if i % opt.visual_every == 0:
net.log_visual(epoch, i)
net.save_networks('latest')
if epoch % opt.save_every == 0:
net.save_networks(epoch)
net.update_learning_rate()