diff --git a/main.py b/main.py index 71bf636..f7ae871 100644 --- a/main.py +++ b/main.py @@ -257,6 +257,10 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_ print("Actual learning rate:", LEARNING_RATE, ", Actual keep prob:", KEEP_PROB) + best_validation_accuracy = 0 + + saver = tf.train.Saver() + for epoch in range(epochs): for batch in tqdm(range(batches_per_epoch)): X_batch , y_batch = next(training_batch_generator) @@ -277,18 +281,23 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_ training_loss_metrics.append(training_loss) training_accuracy_metrics.append(training_accuracy) + if validation_accuracy > best_validation_accuracy: + best_validation_accuracy = validation_accuracy + saver.save(sess, 'checkpoints/checkpoint') + print( "Epoch %d:" % (epoch + 1), "Training loss: %.4f, accuracy: %.2f" % (training_loss, training_accuracy), "Validation loss: %.4f, accuracy: %.2f" % (validation_loss, validation_accuracy) ) - # if epoch > 0 and (training_accuracy_metrics[-1] < training_accuracy_metrics[-2] - 0.02): - # print("Early stopping!!!! latest/prev accuracy: %.3f:%.3f" % (training_accuracy_metrics[-1], training_accuracy_metrics[-2])) - # break + if validation_accuracy < best_validation_accuracy: + ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint')) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + - if epoch % 10 == 0 and epoch > 0: - save_model(sess) + print("Best validation accuracy", best_validation_accuracy) save_model(sess, training_loss_metrics, validation_loss_metrics, training_accuracy_metrics, validation_accuracy_metrics)