-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate.py
71 lines (61 loc) · 2.96 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
import pathlib
import random
import numpy as np
import tqdm
import helpers
from flow_utils import *
from metrics.lpips.loss import PerceptualLoss
def main(opt):
if opt.device:
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print("Random Seed: ", opt.seed)
random.seed(opt.seed)
torch.manual_seed(opt.seed)
if opt.device:
torch.cuda.manual_seed_all(opt.seed)
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
models = helpers.load_model_from_checkpoint(opt, device)
print(opt)
_, test_loader = helpers.get_loaders(opt, True)
lpips_model = PerceptualLoss('lpips_weights', use_gpu=True)
all_metrics = {'psnr': [], 'ssim': [], 'lpips': []}
all_samples = {'psnr': [], 'ssim': [], 'lpips': []}
gts = []
for i, test_x in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
# test_x = next(testing_batch_generator)
test_x = test_x.to(device)[:opt.n_eval]
samples, metrics = helpers.eval_step(test_x, models, opt, device, lpips_model)
for name in samples.keys():
all_metrics[name] += metrics[name].cpu().detach().numpy().tolist()
all_samples[name] += [
(samples[name].cpu().detach().numpy() * 255).astype('uint8')] # shape is different now, check this
gts += [test_x.cpu().detach().numpy()]
pathlib.Path(opt.log_dir).mkdir(exist_ok=True)
to_save = np.concatenate(gts, axis=1)
np.savez_compressed(os.path.join(opt.log_dir, 'gts.npz'), samples=to_save)
for name, values in all_metrics.items():
print(name, np.mean(values), '+/-', np.std(values) / np.sqrt(len(values)))
to_save = np.concatenate(all_samples[name], axis=0)
np.savez_compressed(os.path.join(opt.log_dir, f'{name}.npz'), samples=np.transpose(to_save, (1, 0, 4, 2, 3)))
np.savez_compressed(os.path.join(opt.log_dir, f'results_{name}.npz'), values)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=20, type=int, help='batch size')
parser.add_argument('--data_root', default='mnist_data', help='root directory for data')
parser.add_argument('--model_path', default='', help='path to model')
parser.add_argument('--log_dir', default='', help='directory to save generations to')
parser.add_argument('--seed', default=1, type=int, help='manual seed')
parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on')
parser.add_argument('--n_future', type=int, default=20, help='number of frames to predict')
parser.add_argument('--data_threads', type=int, default=0, help='number of data loading threads')
parser.add_argument('--nsample', type=int, default=100, help='number of samples')
parser.add_argument('--device', action='store_true', help='if true, use gpu')
opt = parser.parse_args()
opt.n_eval = opt.n_past + opt.n_future
main(opt)