Skip to content

Commit

Permalink
Run non-optimization updates on predict for stateful RNN's
Browse files Browse the repository at this point in the history
  • Loading branch information
mynameisfiber committed Dec 9, 2015
1 parent 81787dd commit 40c48dd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ keras/datasets/temp/*
docs/site/*
docs/theme/*
tags
Keras.egg-info

# test-related
.coverage
Expand Down
5 changes: 5 additions & 0 deletions keras/layers/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def set_weights(self, weights):
self.layers[i].set_weights(weights[:nb_param])
weights = weights[nb_param:]

def reset_states(self):
for layer in self.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
layer.reset_states()

def get_config(self):
return {"name": self.__class__.__name__,
"layers": [layer.get_config() for layer in self.layers]}
Expand Down
2 changes: 1 addition & 1 deletion keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def compile(self, optimizer, loss,

self._train = K.function(train_ins, [train_loss], updates=updates)
self._train_with_acc = K.function(train_ins, [train_loss, train_accuracy], updates=updates)
self._predict = K.function(predict_ins, [self.y_test])
self._predict = K.function(predict_ins, [self.y_test], updates=self.updates)
self._test = K.function(test_ins, [test_loss])
self._test_with_acc = K.function(test_ins, [test_loss, test_accuracy])

Expand Down

0 comments on commit 40c48dd

Please sign in to comment.