|
4 | 4 | import tensorflow as tf
|
5 | 5 | import numpy as np
|
6 | 6 | import cv2
|
| 7 | +import time |
7 | 8 |
|
8 | 9 | from glob import glob
|
9 |
| -from scipy.signal import lfilter, firwin, butter |
| 10 | +from scipy.signal import firwin, butter |
10 | 11 | from functools import partial
|
11 | 12 | from tqdm import tqdm, trange
|
12 | 13 | from subprocess import call
|
13 | 14 |
|
| 15 | +from modules import L1_loss |
14 | 16 | from modules import res_encoder, res_decoder, res_manipulator
|
15 | 17 | from modules import residual_block, conv2d
|
16 | 18 | from utils import load_train_data, mkdir, imread, save_images
|
| 19 | +from preprocessor import preprocess_image, preproc_color |
| 20 | +from data_loader import read_and_decode_3frames |
17 | 21 |
|
18 | 22 | # Change here if you use ffmpeg.
|
19 | 23 | DEFAULT_VIDEO_CONVERTER = 'avconv'
|
@@ -509,12 +513,156 @@ def run_temporal(self,
|
509 | 513 |
|
510 | 514 | # Training code.
|
511 | 515 | def _build_training_graph(self, train_config):
|
512 |
| - raise NotImplementedError() |
| 516 | + self.global_step = tf.Variable(0, trainable=False) |
| 517 | + filename_queue = tf.train.string_input_producer( |
| 518 | + [os.path.join(train_config["dataset_dir"], |
| 519 | + 'train.tfrecords')], |
| 520 | + num_epochs=train_config["num_epochs"]) |
| 521 | + frameA, frameB, frameC, frameAmp, amplification_factor = \ |
| 522 | + read_and_decode_3frames(filename_queue, |
| 523 | + (train_config["image_height"], |
| 524 | + train_config["image_width"], |
| 525 | + self.n_channels)) |
| 526 | + min_after_dequeue = 1000 |
| 527 | + num_threads = 16 |
| 528 | + capacity = min_after_dequeue + \ |
| 529 | + (num_threads + 2) * train_config["batch_size"] |
| 530 | + |
| 531 | + frameA, frameB, frameC, frameAmp, amplification_factor = \ |
| 532 | + tf.train.shuffle_batch([frameA, |
| 533 | + frameB, |
| 534 | + frameC, |
| 535 | + frameAmp, |
| 536 | + amplification_factor], |
| 537 | + batch_size=train_config["batch_size"], |
| 538 | + capacity=capacity, |
| 539 | + num_threads=num_threads, |
| 540 | + min_after_dequeue=min_after_dequeue) |
| 541 | + |
| 542 | + frameA = preprocess_image(frameA, train_config) |
| 543 | + frameB = preprocess_image(frameB, train_config) |
| 544 | + frameC = preprocess_image(frameC, train_config) |
| 545 | + self.loss_function = partial(self._loss_function, |
| 546 | + train_config=train_config) |
| 547 | + self.output = self.image_transformer(frameA, |
| 548 | + frameB, |
| 549 | + amplification_factor, |
| 550 | + [train_config["image_height"], |
| 551 | + train_config["image_width"]], |
| 552 | + self.arch_config, True, False) |
| 553 | + self.reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) |
| 554 | + if self.reg_loss and train_config["weight_decay"] > 0.0: |
| 555 | + print("Adding Regularization Weights.") |
| 556 | + self.loss = self.loss_function(self.output, frameAmp) + \ |
| 557 | + train_config["weight_decay"] * tf.add_n(self.reg_loss) |
| 558 | + else: |
| 559 | + print("No Regularization Weights.") |
| 560 | + self.loss = self.loss_function(self.output, frameAmp) |
| 561 | + # Add regularization more |
| 562 | + # TODO: Hardcoding the network name scope here. |
| 563 | + with tf.variable_scope('ynet_3frames/encoder', reuse=True): |
| 564 | + texture_c, shape_c = self._encoder(frameC) |
| 565 | + self.loss = self.loss + \ |
| 566 | + train_config["texture_loss_weight"] * L1_loss(texture_c, self.texture_a) + \ |
| 567 | + train_config["shape_loss_weight"] * L1_loss(shape_c, self.shape_b) |
| 568 | + |
| 569 | + self.loss_sum = tf.summary.scalar('train_loss', self.loss) |
| 570 | + self.image_sum = tf.summary.image('train_B_OUT', |
| 571 | + tf.concat([frameB, self.output], |
| 572 | + axis=2), |
| 573 | + max_outputs=2) |
| 574 | + if self.n_channels == 3: |
| 575 | + self.image_comp_sum = tf.summary.image('train_GT_OUT', |
| 576 | + frameAmp - self.output, |
| 577 | + max_outputs=2) |
| 578 | + self.image_orig_comp_sum = tf.summary.image('train_ORIG_OUT', |
| 579 | + frameA - self.output, |
| 580 | + max_outputs=2) |
| 581 | + else: |
| 582 | + self.image_comp_sum = tf.summary.image('train_GT_OUT', |
| 583 | + tf.concat([frameAmp, |
| 584 | + self.output, |
| 585 | + frameAmp], |
| 586 | + axis=3), |
| 587 | + max_outputs=2) |
| 588 | + self.image_orig_comp_sum = tf.summary.image('train_ORIG_OUT', |
| 589 | + tf.concat([frameA, |
| 590 | + self.output, |
| 591 | + frameA], |
| 592 | + axis=3), |
| 593 | + max_outputs=2) |
| 594 | + self.saver = tf.train.Saver(max_to_keep=train_config["ckpt_to_keep"]) |
513 | 595 |
|
514 | 596 | # Loss function
|
515 | 597 | def _loss_function(self, a, b, train_config):
|
516 |
| - raise NotImplementedError() |
| 598 | + # Use train_config to implement more advance losses. |
| 599 | + with tf.variable_scope("loss_function"): |
| 600 | + return L1_loss(a, b) * train_config["l1_loss_weight"] |
517 | 601 |
|
518 | 602 | def train(self, train_config):
|
519 |
| - raise NotImplementedError() |
| 603 | + # Define training graphs |
| 604 | + self._build_training_graph(train_config) |
| 605 | + |
| 606 | + self.lr = tf.train.exponential_decay(train_config["learning_rate"], |
| 607 | + self.global_step, |
| 608 | + train_config["decay_steps"], |
| 609 | + train_config["lr_decay"], |
| 610 | + staircase=True) |
| 611 | + self.optim_op = tf.train.AdamOptimizer(self.lr, |
| 612 | + beta1=train_config["beta1"]) \ |
| 613 | + .minimize(self.loss, |
| 614 | + var_list=tf.trainable_variables(), |
| 615 | + global_step=self.global_step) |
| 616 | + |
| 617 | + ginit_op = tf.global_variables_initializer() |
| 618 | + linit_op = tf.local_variables_initializer() |
| 619 | + self.sess.run([ginit_op, linit_op]) |
| 620 | + |
| 621 | + self.writer = tf.summary.FileWriter(train_config["logs_dir"], |
| 622 | + self.sess.graph) |
| 623 | + coord = tf.train.Coordinator() |
| 624 | + threads = tf.train.start_queue_runners(sess=self.sess, coord=coord) |
| 625 | + |
| 626 | + start_time = time.time() |
| 627 | + for v in tf.trainable_variables(): |
| 628 | + print(v) |
| 629 | + if train_config["continue_train"] and \ |
| 630 | + self.load(train_config["checkpoint_dir"]): |
| 631 | + print('[*] Load Success') |
| 632 | + elif train_config["restore_dir"] and \ |
| 633 | + self.load(train_config["restore_dir"], |
| 634 | + tf.train.Saver(var_list=tf.trainable_variables())): |
| 635 | + self.sess.run(self.global_step.assign(0)) |
| 636 | + print('[*] Restore success') |
| 637 | + else: |
| 638 | + print('Training from scratch.') |
| 639 | + try: |
| 640 | + while not coord.should_stop(): |
| 641 | + _, loss_sum_str = self.sess.run([self.optim_op, self.loss_sum]) |
| 642 | + global_step = self.sess.run(self.global_step) |
| 643 | + self.writer.add_summary(loss_sum_str, global_step) |
| 644 | + |
| 645 | + if global_step % 100 == 0: |
| 646 | + # Write image summary. |
| 647 | + img_sum_str, img_comp_str, img_orig_str = \ |
| 648 | + self.sess.run([self.image_sum, |
| 649 | + self.image_comp_sum, |
| 650 | + self.image_orig_comp_sum]) |
| 651 | + self.writer.add_summary(img_sum_str, global_step) |
| 652 | + self.writer.add_summary(img_comp_str, global_step) |
| 653 | + self.writer.add_summary(img_orig_str, global_step) |
| 654 | + |
| 655 | + elapsed_time = time.time() - start_time |
| 656 | + print ("Steps: %2d time: %4.4f (%4.4f steps/sec)" % ( |
| 657 | + global_step, elapsed_time, |
| 658 | + float(global_step) / elapsed_time)) |
| 659 | + |
| 660 | + if np.mod(global_step, train_config["save_freq"]) == 2: |
| 661 | + self.save(train_config["checkpoint_dir"], global_step) |
| 662 | + |
| 663 | + except tf.errors.OutOfRangeError: |
| 664 | + print('Done Training.') |
| 665 | + finally: |
| 666 | + coord.request_stop() |
| 667 | + coord.join(threads) |
520 | 668 |
|
0 commit comments