-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_base.py
58 lines (45 loc) · 2.11 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from tqdm import tqdm
from models import GeneratorResNet
from datasets import *
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=16)
parser.add_argument('--task', type=str, default='day2dusk')
parser.add_argument('--dataset_dir', type=str, default='./data/')
parser.add_argument('--result_dir', type=str, default='./results/base/')
parser.add_argument('--save_dir', type=str, default='./pretrained_models/')
parser.add_argument('--gpu', type=str, default='0')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if __name__ == '__main__':
netG_A2B = GeneratorResNet()
netG_B2A = GeneratorResNet()
netG_A2B.cuda()
netG_B2A.cuda()
netG_A2B.load_state_dict(torch.load(os.path.join(args.save_dir, args.task, 'G_A2B.pth')))
netG_B2A.load_state_dict(torch.load(os.path.join(args.save_dir, args.task, 'G_B2A.pth')))
netG_A2B.eval()
netG_B2A.eval()
transforms_ = [transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]
dataloader = DataLoader(ImageDataset(os.path.join(args.dataset_dir, args.task), transforms_=transforms_, mode='test'),
batch_size=1, shuffle=False, num_workers=args.workers)
os.makedirs(os.path.join(args.result_dir, args.task, 'A2B'), exist_ok=True)
os.makedirs(os.path.join(args.result_dir, args.task, 'B2A'), exist_ok=True)
for i, batch in tqdm(enumerate(dataloader)):
real_A = batch['A'].cuda()
real_B = batch['B'].cuda()
filename_A = batch['filename_A'][0]
filename_B = batch['filename_B'][0]
with torch.no_grad():
fake_B = netG_A2B(real_A)
fake_A = netG_B2A(real_B)
save_image(0.5 * (fake_B + 1.0), os.path.join(args.result_dir, args.task, 'A2B', filename_A))
save_image(0.5 * (fake_A + 1.0), os.path.join(args.result_dir, args.task, 'B2A', filename_B))