Skip to content

Commit f31a1cd

Browse files
committed
add manifold learning
1 parent 3b44023 commit f31a1cd

File tree

6 files changed

+37
-11
lines changed

6 files changed

+37
-11
lines changed

Diff for: examples/mnist_cnn.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.layers import Dense, Dropout, Flatten
1313
from keras.layers import Conv2D, MaxPooling2D
1414
from keras import backend as K
15+
from geomstats.hypersphere import Hypersphere
1516

1617
batch_size = 128
1718
num_classes = 10
@@ -47,8 +48,9 @@
4748
model = Sequential()
4849
model.add(Conv2D(32, kernel_size=(3, 3),
4950
activation='relu',
50-
input_shape=input_shape))
51-
model.add(Conv2D(64, (3, 3), activation='relu'))
51+
input_shape=input_shape,
52+
kernel_manifold=Hypersphere(dimension=32*3*3)))
53+
model.add(Conv2D(64, (3, 3), activation='relu', kernel_manifold=Hypersphere(dimension=64*3*3)))
5254
model.add(MaxPooling2D(pool_size=(2, 2)))
5355
model.add(Dropout(0.25))
5456
model.add(Flatten())
@@ -57,7 +59,7 @@
5759
model.add(Dense(num_classes, activation='softmax'))
5860

5961
model.compile(loss=keras.losses.categorical_crossentropy,
60-
optimizer=keras.optimizers.Adadelta(),
62+
optimizer=keras.optimizers.SGD(),
6163
metrics=['accuracy'])
6264

6365
model.fit(x_train, y_train,

Diff for: keras/backend/tensorflow_backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def to_dense(tensor):
356356
name_scope = tf.name_scope
357357

358358

359-
def variable(value, dtype=None, name=None, constraint=None):
359+
def variable(value, dtype=None, name=None, constraint=None, manifold=None):
360360
"""Instantiates a variable and returns it.
361361
362362
# Arguments
@@ -406,6 +406,8 @@ def variable(value, dtype=None, name=None, constraint=None):
406406
v.constraint = constraint
407407
except AttributeError:
408408
v._constraint = constraint
409+
410+
v._manifold = manifold
409411
return v
410412

411413

Diff for: keras/engine/base_layer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def add_weight(self,
220220
initializer=None,
221221
regularizer=None,
222222
trainable=True,
223-
constraint=None):
223+
constraint=None,
224+
manifold=None):
224225
"""Adds a weight variable to the layer.
225226
226227
# Arguments
@@ -243,7 +244,8 @@ def add_weight(self,
243244
weight = K.variable(initializer(shape),
244245
dtype=dtype,
245246
name=name,
246-
constraint=constraint)
247+
constraint=constraint,
248+
manifold=manifold)
247249
if regularizer is not None:
248250
with K.name_scope('weight_regularizer'):
249251
self.add_loss(regularizer(weight))

Diff for: keras/layers/convolutional.py

+6
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, rank,
9999
bias_regularizer=None,
100100
activity_regularizer=None,
101101
kernel_constraint=None,
102+
kernel_manifold=None,
102103
bias_constraint=None,
103104
**kwargs):
104105
super(_Conv, self).__init__(**kwargs)
@@ -114,6 +115,7 @@ def __init__(self, rank,
114115
self.kernel_initializer = initializers.get(kernel_initializer)
115116
self.bias_initializer = initializers.get(bias_initializer)
116117
self.kernel_regularizer = regularizers.get(kernel_regularizer)
118+
self.kernel_manifold = kernel_manifold
117119
self.bias_regularizer = regularizers.get(bias_regularizer)
118120
self.activity_regularizer = regularizers.get(activity_regularizer)
119121
self.kernel_constraint = constraints.get(kernel_constraint)
@@ -134,6 +136,7 @@ def build(self, input_shape):
134136
self.kernel = self.add_weight(shape=kernel_shape,
135137
initializer=self.kernel_initializer,
136138
name='kernel',
139+
manifold=self.kernel_manifold,
137140
regularizer=self.kernel_regularizer,
138141
constraint=self.kernel_constraint)
139142
if self.use_bias:
@@ -443,6 +446,7 @@ def __init__(self, filters,
443446
bias_regularizer=None,
444447
activity_regularizer=None,
445448
kernel_constraint=None,
449+
kernel_manifold=None,
446450
bias_constraint=None,
447451
**kwargs):
448452
super(Conv2D, self).__init__(
@@ -461,9 +465,11 @@ def __init__(self, filters,
461465
bias_regularizer=bias_regularizer,
462466
activity_regularizer=activity_regularizer,
463467
kernel_constraint=kernel_constraint,
468+
kernel_manifold=kernel_manifold,
464469
bias_constraint=bias_constraint,
465470
**kwargs)
466471
self.input_spec = InputSpec(ndim=4)
472+
print(kernel_manifold)
467473

468474
def get_config(self):
469475
config = super(Conv2D, self).get_config()

Diff for: keras/layers/core.py

+3
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ def __init__(self, units,
835835
bias_regularizer=None,
836836
activity_regularizer=None,
837837
kernel_constraint=None,
838+
kernel_manifold=None,
838839
bias_constraint=None,
839840
**kwargs):
840841
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
@@ -849,6 +850,7 @@ def __init__(self, units,
849850
self.bias_regularizer = regularizers.get(bias_regularizer)
850851
self.activity_regularizer = regularizers.get(activity_regularizer)
851852
self.kernel_constraint = constraints.get(kernel_constraint)
853+
self.kernel_manifold = kernel_manifold
852854
self.bias_constraint = constraints.get(bias_constraint)
853855
self.input_spec = InputSpec(min_ndim=2)
854856
self.supports_masking = True
@@ -860,6 +862,7 @@ def build(self, input_shape):
860862
self.kernel = self.add_weight(shape=(input_dim, self.units),
861863
initializer=self.kernel_initializer,
862864
name='kernel',
865+
manifold=self.kernel_manifold,
863866
regularizer=self.kernel_regularizer,
864867
constraint=self.kernel_constraint)
865868
if self.use_bias:

Diff for: keras/optimizers.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""Built-in optimizer classes.
2-
"""
1+
32
from __future__ import absolute_import
43
from __future__ import division
54
from __future__ import print_function
@@ -178,15 +177,26 @@ def get_updates(self, loss, params):
178177
lr *= (1. / (1. + self.decay * K.cast(self.iterations,
179178
K.dtype(self.decay))))
180179
# momentum
180+
# TODO(johmathe): Add back nesterov.
181181
shapes = [K.int_shape(p) for p in params]
182182
moments = [K.zeros(shape) for shape in shapes]
183183
self.weights = [self.iterations] + moments
184184
for p, g, m in zip(params, grads, moments):
185-
v = self.momentum * m - lr * g # velocity
185+
v = self.momentum * m - lr * g
186186
self.updates.append(K.update(m, v))
187187

188-
if self.nesterov:
189-
new_p = p + self.momentum * v - lr * g
188+
# Do the gradient descent on the manifold if present.
189+
if getattr(p, 'manifold', None) is not None:
190+
print('MANIF')
191+
shape = K.shape(v)
192+
manifold = p.manifold
193+
v_shaped = K.reshape(v, (manifold.dimension, -1))
194+
p_shaped = K.reshape(p, (manifold.dimension, -1))
195+
tangent_v = manifold.projection_to_tangent_space(
196+
vector=v_shaped, base_point=p_shaped)
197+
destination = manifold.metric.exp(
198+
base_point=p_shaped, vector=tangent_v)
199+
new_p = K.reshape(destination, shape, name='new_p')
190200
else:
191201
new_p = p + v
192202

@@ -253,6 +263,7 @@ def get_updates(self, loss, params):
253263

254264
for p, g, a in zip(params, grads, accumulators):
255265
# update accumulator
266+
256267
new_a = self.rho * a + (1. - self.rho) * K.square(g)
257268
self.updates.append(K.update(a, new_a))
258269
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)

0 commit comments

Comments
 (0)