-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval.py
126 lines (116 loc) · 4.28 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
import torch
import torch.nn as nn
import torch.nn.functional as f
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T
import cv2
import time
import argparse
from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity
from skimage.restoration import denoise_wavelet
from xgboost import XGBRegressor
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', help = 'Directory of output images')
parser.add_argument('-i', '--input', help = 'Directory of input images')
args = parser.parse_args()
path_to_test = args.input
path_to_result = args.output
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if device == torch.device('cuda') :
print('GPU Found !')
print('Loading trained models ...')
print()
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.TransformerEncoderLayer(d_model = 300, nhead = 10, dim_feedforward = 512, batch_first = True)
self.transformer = nn.TransformerEncoder(self.layer, num_layers = 2)
self.extra = nn.Sequential(
nn.Conv2d(3, 128, (3, 3), padding = 'same'),
nn.GELU(),
nn.Conv2d(128, 3, (3, 3), padding = 'same')
)
def forward(self, x):
x = self.img_to_patch(x, 10)
x = self.transformer(x)
x = self.patch_to_img(x, 10, 3, 400, 600)
x = self.extra(x)
return f.sigmoid(x)
def img_to_patch(self, x, patch_size):
B, C, H, W = x.shape
x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.flatten(1,2)
x = x.flatten(2,4)
return x
def patch_to_img(self, x, patch_size, C, H, W):
x = x.view(-1, H*W//(patch_size)**2, C , patch_size , patch_size)
x = x.view(-1, H // patch_size, W // patch_size, C, patch_size, patch_size)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(-1, C, H, W)
return x
model = Model()
model = model.to(device)
checkpoint = torch.load("models/transformer_conv_transform_new_input.pt", map_location = device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model_xgb = XGBRegressor(n_estimators=200, max_depth=10, learning_rate=0.2)
model_xgb.load_model('models/my_xgb_model.model')
print('Loaded trained models ...')
print()
print('Importing images ...')
print()
low = []
i = 0
dir = path_to_test
for img in sorted(os.listdir(dir)):
if img.endswith('.png'):
x = Image.open(dir + '/' + img)
x = torch.from_numpy(np.array(x))
low.append(x)
i += 1
print(f'Count : {i}', end = '\r')
low = torch.stack(low).permute(0, 3, 1, 2) / 255
def hist_from_quant(quant, c=0.75):
x = np.linspace(0, 1, 255)
kde = gaussian_kde(quant, c*quant.std())
pdf = kde(x)
return pdf
def convert(img, hist):
input_hist, _ = np.histogram(img, bins=256, range=(0, 1))
input_hist = input_hist / np.sum(input_hist)
desired_hist = hist / np.sum(hist)
input_cumsum = np.cumsum(input_hist)
desired_cumsum = np.cumsum(desired_hist)
mapping_func = np.interp(input_cumsum, desired_cumsum, np.linspace(0, 1, 255))
matched_image = np.interp(img, np.linspace(0, 1, 256), mapping_func)
return matched_image
print('Preprocessing Images ...')
start = time.time()
out = model_xgb.predict(np.array([np.histogram(img, bins=256, range=(0,1))[0] for img in low.reshape(-1, 400, 600)]))
new_input = []
for i in range(len(low)):
im=[]
hist = hist_from_quant(out[3*i])
im.append(convert(low[i][0], hist))
hist = hist_from_quant(out[3*i+1])
im.append(convert(low[i][1], hist))
hist = hist_from_quant(out[3*i+2])
im.append(convert(low[i][2], hist))
new_input.append(im)
new_input = torch.from_numpy(np.array(new_input)).float()
print(f'Time taken : {round(time.time() - start, 3)}')
print()
print('Generating new images ...')
start = time.time()
with torch.no_grad():
result = model(new_input.to(device))
for i in range(len(result)):
cv2.imwrite(path_to_result + '/' + f'result{i}.png', cv2.cvtColor(result[i].permute(1, 2, 0).numpy() * 255, cv2.COLOR_RGB2BGR))
print(f'Time taken : {round(time.time() - start, 3)}')
print(f'Generated images saved to {path_to_result}')