forked from ryanxingql/mfqev2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_extract_TrainingSet_NP.py
346 lines (263 loc) · 14.1 KB
/
main_extract_TrainingSet_NP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""Extract training set.
Randomly select frame to patch.
Patches are stored in several npys.
Each npy contains several batches.
So there are n x batch_size patches in each npy.
Return: a few npy with shape (n x width_patch x width_height x 1), dtype=np.float32 \in [0,1]."""
import os, glob, gc, h5py
import numpy as np
import random, math
def y_import(video_path, height_frame, width_frame, nfs, startfrm, bar=True, opt_clear=True):
"""Import Y channel from a yuv video.
startfrm: start from 0
return: (nfs * height * width), dtype=uint8"""
fp = open(video_path, 'rb')
# target at startfrm
blk_size = int(height_frame * width_frame * 3 / 2)
fp.seek(blk_size * startfrm, 0)
d0 = height_frame // 2
d1 = width_frame // 2
Yt = np.zeros((height_frame, width_frame), dtype=np.uint8) # 0-255
for ite_frame in range(nfs):
for m in range(height_frame):
for n in range(width_frame):
Yt[m,n] = ord(fp.read(1))
for m in range(d0):
for n in range(d1):
fp.read(1)
for m in range(d0):
for n in range(d1):
fp.read(1)
if ite_frame == 0:
Y = Yt[np.newaxis, :, :]
else:
Y = np.vstack((Y, Yt[np.newaxis, :, :]))
if bar:
print("\r%4d | %4d" % (ite_frame + 1, nfs), end="", flush=True)
if opt_clear:
print("\r ", end="\r")
fp.close()
return Y
def func_PatchFrame(info_patch, num_patch, ite_npy, mode):
"""Patch and store four npys with a same index.
Shuffle the patches inside these four npys before saving."""
order_FirstFrame, order_FirstPatch, order_LastFrame, order_LastPatch, \
VideoIndex_list_list, MidIndex_list_list, PreIndex_list_list, SubIndex_list_list,\
dir_save_stack = info_patch[:]
### Init stack
stack_pre = np.zeros((num_patch, height_patch, width_patch, 1), dtype=np.float32)
stack_cmp = np.zeros((num_patch, height_patch, width_patch, 1), dtype=np.float32)
stack_sub = np.zeros((num_patch, height_patch, width_patch, 1), dtype=np.float32)
stack_raw = np.zeros((num_patch, height_patch, width_patch, 1), dtype=np.float32)
### Extract patches
cal_patch_total = 0
num_frame_total = order_LastFrame - order_FirstFrame + 1
for ite_frame, order_frame in enumerate(range(order_FirstFrame, order_LastFrame + 1)):
print("\rframe %d | %d" % (ite_frame + 1, num_frame_total), end="")
cal_patch_frame = 0
### Extract basic information
index_video = VideoIndex_list_list[order_frame]
index_Mid = MidIndex_list_list[order_frame]
index_Pre = PreIndex_list_list[order_frame]
index_Sub = SubIndex_list_list[order_frame]
cmp_path = list_CmpVideo[index_video]
cmp_name = cmp_path.split("\\")[-1].split(".")[0]
raw_name = cmp_name
raw_name = raw_name + ".yuv"
raw_path = dir_raw + raw_name
dims_str = raw_name.split("_")[1]
width_frame = int(dims_str.split("x")[0])
height_frame = int(dims_str.split("x")[1])
### Cal step
step_height = int((height_frame - height_patch) / (num_patch_height - 1))
step_width = int((width_frame - width_patch) / (num_patch_width - 1))
### Load frames
Y_raw = np.squeeze(y_import(raw_path, height_frame, width_frame, 1, index_Mid, bar=False, opt_clear=False))
Y_cmp = np.squeeze(y_import(cmp_path, height_frame, width_frame, 1, index_Mid, bar=False, opt_clear=False))
Y_pre = np.squeeze(y_import(cmp_path, height_frame, width_frame, 1, index_Pre, bar=False, opt_clear=False))
Y_sub = np.squeeze(y_import(cmp_path, height_frame, width_frame, 1, index_Sub, bar=False, opt_clear=False))
### Patch
for ite_patch_height in range(num_patch_height):
start_height = ite_patch_height * step_height
for ite_patch_width in range(num_patch_width):
if (order_frame == order_FirstFrame) and (cal_patch_frame < order_FirstPatch):
cal_patch_frame += 1
continue
if (order_frame == order_LastFrame) and (cal_patch_frame > order_LastPatch):
cal_patch_frame += 1
continue
start_width = ite_patch_width * step_width
stack_pre[cal_patch_total, 0:height_patch, 0:width_patch, 0] = Y_pre[start_height:(start_height+height_patch), start_width:(start_width+width_patch)] / 255.0
stack_cmp[cal_patch_total, 0:height_patch, 0:width_patch, 0] = Y_cmp[start_height:(start_height+height_patch), start_width:(start_width+width_patch)] / 255.0
stack_sub[cal_patch_total, 0:height_patch, 0:width_patch, 0] = Y_sub[start_height:(start_height+height_patch), start_width:(start_width+width_patch)] / 255.0
stack_raw[cal_patch_total, 0:height_patch, 0:width_patch, 0] = Y_raw[start_height:(start_height+height_patch), start_width:(start_width+width_patch)] / 255.0
cal_patch_total += 1
cal_patch_frame += 1
### Shuffle and save npy
print("\nsaving 1/4...", end="")
random.seed(100)
random.shuffle(stack_pre)
save_path = dir_save_stack + "/stack_" + mode + "_pre_" + str(ite_npy) + ".hdf5"
f = h5py.File(save_path, "w")
f.create_dataset('stack_pre', data=stack_pre)
f.close()
stack_pre = []
gc.collect()
print("\rsaving 2/4...", end="")
random.seed(100)
random.shuffle(stack_cmp)
save_path = dir_save_stack + "/stack_" + mode + "_cmp_" + str(ite_npy) + ".hdf5"
f = h5py.File(save_path, "w")
f.create_dataset('stack_cmp', data=stack_cmp)
f.close()
stack_cmp = []
gc.collect()
print("\rsaving 3/4...", end="")
random.seed(100)
random.shuffle(stack_sub)
save_path = dir_save_stack + "/stack_" + mode + "_sub_" + str(ite_npy) + ".hdf5"
f = h5py.File(save_path, "w")
f.create_dataset('stack_sub', data=stack_sub)
f.close()
stack_sub = []
gc.collect()
print("\rsaving 4/4...", end="")
random.seed(100)
random.shuffle(stack_raw)
save_path = dir_save_stack + "/stack_" + mode + "_raw_" + str(ite_npy) + ".hdf5"
f = h5py.File(save_path, "w")
f.create_dataset('stack_raw', data=stack_raw)
f.close()
stack_raw = []
gc.collect()
print("\r ", end="\r") # clear bar
def main_extract_TrainingSet():
"""Extract training setself.
Select a non-PQF between each pair of PQFs.
Randomly select up to 20 non-PQFs each video."""
for QP in QP_list:
### Init dir_save_stack for this QP
dir_save_stack = dir_save_stack_pre + str(QP)
if not os.path.exists(dir_save_stack):
os.makedirs(dir_save_stack)
### List all randomly selected non-PQFs with their pre/sub PQFs and calculate the num of patches
VideoIndex_list_list = []
MidIndex_list_list = []
PreIndex_list_list = []
SubIndex_list_list = []
cal_frame = 0
for ite_CmpVideo in range(num_CmpVideo): # video by video
cmp_name = list_CmpVideo[ite_CmpVideo].split("//")[-1].split(".")[0]
# load PQF label
PQFLabel_path = dir_PQFLabel + "/PQFLabel_" + cmp_name + PQFLabel_sub
PQF_label = h5py.File(PQFLabel_path,'r')['PQF_label'][:]
# locate PQFs
PQFIndex_list = [i for i in range(len(PQF_label)) if PQF_label[i] == 1]
# select inconsistent pre and sub PQFs
num_PQF = len(PQFIndex_list)
distance_list = [PQFIndex_list[i + 1] - PQFIndex_list[i] for i in range(num_PQF - 1)]
PreIndex_list = [PQFIndex_list[0: (num_PQF - 1)][i] for i in range(num_PQF - 1) if distance_list[i] > 1]
SubIndex_list = [PQFIndex_list[1: num_PQF][i] for i in range(num_PQF - 1) if distance_list[i] > 1]
# randomly select maximum allowable pairs
random.seed(666)
random.shuffle(PreIndex_list)
random.seed(666)
random.shuffle(SubIndex_list)
num_pairs = len(PreIndex_list)
if num_pairs > max_NonPQF_OneVideo:
PreIndex_list = PreIndex_list[0: max_NonPQF_OneVideo]
SubIndex_list = SubIndex_list[0: max_NonPQF_OneVideo]
# randomly select one non-PQF from each pair
distance_list = [SubIndex_list[i] - PreIndex_list[i] for i in range(len(PreIndex_list))]
distance_list = [random.randint(1,d-1) for d in distance_list]
MidIndex_list = [PreIndex_list[i] + distance_list[i] for i in range(len(PreIndex_list))]
# record
cal_frame += len(PreIndex_list)
VideoIndex_list_list += [ite_CmpVideo] * len(PreIndex_list) # video index for all selected non-PQFs
PreIndex_list_list += PreIndex_list
MidIndex_list_list += MidIndex_list
SubIndex_list_list += SubIndex_list
num_patch_available = cal_frame * num_patch_PerFrame
print("Available frames: %d - patches: %d" % (cal_frame, num_patch_available))
### Shuffle the numbering of all frames
random.seed(888)
random.shuffle(VideoIndex_list_list)
random.seed(888)
random.shuffle(MidIndex_list_list)
random.seed(888)
random.shuffle(PreIndex_list_list)
random.seed(888)
random.shuffle(SubIndex_list_list)
### Cut down the num of frames
max_patch_total = int(num_patch_available / batch_size) * batch_size
max_frame_total = math.ceil(max_patch_total / num_patch_PerFrame) # may need one more frame to patch
VideoIndex_list_list = VideoIndex_list_list[0: max_frame_total]
MidIndex_list_list = MidIndex_list_list[0: max_frame_total]
PreIndex_list_list = PreIndex_list_list[0: max_frame_total]
SubIndex_list_list = SubIndex_list_list[0: max_frame_total]
### Cal num of batch for each npy, including training and validation
num_patch_val = int(int((1 - ratio_training) * max_patch_total) / batch_size) * batch_size
num_patch_tra = max_patch_total - num_patch_val # we can make sure that it is a multiple of batch size
num_batch_tra = int(num_patch_tra / batch_size)
num_batch_val = int(num_patch_val / batch_size)
num_npy_tra = int(num_batch_tra / max_batch_PerNpy)
num_batch_PerNpy_list_tra = [max_batch_PerNpy] * num_npy_tra
if (num_batch_tra % max_batch_PerNpy) > 0:
num_batch_PerNpy_list_tra.append(num_batch_tra - max_batch_PerNpy * num_npy_tra)
num_npy_val = int(num_batch_val / max_batch_PerNpy)
num_batch_PerNpy_list_val = [max_batch_PerNpy] * num_npy_val
if (num_batch_val % max_batch_PerNpy) > 0:
num_batch_PerNpy_list_val.append(num_batch_val - max_batch_PerNpy * num_npy_val)
### Patch and stack
# some frames may be partly patched.
for ite_npy_tra in range(len(num_batch_PerNpy_list_tra)):
print("stacking tra npy %d / %d..." % (ite_npy_tra + 1, len(num_batch_PerNpy_list_tra)))
# Cal the position of the first patch and the last patch of this npy
first_patch_cal = sum(num_batch_PerNpy_list_tra[0: ite_npy_tra]) * batch_size + 1
order_FirstFrame = math.ceil(first_patch_cal / num_patch_PerFrame) - 1
order_FirstPatch = first_patch_cal - order_FirstFrame * num_patch_PerFrame - 1
last_patch_cal = sum(num_batch_PerNpy_list_tra[0: ite_npy_tra + 1]) * batch_size
order_LastFrame = math.ceil(last_patch_cal / num_patch_PerFrame) - 1
order_LastPatch = last_patch_cal - order_LastFrame * num_patch_PerFrame - 1
# patch
num_patch = num_batch_PerNpy_list_tra[ite_npy_tra] * batch_size
info_patch = (order_FirstFrame, order_FirstPatch, order_LastFrame, order_LastPatch,
VideoIndex_list_list, MidIndex_list_list, PreIndex_list_list, SubIndex_list_list, dir_save_stack)
func_PatchFrame(info_patch, num_patch=num_patch, ite_npy=ite_npy_tra, mode="tra")
for ite_npy_val in range(len(num_batch_PerNpy_list_val)):
print("stacking val npy %d / %d..." % (ite_npy_val + 1, len(num_batch_PerNpy_list_val)))
# Cal the position of the first patch and the last patch of this npy
first_patch_cal = (sum(num_batch_PerNpy_list_tra) + sum(num_batch_PerNpy_list_val[0: ite_npy_val])) * batch_size + 1
order_FirstFrame = math.ceil(first_patch_cal / num_patch_PerFrame) - 1
order_FirstPatch = first_patch_cal - order_FirstFrame * num_patch_PerFrame - 1
last_patch_cal = (sum(num_batch_PerNpy_list_tra) + sum(num_batch_PerNpy_list_val[0: ite_npy_val + 1])) * batch_size
order_LastFrame = math.ceil(last_patch_cal / num_patch_PerFrame) - 1
order_LastPatch = last_patch_cal - order_LastFrame * num_patch_PerFrame - 1
# patch
num_patch = num_batch_PerNpy_list_val[ite_npy_val] * batch_size
info_patch = (order_FirstFrame, order_FirstPatch, order_LastFrame, order_LastPatch,
VideoIndex_list_list, MidIndex_list_list, PreIndex_list_list, SubIndex_list_list, dir_save_stack)
func_PatchFrame(info_patch, num_patch=num_patch, ite_npy=ite_npy_val, mode="val")
if __name__ == '__main__':
### Settings
num_patch_width = 26
num_patch_height = 16
height_patch = 64
width_patch = 64
dir_database = "G:/SCI/Database/"
dir_raw = dir_database + "/train_108/raw"
dir_cmp = dir_database + "/train_108/HM16.5_LDP/QP37"
dir_PQFLabel = "G:/SCI/MFQEv2.0/Database/PQF_label/ground_truth/train_108/QP37"
dir_save_stack_pre = "G:/SCI/MFQEv2.0/Database"
PQFLabel_sub = "_MaxNfs_300.hdf5"
QP_list = [37]
batch_size = 64
max_batch_PerNpy = 14000
ratio_training = 1.0 # we select a small part of test set for validation
max_NonPQF_OneVideo = 20
### List all cmp video
list_CmpVideo = glob.glob(dir_cmp + "/*.yuv")
num_CmpVideo = len(list_CmpVideo)
print(dir_cmp + "/*.yuv")
num_patch_PerFrame = num_patch_width * num_patch_height
main_extract_TrainingSet()