-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain_loop.py
60 lines (42 loc) · 1.83 KB
/
main_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from comet_ml import Experiment
import numpy as np
import os
import math
import os.path
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
from inception_score import prefetch_inception_model
from input import train_input_fn, eval_input_fn, predict_input_fn
from utils import *
from args import *
def run_main_loop(args, train_estimator, predict_estimator):
total_steps = 0
train_steps = math.ceil(args.train_examples / args._batch_size)
eval_steps = math.ceil(args.eval_examples / args._batch_size)
if args.use_comet:
experiment = Experiment(api_key=comet_ml_api_key, project_name=comet_ml_project, workspace=comet_ml_workspace)
experiment.log_parameters(vars(args))
experiment.add_tags(args.tag)
experiment.set_name(model_name(args))
else:
experiment = None
prefetch_inception_model()
with tf.gfile.Open(os.path.join(suffixed_folder(args, args.result_dir), "eval.txt"), "a") as eval_file:
for epoch in range(0, args.epochs, args.predict_every):
logger.info(f"Training epoch {epoch}")
train_estimator.train(input_fn=train_input_fn, steps=train_steps * args.predict_every)
total_steps += train_steps * args.predict_every
if args.use_comet:
experiment.set_step(epoch)
# logger.info(f"Evaluate {epoch}")
# evaluation = predict_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
# logger.info(evaluation)
# save_evaluation(args, eval_file, evaluation, epoch, total_steps)
# if args.use_comet:
# experiment.log_metrics(evaluation)
logger.info(f"Generate predictions {epoch}")
predictions = predict_estimator.predict(input_fn=predict_input_fn)
logger.info(f"Save predictions")
save_predictions(args, suffixed_folder(args, args.result_dir), eval_file, predictions, epoch, total_steps, experiment)
logger.info(f"Completed {args.epochs} epochs")