diff --git a/main.py b/main.py index 0a52f26..c9f560d 100644 --- a/main.py +++ b/main.py @@ -106,6 +106,8 @@ def train(parser): print("Result") print("[Test] Loss:", loss_test, "Accuracy:", accuracy_test) + sess.close() + def get_parser(): parser = argparse.ArgumentParser( diff --git a/util/repoter.py b/util/repoter.py index 4d9a323..0d732ab 100644 --- a/util/repoter.py +++ b/util/repoter.py @@ -113,7 +113,7 @@ def get_imageset(image_in_np, image_out_np, image_tc_np, palette, index_void=Non image_result = Reporter.concat_images(image_in_pil, image_concated, None, "RGB") return image_result - def store_model(self, saver, sess): + def save_model(self, saver, sess): saver.save(sess, os.path.join(self._model_dir, self.MODEL_NAME))