Skip to content

Commit

Permalink
Added checkpoint to store best model
Browse files Browse the repository at this point in the history
Subodh Malgonde authored and Subodh Malgonde committed Dec 21, 2017

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent f35cba8 commit 20e5149
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -257,6 +257,10 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_

print("Actual learning rate:", LEARNING_RATE, ", Actual keep prob:", KEEP_PROB)

best_validation_accuracy = 0

saver = tf.train.Saver()

for epoch in range(epochs):
for batch in tqdm(range(batches_per_epoch)):
X_batch , y_batch = next(training_batch_generator)
@@ -277,18 +281,23 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_
training_loss_metrics.append(training_loss)
training_accuracy_metrics.append(training_accuracy)

if validation_accuracy > best_validation_accuracy:
best_validation_accuracy = validation_accuracy
saver.save(sess, 'checkpoints/checkpoint')

print(
"Epoch %d:" % (epoch + 1),
"Training loss: %.4f, accuracy: %.2f" % (training_loss, training_accuracy),
"Validation loss: %.4f, accuracy: %.2f" % (validation_loss, validation_accuracy)
)

# if epoch > 0 and (training_accuracy_metrics[-1] < training_accuracy_metrics[-2] - 0.02):
# print("Early stopping!!!! latest/prev accuracy: %.3f:%.3f" % (training_accuracy_metrics[-1], training_accuracy_metrics[-2]))
# break
if validation_accuracy < best_validation_accuracy:
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)


if epoch % 10 == 0 and epoch > 0:
save_model(sess)
print("Best validation accuracy", best_validation_accuracy)

save_model(sess, training_loss_metrics, validation_loss_metrics, training_accuracy_metrics,
validation_accuracy_metrics)

0 comments on commit 20e5149

Please sign in to comment.