-
Notifications
You must be signed in to change notification settings - Fork 85
/
run.py
72 lines (57 loc) · 2.14 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
""" Siamese implementation using Tensorflow with MNIST example.
This siamese network embeds a 28x28 image (a point in 784D)
into a point in 2D.
By Youngwook Paul Kwon (young at berkeley.edu)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from builtins import input
#import system things
from tensorflow.examples.tutorials.mnist import input_data # for data
import tensorflow as tf
import numpy as np
import os
#import helpers
import inference
import visualize
# prepare data and tf.session
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
sess = tf.InteractiveSession()
# setup siamese network
siamese = inference.siamese();
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss)
saver = tf.train.Saver()
tf.initialize_all_variables().run()
# if you just want to load a previously trainmodel?
load = False
model_ckpt = './model.meta'
if os.path.isfile(model_ckpt):
input_var = None
while input_var not in ['yes', 'no']:
input_var = input("We found model files. Do you want to load it and continue training [yes/no]?")
if input_var == 'yes':
load = True
# start training
if load: saver.restore(sess, './model')
for step in range(50000):
batch_x1, batch_y1 = mnist.train.next_batch(128)
batch_x2, batch_y2 = mnist.train.next_batch(128)
batch_y = (batch_y1 == batch_y2).astype('float')
_, loss_v = sess.run([train_step, siamese.loss], feed_dict={
siamese.x1: batch_x1,
siamese.x2: batch_x2,
siamese.y_: batch_y})
if np.isnan(loss_v):
print('Model diverged with loss = NaN')
quit()
if step % 10 == 0:
print ('step %d: loss %.3f' % (step, loss_v))
if step % 1000 == 0 and step > 0:
saver.save(sess, './model')
embed = siamese.o1.eval({siamese.x1: mnist.test.images})
embed.tofile('embed.txt')
# visualize result
x_test = mnist.test.images.reshape([-1, 28, 28])
y_test = mnist.test.labels
visualize.visualize(embed, x_test, y_test)