Skip to content

Commit

Permalink
Add initial epoch argument to fit functions (keras-team#4429)
Browse files Browse the repository at this point in the history
* Added initial_epoch argument to fit functions in trainer

* Added unit test

* PEP8 fixes
  • Loading branch information
kencoken authored and fchollet committed Nov 20, 2016
1 parent 97484ec commit 06cc6d7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
20 changes: 14 additions & 6 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def _make_predict_function(self):
def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
nb_epoch=100, verbose=1, callbacks=[],
val_f=None, val_ins=None, shuffle=True,
callback_metrics=[]):
callback_metrics=[], initial_epoch=0):
'''Abstract fit function for f(ins).
Assume that f returns a list, labeled by out_labels.
Expand All @@ -780,6 +780,8 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
passed to the callbacks. They should be the
concatenation of list the display names of the outputs of
`f` and the list of display names of the outputs of `f_val`.
initial_epoch: epoch at which to start training
(useful for resuming a previous training run)
# Returns
`History` object.
Expand Down Expand Up @@ -820,7 +822,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
callback_model.stop_training = False
self.validation_data = val_ins

for epoch in range(nb_epoch):
for epoch in range(initial_epoch, nb_epoch):
callbacks.on_epoch_begin(epoch)
if shuffle == 'batch':
index_array = batch_shuffle(index_array, batch_size)
Expand Down Expand Up @@ -1007,7 +1009,7 @@ def _standardize_user_data(self, x, y,

def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True,
class_weight=None, sample_weight=None):
class_weight=None, sample_weight=None, initial_epoch=0):
'''Trains the model for a fixed number of epochs (iterations on a dataset).
# Arguments
Expand Down Expand Up @@ -1044,6 +1046,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
with shape (samples, sequence_length),
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify sample_weight_mode="temporal" in compile().
initial_epoch: epoch at which to start training
(useful for resuming a previous training run)
# Returns
Expand Down Expand Up @@ -1127,7 +1131,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
batch_size=batch_size, nb_epoch=nb_epoch,
verbose=verbose, callbacks=callbacks,
val_f=val_f, val_ins=val_ins, shuffle=shuffle,
callback_metrics=callback_metrics)
callback_metrics=callback_metrics,
initial_epoch=initial_epoch)

def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
'''Returns the loss value and metrics values for the model
Expand Down Expand Up @@ -1303,7 +1308,8 @@ def predict_on_batch(self, x):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False):
class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False,
initial_epoch=0):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -1339,6 +1345,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
this implementation relies on multiprocessing, you should not pass
non picklable arguments to the generator as they can't be passed
easily to children processes.
initial_epoch: epoch at which to start training
(useful for resuming a previous training run)
# Returns
A `History` object.
Expand All @@ -1361,7 +1369,7 @@ def generate_arrays_from_file(path):
```
'''
wait_time = 0.01 # in seconds
epoch = 0
epoch = initial_epoch

do_validation = bool(validation_data)
self._make_train_function()
Expand Down
23 changes: 23 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.models import Sequential
from keras import backend as K
from keras.utils.test_utils import keras_test
from keras.callbacks import LambdaCallback


@keras_test
Expand Down Expand Up @@ -146,6 +147,28 @@ def test_model_methods():
[output_a_np, output_b_np])
assert len(out) == 4

# test starting from non-zero initial epoch
trained_epochs = []

def on_epoch_begin(epoch, logs):
trained_epochs.append(epoch)
tracker_cb = LambdaCallback(on_epoch_begin=on_epoch_begin)
out = model.fit([input_a_np, input_b_np],
[output_a_np, output_b_np], nb_epoch=5, batch_size=4,
initial_epoch=2, callbacks=[tracker_cb])
assert trained_epochs == [2, 3, 4]

# test starting from non-zero initial epoch for generator too
trained_epochs = []

def gen_data(batch_sz):
while True:
yield ([np.random.random((batch_sz, 3)), np.random.random((batch_sz, 3))],
[np.random.random((batch_sz, 4)), np.random.random((batch_sz, 3))])
out = model.fit_generator(gen_data(4), samples_per_epoch=10, nb_epoch=5,
initial_epoch=2, callbacks=[tracker_cb])
assert trained_epochs == [2, 3, 4]

# test with a custom metric function
mse = lambda y_true, y_pred: K.mean(K.pow(y_true - y_pred, 2))

Expand Down

0 comments on commit 06cc6d7

Please sign in to comment.