forked from bernardohenz/synt_noise_GANs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_artXrealXGAN_SIDDXnoise_flow.py
142 lines (105 loc) · 4.8 KB
/
test_artXrealXGAN_SIDDXnoise_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#########################################################
# Imports
#########################################################
######################################
# Misc
######################################
import os
import argparse
DESIRED_LOG_LEVEL = '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import warnings
warnings.filterwarnings('ignore',category=FutureWarning)
import tensorflow.compat.v1 as tfv1
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
######################################
# Keras
######################################
import keras
from keras.models import Model
######################################
# Utils
######################################
from util.extendable_datagen_several_classes import ImageDataGenerator, random_transform,standardize
from util.data_aug_transforms import random90rot,export_img,random_crop,center_crop,augment_contrast,augment_brightness,compute_fft2, srgb_to_linear
import numpy as np
import matplotlib.pyplot as plt
# We'll use matplotlib for graphics.
import matplotlib.patheffects as PathEffects
import matplotlib
# We import sklearn.
import sklearn
RS = 20150104
from sklearn.manifold import TSNE
from sklearn.datasets import load_digits
from sklearn.preprocessing import scale
# We'll hack a bit with the t-SNE code in sklearn 0.15.2.
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.manifold.t_sne import (_joint_probabilities,
_kl_divergence)
from sklearn.metrics import accuracy_score
from util.conf_matrix import plot_confusion_matrix
from model_creation import create_model
#########################################################
# General Run Parameters
#########################################################
######################################
# Change according to your needs
######################################
parser = argparse.ArgumentParser(description='Test of artificial vs real vs GAN noise.')
parser.add_argument('--test_dir', type=str, default='sample_imgs',
help='Folder containing image patches corrupted by clean, noisy and GAN noise.')
parser.add_argument('--selected_iso', type=int, default=800,
help='Value of VAD aggressiveness (must be 0-3)')
args = parser.parse_args()
SELECTED_ISO = args.selected_iso
test_clean_dir = os.path.join(args.test_dir,'SIDD_{0:04d}/clean'.format(SELECTED_ISO))
test_noisy_dir = os.path.join(args.test_dir,'SIDD_{0:04d}/noisy'.format(SELECTED_ISO))
test_GAN_dir = os.path.join(args.test_dir,'SIDD_{0:04d}/GAN'.format(SELECTED_ISO))
# Data specific constants
image_size = 128,128
#classes = ['gaussian','poisson','gaussianpoisson','mosaicgaussianpoisson','natural']
classes = ['Gaussian','Poisson','GaussPois','GaussMosaic','NoiseFlow','GAN_SIDD','Natural']
num_classes = 7
#load best model
datagenTest = ImageDataGenerator()
datagenTest.config['random_crop_size'] = image_size
datagenTest.set_pipeline([random_crop,standardize,compute_fft2])
flow_test = datagenTest.flow_from_directory(test_clean_dir,batch_size=50,color_mode='rgbfft',target_size=image_size)
flow_test.setCurrentISO(SELECTED_ISO, test_noisy_dir)
flow_test.setGANdir(test_GAN_dir)
flow_test.batch_size = 3#batchsizes_for_isos[str(ISO_LEVEL)]
total_batch_size = 500
x = np.zeros((total_batch_size,image_size[0],image_size[1],6))
y_true = np.zeros((total_batch_size,7))
iter_flow = iter(flow_test)
for i in range(total_batch_size//flow_test.batch_size):
if (((i*flow_test.batch_size)%len(flow_test.filenames))==0):
flow_test.on_epoch_end()
iter_flow = iter(flow_test)
x_cur,y_cur = next(iter_flow)
x[i*flow_test.batch_size:(i+1)*flow_test.batch_size] = x_cur
y_true[i*flow_test.batch_size:(i+1)*flow_test.batch_size]=y_cur
model = create_model(image_size, num_classes=num_classes)
model.load_weights('trained_models/{}_SIDD_several_classes_weights.h5'.format(SELECTED_ISO))
y_pred = model.predict(x)
y_true = y_true.argmax(axis=1)
y_pred = y_pred.argmax(axis=1)
plot_confusion_matrix(y_true,y_pred, classes=classes, normalize=False,
title='Noise Type Classifier')
plt.savefig('conf_matrix_{}_sidd.pdf'.format(SELECTED_ISO))
newModel = Model(inputs=model.input, outputs=model.layers[-4].output )
x_feat_all = newModel.predict(x)
x_proj = TSNE(random_state=RS+2).fit_transform(x_feat_all)
from util.visualization import scatter
nb_points = 2000
flow_test.batch_size = nb_points
scatter(x_proj[:nb_points],y_true[:nb_points],labels=classes,chosen_palette='muted')
plt.savefig('{}_SIDD_w_legend_palette_muted.pdf'.format(SELECTED_ISO))
scatter(x_proj[:nb_points],y_true[:nb_points],labels=classes,legend=False,chosen_palette='muted')
plt.savefig('{}_SIDD_no_legend_palette_muted.pdf'.format(SELECTED_ISO))