Skip to content

Commit 41c52d6

Browse files
committedApr 1, 2016
Spatial Transformer model
Shorten STN summary in README relinked to data files adding license header, editing AUTHORS file adding tensorflow version
1 parent d51fdd2 commit 41c52d6

File tree

7 files changed

+625
-0
lines changed

7 files changed

+625
-0
lines changed
 

‎AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
# The email address is not required for organizations.
88

99
Google Inc.
10+
David Dao <daviddao@broad.mit.edu>

‎transformer/README.md

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Spatial Transformer Network
2+
3+
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
4+
5+
<div align="center">
6+
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
7+
</div>
8+
9+
### API
10+
11+
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
12+
13+
#### How to use
14+
15+
<div align="center">
16+
<img src="http://i.imgur.com/gfqLV3f.png"><br><br>
17+
</div>
18+
19+
```python
20+
transformer(U, theta, downsample_factor=1)
21+
```
22+
23+
#### Parameters
24+
25+
U : float
26+
The output of a convolutional net should have the
27+
shape [num_batch, height, width, num_channels].
28+
theta: float
29+
The output of the
30+
localisation network should be [num_batch, 6].
31+
downsample_factor : float
32+
A value of 1 will keep the original size of the image
33+
Values larger than 1 will downsample the image.
34+
Values below 1 will upsample the image
35+
example image: height = 100, width = 200
36+
downsample_factor = 2
37+
output image will then be 50, 100
38+
39+
40+
#### Notes
41+
To initialize the network to the identity transform init ``theta`` to :
42+
43+
```python
44+
identity = np.array([[1., 0., 0.],
45+
[0., 1., 0.]])
46+
identity = identity.flatten()
47+
theta = tf.Variable(initial_value=identity)
48+
```
49+
50+
#### Experiments
51+
52+
<div align="center">
53+
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
54+
</div>
55+
56+
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
57+
58+
All experiments were run in Tensorflow 0.7.
59+
60+
### References
61+
62+
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
63+
64+
[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py

‎transformer/cluttered_mnist.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import tensorflow as tf
16+
from spatial_transformer import transformer
17+
from scipy import ndimage
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
21+
22+
# %% Load data
23+
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
24+
25+
X_train = mnist_cluttered['X_train']
26+
y_train = mnist_cluttered['y_train']
27+
X_valid = mnist_cluttered['X_valid']
28+
y_valid = mnist_cluttered['y_valid']
29+
X_test = mnist_cluttered['X_test']
30+
y_test = mnist_cluttered['y_test']
31+
32+
# % turn from dense to one hot representation
33+
Y_train = dense_to_one_hot(y_train, n_classes=10)
34+
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
35+
Y_test = dense_to_one_hot(y_test, n_classes=10)
36+
37+
# %% Graph representation of our network
38+
39+
# %% Placeholders for 40x40 resolution
40+
x = tf.placeholder(tf.float32, [None, 1600])
41+
y = tf.placeholder(tf.float32, [None, 10])
42+
43+
# %% Since x is currently [batch, height*width], we need to reshape to a
44+
# 4-D tensor to use it in a convolutional graph. If one component of
45+
# `shape` is the special value -1, the size of that dimension is
46+
# computed so that the total size remains constant. Since we haven't
47+
# defined the batch dimension's shape yet, we use -1 to denote this
48+
# dimension should not change size.
49+
x_tensor = tf.reshape(x, [-1, 40, 40, 1])
50+
51+
# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
52+
# %% Create variables for fully connected layer
53+
W_fc_loc1 = weight_variable([1600, 20])
54+
b_fc_loc1 = bias_variable([20])
55+
56+
W_fc_loc2 = weight_variable([20, 6])
57+
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
58+
initial = initial.astype('float32')
59+
initial = initial.flatten()
60+
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
61+
62+
# %% Define the two layer localisation network
63+
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
64+
# %% We can add dropout for regularizing and to reduce overfitting like so:
65+
keep_prob = tf.placeholder(tf.float32)
66+
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
67+
# %% Second layer
68+
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
69+
70+
# %% We'll create a spatial transformer module to identify discriminative patches
71+
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
72+
73+
# %% We'll setup the first convolutional layer
74+
# Weight matrix is [height x width x input_channels x output_channels]
75+
filter_size = 3
76+
n_filters_1 = 16
77+
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
78+
79+
# %% Bias is [output_channels]
80+
b_conv1 = bias_variable([n_filters_1])
81+
82+
# %% Now we can build a graph which does the first layer of convolution:
83+
# we define our stride as batch x height x width x channels
84+
# instead of pooling, we use strides of 2 and more layers
85+
# with smaller filters.
86+
87+
h_conv1 = tf.nn.relu(
88+
tf.nn.conv2d(input=h_trans,
89+
filter=W_conv1,
90+
strides=[1, 2, 2, 1],
91+
padding='SAME') +
92+
b_conv1)
93+
94+
# %% And just like the first layer, add additional layers to create
95+
# a deep net
96+
n_filters_2 = 16
97+
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
98+
b_conv2 = bias_variable([n_filters_2])
99+
h_conv2 = tf.nn.relu(
100+
tf.nn.conv2d(input=h_conv1,
101+
filter=W_conv2,
102+
strides=[1, 2, 2, 1],
103+
padding='SAME') +
104+
b_conv2)
105+
106+
# %% We'll now reshape so we can connect to a fully-connected layer:
107+
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
108+
109+
# %% Create a fully-connected layer:
110+
n_fc = 1024
111+
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
112+
b_fc1 = bias_variable([n_fc])
113+
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
114+
115+
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
116+
117+
# %% And finally our softmax layer:
118+
W_fc2 = weight_variable([n_fc, 10])
119+
b_fc2 = bias_variable([10])
120+
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
121+
122+
# %% Define loss/eval/training functions
123+
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
124+
opt = tf.train.AdamOptimizer()
125+
optimizer = opt.minimize(cross_entropy)
126+
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
127+
128+
# %% Monitor accuracy
129+
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
130+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
131+
132+
# %% We now create a new session to actually perform the initialization the
133+
# variables:
134+
sess = tf.Session()
135+
sess.run(tf.initialize_all_variables())
136+
137+
138+
# %% We'll now train in minibatches and report accuracy, loss:
139+
iter_per_epoch = 100
140+
n_epochs = 500
141+
train_size = 10000
142+
143+
indices = np.linspace(0,10000 - 1,iter_per_epoch)
144+
indices = indices.astype('int')
145+
146+
for epoch_i in range(n_epochs):
147+
for iter_i in range(iter_per_epoch - 1):
148+
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
149+
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
150+
151+
if iter_i % 10 == 0:
152+
loss = sess.run(cross_entropy,
153+
feed_dict={
154+
x: batch_xs,
155+
y: batch_ys,
156+
keep_prob: 1.0
157+
})
158+
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
159+
160+
sess.run(optimizer, feed_dict={
161+
x: batch_xs, y: batch_ys, keep_prob: 0.8})
162+
163+
164+
print('Accuracy: ' + str(sess.run(accuracy,
165+
feed_dict={
166+
x: X_valid,
167+
y: Y_valid,
168+
keep_prob: 1.0
169+
})))
170+
#theta = sess.run(h_fc_loc2, feed_dict={
171+
# x: batch_xs, keep_prob: 1.0})
172+
#print(theta[0])

‎transformer/data/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
### How to get the data
2+
3+
#### Cluttered MNIST
4+
5+
The cluttered MNIST dataset can be found here [1] or can be generated via [2].
6+
7+
Settings used for `cluttered_mnist.py` :
8+
9+
```python
10+
11+
ORG_SHP = [28, 28]
12+
OUT_SHP = [40, 40]
13+
NUM_DISTORTIONS = 8
14+
dist_size = (5, 5)
15+
16+
```
17+
18+
[1] https://github.com/daviddao/spatial-transformer-tensorflow
19+
20+
[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py

‎transformer/example.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import tensorflow as tf
16+
from spatial_transformer import transformer
17+
from scipy import ndimage
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
from tf_utils import conv2d, linear, weight_variable, bias_variable
21+
22+
# %% Create a batch of three images (1600 x 1200)
23+
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
24+
im = ndimage.imread('cat.jpg')
25+
im = im / 255.
26+
im = im.reshape(1, 1200, 1600, 3)
27+
im = im.astype('float32')
28+
29+
# %% Simulate batch
30+
batch = np.append(im, im, axis=0)
31+
batch = np.append(batch, im, axis=0)
32+
num_batch = 3
33+
34+
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
35+
x = tf.cast(batch,'float32')
36+
37+
# %% Create localisation network and convolutional layer
38+
with tf.variable_scope('spatial_transformer_0'):
39+
40+
# %% Create a fully-connected layer with 6 output nodes
41+
n_fc = 6
42+
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
43+
44+
# %% Zoom into the image
45+
initial = np.array([[0.5,0, 0],[0,0.5,0]])
46+
initial = initial.astype('float32')
47+
initial = initial.flatten()
48+
49+
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
50+
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
51+
h_trans = transformer(x, h_fc1, downsample_factor=2)
52+
53+
# %% Run session
54+
sess = tf.Session()
55+
sess.run(tf.initialize_all_variables())
56+
y = sess.run(h_trans, feed_dict={x: batch})
57+
58+
# plt.imshow(y[0])

0 commit comments

Comments
 (0)
Please sign in to comment.