forked from mshunshin/SegNetCMR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
127 lines (80 loc) · 4.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!data/anaconda510/bin/python
import os
import tensorflow as tf
import tfmodel
DATA_NAME = 'Data'
TRAIN_SOURCE = "Train"
TEST_SOURCE = 'Test'
RUN_NAME = "SELU_Run03"
OUTPUT_NAME = 'Output'
CHECKPOINT_FN = 'model.ckpt'
WORKING_DIR = os.getcwd()
TRAIN_DATA_DIR = os.path.join(WORKING_DIR, DATA_NAME, TRAIN_SOURCE)
TEST_DATA_DIR = os.path.join(WORKING_DIR, DATA_NAME, TEST_SOURCE)
ROOT_LOG_DIR = os.path.join(WORKING_DIR, OUTPUT_NAME)
LOG_DIR = os.path.join(ROOT_LOG_DIR, RUN_NAME)
CHECKPOINT_FL = os.path.join(LOG_DIR, CHECKPOINT_FN)
print(CHECKPOINT_FL)
TRAIN_WRITER_DIR = os.path.join(LOG_DIR, TRAIN_SOURCE)
TEST_WRITER_DIR = os.path.join(LOG_DIR, TEST_SOURCE)
NUM_EPOCHS = 10
MAX_STEP = 2500
BATCH_SIZE = 6
LEARNING_RATE = 1e-04
SAVE_RESULTS_INTERVAL = 5
SAVE_CHECKPOINT_INTERVAL = 100
def main():
train_data = tfmodel.GetData(TRAIN_DATA_DIR)
test_data = tfmodel.GetData(TEST_DATA_DIR)
g = tf.Graph()
with g.as_default():
images, labels = tfmodel.placeholder_inputs(batch_size=BATCH_SIZE)
logits, softmax_logits = tfmodel.inference(images, class_inc_bg=2)
img = tfmodel.add_output_images(images=images, logits=logits, labels=labels) #in helper.py; for showing in tensor broad (only show the first time in the batch)
print(f"what is the img outtt type{type(img)}")
loss = tfmodel.loss_calc(logits=logits, labels=labels) #in evaluation; for
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = tfmodel.training(loss=loss, learning_rate=1e-04, global_step=global_step)
accuracy = tfmodel.evaluation(logits=logits, labels=labels)
summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.global_variables())
sm = tf.train.SessionManager(graph=g)
with sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=LOG_DIR) as sess:
sess.run(tf.local_variables_initializer())
train_writer = tf.summary.FileWriter(TRAIN_WRITER_DIR, sess.graph)
test_writer = tf.summary.FileWriter(TEST_WRITER_DIR)
global_step_value, = sess.run([global_step])
print("Last trained iteration was: ", global_step_value)
try:
while True:
if global_step_value >= MAX_STEP:
print(f"Reached MAX_STEP: {MAX_STEP} at step: {global_step_value}")
break
images_batch, labels_batch = test_data.next_batch(BATCH_SIZE)
feed_dict = {images: images_batch, labels: labels_batch}
if (global_step_value + 1) % SAVE_RESULTS_INTERVAL == 0:
_, loss_value, accuracy_value, global_step_value, summary_str = sess.run([train_op, loss, accuracy, global_step, summary], feed_dict=feed_dict)
train_writer.add_summary(summary_str, global_step=global_step_value)
print(f"TRAIN Step: {global_step_value}\tLoss: {loss_value}\tAccuracy: {accuracy_value}")
images_batch, labels_batch = train_data.next_batch(BATCH_SIZE)
feed_dict = {images: images_batch, labels: labels_batch}
loss_value, accuracy_value, global_step_value, summary_str = sess.run([loss, accuracy, global_step, summary], feed_dict=feed_dict)
test_writer.add_summary(summary_str, global_step=global_step_value)
print(f"TEST Step: {global_step_value}\tLoss: {loss_value}\tAccuracy: {accuracy_value}")
else:
_, loss_value, accuracy_value, global_step_value = sess.run([train_op, loss, accuracy, global_step], feed_dict=feed_dict)
print(f"TRAIN Step: {global_step_value}\tLoss: {loss_value}\tAccuracy: {accuracy_value}")
if global_step_value % SAVE_CHECKPOINT_INTERVAL == 0:
saver.save(sess, CHECKPOINT_FL, global_step=global_step_value)
print("Checkpoint Saved")
except Exception as e:
print('Exception')
print(e)
train_writer.flush()
test_writer.flush()
saver.save(sess, CHECKPOINT_FL, global_step=global_step_value)
print("Checkpoint Saved")
print("Stopping")
if __name__ == '__main__':
main()