Skip to content

Commit

Permalink
Another test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Subodh Malgonde authored and Subodh Malgonde committed Dec 20, 2017
1 parent c5d5580 commit 6251175
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ def save_model(sess, training_loss_metrics, validation_loss_metrics,
with open('validation_accuracy_history', 'wb') as f:
pickle.dump(validation_accuracy_history, f)

def evaluate(image_paths, data_folder, image_shape, sess, input_image,correct_label, keep_prob, loss_op, accuracy_op, is_training):
data_generator_function = helper.gen_batch_function(data_folder, image_shape, image_paths, augment=False)
batch_size = 8
data_generator = data_generator_function(batch_size)
num_examples = int(math.floor(len(image_paths)/batch_size)*batch_size)
total_loss = 0
total_acc = 0
for offset in range(0, num_examples, batch_size):
X_batch, y_batch = next(data_generator)
loss, accuracy = sess.run([loss_op, accuracy_op], feed_dict={input_image: X_batch, correct_label: y_batch, keep_prob: 1.0,
is_training: False})
total_loss += (loss * X_batch.shape[0])
total_acc += (accuracy * X_batch.shape[0])
return total_loss/num_examples, total_acc/num_examples


def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_paths, validation_image_paths, train_op,
cross_entropy_loss, accuracy_op, input_image, correct_label, keep_prob, learning_rate, is_training):
Expand Down Expand Up @@ -249,22 +264,6 @@ def train_nn(sess, epochs, data_folder, image_shape, batch_size, training_image_
tests.test_train_nn(train_nn)


def evaluate(image_paths, data_folder, image_shape, sess, input_image,correct_label, keep_prob, loss_op, accuracy_op, is_training):
data_generator_function = helper.gen_batch_function(data_folder, image_shape, image_paths, augment=False)
batch_size = 8
data_generator = data_generator_function(batch_size)
num_examples = int(math.floor(len(image_paths)/batch_size)*batch_size)
total_loss = 0
total_acc = 0
for offset in range(0, num_examples, batch_size):
X_batch, y_batch = next(data_generator)
loss, accuracy = sess.run([loss_op, accuracy_op], feed_dict={input_image: X_batch, correct_label: y_batch, keep_prob: 1.0,
is_training: False})
total_loss += (loss * X_batch.shape[0])
total_acc += (accuracy * X_batch.shape[0])
return total_loss/num_examples, total_acc/num_examples


def run():
global LEARNING_RATE
global KEEP_PROB
Expand Down

0 comments on commit 6251175

Please sign in to comment.