forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request keras-team#5 from fchollet/master
update
- Loading branch information
Showing
60 changed files
with
3,210 additions
and
488 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
'''The example demonstrates how to write custom layers for Keras. | ||
We build a custom activation layer called 'Antirectifier', | ||
which modifies the shape of the tensor that passes through it. | ||
We need to specify two methods: `output_shape` and `get_output`. | ||
Note that the same result can also be achieved via a Lambda layer. | ||
Because our custom layer is written with primitives from the Keras | ||
backend (`K`), our code can run both on TensorFlow and Theano. | ||
''' | ||
|
||
from __future__ import print_function | ||
import numpy as np | ||
from keras.models import Sequential | ||
from keras.layers.core import Dense, Dropout, Layer, Activation | ||
from keras.datasets import mnist | ||
from keras import backend as K | ||
from keras.utils import np_utils | ||
|
||
|
||
class Antirectifier(Layer): | ||
'''This is the combination of a sample-wise | ||
L2 normalization with the concatenation of the | ||
positive part of the input with the negative part | ||
of the input. The result is a tensor of samples that are | ||
twice as large as the input samples. | ||
It can be used in place of a ReLU. | ||
# Input shape | ||
2D tensor of shape (samples, n) | ||
# Output shape | ||
2D tensor of shape (samples, 2*n) | ||
# Theoretical justification | ||
When applying ReLU, assuming that the distribution | ||
of the previous output is approximately centered around 0., | ||
you are discarding half of your input. This is inefficient. | ||
Antirectifier allows to return all-positive outputs like ReLU, | ||
without discarding any data. | ||
Tests on MNIST show that Antirectifier allows to train networks | ||
with twice less parameters yet with comparable | ||
classification accuracy as an equivalent ReLU-based network. | ||
''' | ||
@property | ||
def output_shape(self): | ||
shape = list(self.input_shape) | ||
assert len(shape) == 2 # only valid for 2D tensors | ||
shape[-1] *= 2 | ||
return tuple(shape) | ||
|
||
def get_output(self, train): | ||
x = self.get_input(train) | ||
x -= K.mean(x, axis=1, keepdims=True) | ||
x = K.l2_normalize(x, axis=1) | ||
pos = K.relu(x) | ||
neg = K.relu(-x) | ||
return K.concatenate([pos, neg], axis=1) | ||
|
||
# global parameters | ||
batch_size = 128 | ||
nb_classes = 10 | ||
nb_epoch = 40 | ||
|
||
# the data, shuffled and split between tran and test sets | ||
(X_train, y_train), (X_test, y_test) = mnist.load_data() | ||
|
||
X_train = X_train.reshape(60000, 784) | ||
X_test = X_test.reshape(10000, 784) | ||
X_train = X_train.astype('float32') | ||
X_test = X_test.astype('float32') | ||
X_train /= 255 | ||
X_test /= 255 | ||
print(X_train.shape[0], 'train samples') | ||
print(X_test.shape[0], 'test samples') | ||
|
||
# convert class vectors to binary class matrices | ||
Y_train = np_utils.to_categorical(y_train, nb_classes) | ||
Y_test = np_utils.to_categorical(y_test, nb_classes) | ||
|
||
# build the model | ||
model = Sequential() | ||
model.add(Dense(256, input_shape=(784,))) | ||
model.add(Antirectifier()) | ||
model.add(Dropout(0.1)) | ||
model.add(Dense(256)) | ||
model.add(Antirectifier()) | ||
model.add(Dropout(0.1)) | ||
model.add(Dense(10)) | ||
model.add(Activation('softmax')) | ||
|
||
# compile the model | ||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') | ||
|
||
# train the model | ||
model.fit(X_train, Y_train, | ||
batch_size=batch_size, nb_epoch=nb_epoch, | ||
show_accuracy=True, verbose=1, | ||
validation_data=(X_test, Y_test)) | ||
|
||
# next, compare with an equivalent network | ||
# with2x bigger Dense layers and ReLU |
Oops, something went wrong.