diff --git a/main.py b/main.py index 3e52b0f..254b1af 100644 --- a/main.py +++ b/main.py @@ -273,6 +273,10 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_ training_loss_metrics.append(training_loss) training_accuracy_metrics.append(training_accuracy) + if epoch > 0 and (training_accuracy_metrics[-1] < training_accuracy_metrics[-2] - 0.01): + print("Early stopping!!!!") + break + print( "Epoch %d:" % (epoch + 1), "Training loss: %.4f, accuracy: %.2f" % (training_loss, training_accuracy),