From a4edc352ed2345e8301e5068836b62d64c3360dc Mon Sep 17 00:00:00 2001 From: tks10 Date: Thu, 6 Dec 2018 19:00:08 +0900 Subject: [PATCH] feat: saving model --- util/repoter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/util/repoter.py b/util/repoter.py index 42db8e3..db8d95d 100644 --- a/util/repoter.py +++ b/util/repoter.py @@ -10,9 +10,11 @@ class Reporter: IMAGE_DIR = "image" LEARNING_DIR = "learning" INFO_DIR = "info" + MODEL_DIR = "model" PARAMETER = "parameter.txt" IMAGE_PREFIX = "epoch_" IMAGE_EXTENSION = ".png" + MODEL_NAME = "model.ckpt" def __init__(self, result_dir=None, parser=None): if result_dir is None: @@ -24,6 +26,7 @@ def __init__(self, result_dir=None, parser=None): self._image_test_dir = os.path.join(self._image_dir, "test") self._learning_dir = os.path.join(self._result_dir, self.LEARNING_DIR) self._info_dir = os.path.join(self._result_dir, self.INFO_DIR) + self._model_dir = os.path.join(self._result_dir, self.MODEL_DIR) self._parameter = os.path.join(self._info_dir, self.PARAMETER) self.create_dirs() @@ -108,6 +111,9 @@ def get_imageset(image_in_np, image_out_np, image_tc_np, palette, index_void=Non image_result = Reporter.concat_images(image_in_pil, image_concated, None, "RGB") return image_result + def store_model(self, saver, sess): + saver.save(sess, os.path.join(self._model_dir, self.MODEL_NAME)) + class MatPlotManager: def __init__(self, root_dir):