Skip to content

Commit f205c8a

Browse files
author
Ronnachai Jaroensri
committed
Training Code
1 parent 1a921e2 commit f205c8a

7 files changed

+389
-4
lines changed

configs/configspec.conf

+14
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,27 @@ exp_dir = string
66
image_height = integer
77
continue_train = boolean
88
num_epochs = integer(default=100)
9+
# Preprocessing
10+
poisson_noise_n = float(default=None)
11+
gauss_noise_n = float(default=None)
912
# IO
1013
dataset_dir = string(default=None)
1114
checkpoint_dir = string(default=None)
1215
logs_dir = string(default=None)
1316
restore_dir = string(default=None)
1417
save_freq = integer(default=6250)
1518
ckpt_to_keep = integer(default=5)
19+
# Loss
20+
l1_loss_weight = float(default=1.0)
21+
weight_decay = float(default=5e-4)
22+
texture_loss_weight = float(default=1.0)
23+
shape_loss_weight = float(default=1.0)
24+
# Learning
25+
decay_steps = integer(default=3000)
26+
batch_size = integer(default=8)
27+
learning_rate = float(default=0.0002)
28+
lr_decay = float(default=0.97)
29+
beta1 = float(default=0.9)
1630

1731
[architecture]
1832
# TODO: Use options for network_arch instead.

configs/o3f_hmhm2_bg_qnoise_mix4_nl_n_t_ds3.conf

+7
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@ exp_dir = data/training/%(exp_name)s
88
image_height = 384
99
continue_train = True
1010
# IO
11+
dataset_dir = /path/to/your/dataset
1112
checkpoint_dir = %(exp_dir)s/checkpoint
1213
logs_dir = %(exp_dir)s/logs
1314
test_dir = %(exp_dir)s/test
1415
save_freq = 1000
1516
ckpt_to_keep = 1000
17+
# Preprocessing
18+
poisson_noise_n = 0.3
19+
# Learning
20+
batch_size = 4
21+
lr_decay = 1.0
22+
learning_rate = 0.0001
1623

1724
[architecture]
1825
network_arch = ynet_3frames

convert_3frames_data_to_tfrecords.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import argparse
2+
import os
3+
import glob
4+
import sys
5+
import numpy as np
6+
from tqdm import tqdm
7+
import cv2
8+
import tensorflow as tf
9+
import json
10+
11+
FLAGS = None
12+
13+
def _float_feature(value):
14+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
15+
16+
def _bytes_feature(value):
17+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
18+
19+
20+
def convert_dataset(data_dir, out_name, color=False):
21+
# Open a TFRRecordWriter
22+
filename = os.path.join(out_name)
23+
writeOpts = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
24+
# writeOpts = tf.python_io.TFRecordOptions()
25+
writer = tf.python_io.TFRecordWriter(filename, options=writeOpts)
26+
27+
# Load each data sample (image_a, image_b, flow) and write it to the TFRecord
28+
for f in tqdm(glob.glob(os.path.join(data_dir, 'frameA', '*.png'))):
29+
f = os.path.basename(f)
30+
image_a_path = os.path.join(data_dir, 'frameA', f)
31+
image_b_path = os.path.join(data_dir, 'frameB', f)
32+
image_c_path = os.path.join(data_dir, 'frameC', f)
33+
flow_path = os.path.join(data_dir, 'amplified', f)
34+
f, _ = os.path.splitext(f)
35+
meta_path = os.path.join(data_dir, 'meta', f + '.json')
36+
37+
if color:
38+
flag = cv2.IMREAD_COLOR
39+
else:
40+
flag = cv2.IMREAD_GRAYSCALE
41+
image_a = cv2.imread(image_a_path, flags=flag).astype('uint8')
42+
image_b = cv2.imread(image_b_path, flags=flag).astype('uint8')
43+
image_c = cv2.imread(image_c_path, flags=flag).astype('uint8')
44+
flow = cv2.imread(flow_path, flags=flag).astype('uint8')
45+
46+
if color:
47+
image_a = cv2.cvtColor(image_a, code=cv2.COLOR_BGR2RGB)
48+
image_b = cv2.cvtColor(image_b, code=cv2.COLOR_BGR2RGB)
49+
image_c = cv2.cvtColor(image_c, code=cv2.COLOR_BGR2RGB)
50+
flow = cv2.cvtColor(flow, code=cv2.COLOR_BGR2RGB)
51+
52+
amplification_factor = json.load(open(meta_path))['amplification_factor']
53+
# Scale from [0, 255] -> [0.0, 1.0]
54+
# image_a = image_a / 255.0
55+
# image_b = image_b / 255.0
56+
# flow = flow / 255.0
57+
58+
image_a_raw = image_a.tostring()
59+
image_b_raw = image_b.tostring()
60+
image_c_raw = image_c.tostring()
61+
flow_raw = flow.tostring()
62+
63+
example = tf.train.Example(features=tf.train.Features(feature={
64+
'frameA': _bytes_feature(image_a_raw),
65+
'frameB': _bytes_feature(image_b_raw),
66+
'frameC': _bytes_feature(image_c_raw),
67+
'amplified': _bytes_feature(flow_raw),
68+
'amplification_factor': _float_feature(amplification_factor),
69+
}))
70+
writer.write(example.SerializeToString())
71+
writer.close()
72+
73+
74+
def main():
75+
# Convert the train and val datasets into .tfrecords format
76+
convert_dataset(os.path.join(FLAGS.data_dir, 'train'), os.path.join(FLAGS.out, 'train.tfrecords'), FLAGS.color)
77+
78+
79+
if __name__ == '__main__':
80+
parser = argparse.ArgumentParser()
81+
parser.add_argument(
82+
'--data_dir',
83+
type=str,
84+
required=True,
85+
help='Directory that includes all .png files in the dataset'
86+
)
87+
parser.add_argument(
88+
'--out',
89+
type=str,
90+
required=True,
91+
help='Directory for output .tfrecords files'
92+
)
93+
parser.add_argument('--color', action='store_true', help='Whether to store image as color.')
94+
FLAGS = parser.parse_args()
95+
96+
# Verify arguments are valid
97+
if not os.path.isdir(FLAGS.data_dir):
98+
raise ValueError('data_dir must exist and be a directory')
99+
if not os.path.isdir(FLAGS.out):
100+
raise ValueError('out must exist and be a directory')
101+
main()

data_loader.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import tensorflow as tf
2+
3+
4+
def read_and_decode(filename_queue, im_size=(512, 512, 1)):
5+
writeOpts = tf.python_io.TFRecordOptions(\
6+
tf.python_io.TFRecordCompressionType.ZLIB)
7+
reader = tf.TFRecordReader(options=writeOpts)
8+
_, single_example = reader.read(filename_queue)
9+
features = tf.parse_single_example(
10+
single_example,
11+
features={
12+
'frameA': tf.FixedLenFeature([], tf.string),
13+
'frameB': tf.FixedLenFeature([], tf.string),
14+
'amplified': tf.FixedLenFeature([], tf.string),
15+
'amplification_factor': tf.FixedLenFeature([], tf.float32),
16+
})
17+
frameA = tf.decode_raw(features['frameA'], tf.uint8)
18+
frameB = tf.decode_raw(features['frameB'], tf.uint8)
19+
frameAmp = tf.decode_raw(features['amplified'], tf.uint8)
20+
amplification_factor = tf.cast(features['amplification_factor'], tf.float32)
21+
22+
frameA = tf.reshape(frameA, im_size)
23+
frameB = tf.reshape(frameB, im_size)
24+
frameAmp = tf.reshape(frameAmp, im_size)
25+
26+
# Normalize to -1 to +1
27+
frameA = tf.to_float(frameA) / 127.5 - 1.0
28+
frameB = tf.to_float(frameB) / 127.5 - 1.0
29+
frameAmp = tf.to_float(frameAmp) / 127.5 - 1.0
30+
31+
return frameA, frameB, frameAmp, amplification_factor
32+
33+
def read_and_decode_3frames(filename_queue, im_size=(512, 512, 1)):
34+
writeOpts = tf.python_io.TFRecordOptions(\
35+
tf.python_io.TFRecordCompressionType.ZLIB)
36+
reader = tf.TFRecordReader(options=writeOpts)
37+
_, single_example = reader.read(filename_queue)
38+
features = tf.parse_single_example(
39+
single_example,
40+
features={
41+
'frameA': tf.FixedLenFeature([], tf.string),
42+
'frameB': tf.FixedLenFeature([], tf.string),
43+
'frameC': tf.FixedLenFeature([], tf.string),
44+
'amplified': tf.FixedLenFeature([], tf.string),
45+
'amplification_factor': tf.FixedLenFeature([], tf.float32),
46+
})
47+
frameA = tf.decode_raw(features['frameA'], tf.uint8)
48+
frameB = tf.decode_raw(features['frameB'], tf.uint8)
49+
frameC = tf.decode_raw(features['frameC'], tf.uint8)
50+
frameAmp = tf.decode_raw(features['amplified'], tf.uint8)
51+
amplification_factor = tf.cast(features['amplification_factor'], tf.float32)
52+
53+
frameA = tf.reshape(frameA, im_size)
54+
frameB = tf.reshape(frameB, im_size)
55+
frameC = tf.reshape(frameC, im_size)
56+
frameAmp = tf.reshape(frameAmp, im_size)
57+
58+
# Normalize to -1 to +1
59+
frameA = tf.to_float(frameA) / 127.5 - 1.0
60+
frameB = tf.to_float(frameB) / 127.5 - 1.0
61+
frameC = tf.to_float(frameC) / 127.5 - 1.0
62+
frameAmp = tf.to_float(frameAmp) / 127.5 - 1.0
63+
64+
return frameA, frameB, frameC, frameAmp, amplification_factor

magnet.py

+152-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
import tensorflow as tf
55
import numpy as np
66
import cv2
7+
import time
78

89
from glob import glob
9-
from scipy.signal import lfilter, firwin, butter
10+
from scipy.signal import firwin, butter
1011
from functools import partial
1112
from tqdm import tqdm, trange
1213
from subprocess import call
1314

15+
from modules import L1_loss
1416
from modules import res_encoder, res_decoder, res_manipulator
1517
from modules import residual_block, conv2d
1618
from utils import load_train_data, mkdir, imread, save_images
19+
from preprocessor import preprocess_image, preproc_color
20+
from data_loader import read_and_decode_3frames
1721

1822
# Change here if you use ffmpeg.
1923
DEFAULT_VIDEO_CONVERTER = 'avconv'
@@ -509,12 +513,156 @@ def run_temporal(self,
509513

510514
# Training code.
511515
def _build_training_graph(self, train_config):
512-
raise NotImplementedError()
516+
self.global_step = tf.Variable(0, trainable=False)
517+
filename_queue = tf.train.string_input_producer(
518+
[os.path.join(train_config["dataset_dir"],
519+
'train.tfrecords')],
520+
num_epochs=train_config["num_epochs"])
521+
frameA, frameB, frameC, frameAmp, amplification_factor = \
522+
read_and_decode_3frames(filename_queue,
523+
(train_config["image_height"],
524+
train_config["image_width"],
525+
self.n_channels))
526+
min_after_dequeue = 1000
527+
num_threads = 16
528+
capacity = min_after_dequeue + \
529+
(num_threads + 2) * train_config["batch_size"]
530+
531+
frameA, frameB, frameC, frameAmp, amplification_factor = \
532+
tf.train.shuffle_batch([frameA,
533+
frameB,
534+
frameC,
535+
frameAmp,
536+
amplification_factor],
537+
batch_size=train_config["batch_size"],
538+
capacity=capacity,
539+
num_threads=num_threads,
540+
min_after_dequeue=min_after_dequeue)
541+
542+
frameA = preprocess_image(frameA, train_config)
543+
frameB = preprocess_image(frameB, train_config)
544+
frameC = preprocess_image(frameC, train_config)
545+
self.loss_function = partial(self._loss_function,
546+
train_config=train_config)
547+
self.output = self.image_transformer(frameA,
548+
frameB,
549+
amplification_factor,
550+
[train_config["image_height"],
551+
train_config["image_width"]],
552+
self.arch_config, True, False)
553+
self.reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
554+
if self.reg_loss and train_config["weight_decay"] > 0.0:
555+
print("Adding Regularization Weights.")
556+
self.loss = self.loss_function(self.output, frameAmp) + \
557+
train_config["weight_decay"] * tf.add_n(self.reg_loss)
558+
else:
559+
print("No Regularization Weights.")
560+
self.loss = self.loss_function(self.output, frameAmp)
561+
# Add regularization more
562+
# TODO: Hardcoding the network name scope here.
563+
with tf.variable_scope('ynet_3frames/encoder', reuse=True):
564+
texture_c, shape_c = self._encoder(frameC)
565+
self.loss = self.loss + \
566+
train_config["texture_loss_weight"] * L1_loss(texture_c, self.texture_a) + \
567+
train_config["shape_loss_weight"] * L1_loss(shape_c, self.shape_b)
568+
569+
self.loss_sum = tf.summary.scalar('train_loss', self.loss)
570+
self.image_sum = tf.summary.image('train_B_OUT',
571+
tf.concat([frameB, self.output],
572+
axis=2),
573+
max_outputs=2)
574+
if self.n_channels == 3:
575+
self.image_comp_sum = tf.summary.image('train_GT_OUT',
576+
frameAmp - self.output,
577+
max_outputs=2)
578+
self.image_orig_comp_sum = tf.summary.image('train_ORIG_OUT',
579+
frameA - self.output,
580+
max_outputs=2)
581+
else:
582+
self.image_comp_sum = tf.summary.image('train_GT_OUT',
583+
tf.concat([frameAmp,
584+
self.output,
585+
frameAmp],
586+
axis=3),
587+
max_outputs=2)
588+
self.image_orig_comp_sum = tf.summary.image('train_ORIG_OUT',
589+
tf.concat([frameA,
590+
self.output,
591+
frameA],
592+
axis=3),
593+
max_outputs=2)
594+
self.saver = tf.train.Saver(max_to_keep=train_config["ckpt_to_keep"])
513595

514596
# Loss function
515597
def _loss_function(self, a, b, train_config):
516-
raise NotImplementedError()
598+
# Use train_config to implement more advance losses.
599+
with tf.variable_scope("loss_function"):
600+
return L1_loss(a, b) * train_config["l1_loss_weight"]
517601

518602
def train(self, train_config):
519-
raise NotImplementedError()
603+
# Define training graphs
604+
self._build_training_graph(train_config)
605+
606+
self.lr = tf.train.exponential_decay(train_config["learning_rate"],
607+
self.global_step,
608+
train_config["decay_steps"],
609+
train_config["lr_decay"],
610+
staircase=True)
611+
self.optim_op = tf.train.AdamOptimizer(self.lr,
612+
beta1=train_config["beta1"]) \
613+
.minimize(self.loss,
614+
var_list=tf.trainable_variables(),
615+
global_step=self.global_step)
616+
617+
ginit_op = tf.global_variables_initializer()
618+
linit_op = tf.local_variables_initializer()
619+
self.sess.run([ginit_op, linit_op])
620+
621+
self.writer = tf.summary.FileWriter(train_config["logs_dir"],
622+
self.sess.graph)
623+
coord = tf.train.Coordinator()
624+
threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
625+
626+
start_time = time.time()
627+
for v in tf.trainable_variables():
628+
print(v)
629+
if train_config["continue_train"] and \
630+
self.load(train_config["checkpoint_dir"]):
631+
print('[*] Load Success')
632+
elif train_config["restore_dir"] and \
633+
self.load(train_config["restore_dir"],
634+
tf.train.Saver(var_list=tf.trainable_variables())):
635+
self.sess.run(self.global_step.assign(0))
636+
print('[*] Restore success')
637+
else:
638+
print('Training from scratch.')
639+
try:
640+
while not coord.should_stop():
641+
_, loss_sum_str = self.sess.run([self.optim_op, self.loss_sum])
642+
global_step = self.sess.run(self.global_step)
643+
self.writer.add_summary(loss_sum_str, global_step)
644+
645+
if global_step % 100 == 0:
646+
# Write image summary.
647+
img_sum_str, img_comp_str, img_orig_str = \
648+
self.sess.run([self.image_sum,
649+
self.image_comp_sum,
650+
self.image_orig_comp_sum])
651+
self.writer.add_summary(img_sum_str, global_step)
652+
self.writer.add_summary(img_comp_str, global_step)
653+
self.writer.add_summary(img_orig_str, global_step)
654+
655+
elapsed_time = time.time() - start_time
656+
print ("Steps: %2d time: %4.4f (%4.4f steps/sec)" % (
657+
global_step, elapsed_time,
658+
float(global_step) / elapsed_time))
659+
660+
if np.mod(global_step, train_config["save_freq"]) == 2:
661+
self.save(train_config["checkpoint_dir"], global_step)
662+
663+
except tf.errors.OutOfRangeError:
664+
print('Done Training.')
665+
finally:
666+
coord.request_stop()
667+
coord.join(threads)
520668

0 commit comments

Comments
 (0)