diff --git a/keras/engine/training.py b/keras/engine/training.py index 69faaa41e54e..1458ccb4f384 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -760,7 +760,7 @@ def _make_predict_function(self): def _fit_loop(self, f, ins, out_labels=[], batch_size=32, nb_epoch=100, verbose=1, callbacks=[], val_f=None, val_ins=None, shuffle=True, - callback_metrics=[]): + callback_metrics=[], initial_epoch=0): '''Abstract fit function for f(ins). Assume that f returns a list, labeled by out_labels. @@ -780,6 +780,8 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. + initial_epoch: epoch at which to start training + (useful for resuming a previous training run) # Returns `History` object. @@ -820,7 +822,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, callback_model.stop_training = False self.validation_data = val_ins - for epoch in range(nb_epoch): + for epoch in range(initial_epoch, nb_epoch): callbacks.on_epoch_begin(epoch) if shuffle == 'batch': index_array = batch_shuffle(index_array, batch_size) @@ -1007,7 +1009,7 @@ def _standardize_user_data(self, x, y, def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0., validation_data=None, shuffle=True, - class_weight=None, sample_weight=None): + class_weight=None, sample_weight=None, initial_epoch=0): '''Trains the model for a fixed number of epochs (iterations on a dataset). # Arguments @@ -1044,6 +1046,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). + initial_epoch: epoch at which to start training + (useful for resuming a previous training run) # Returns @@ -1127,7 +1131,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, callbacks=callbacks, val_f=val_f, val_ins=val_ins, shuffle=shuffle, - callback_metrics=callback_metrics) + callback_metrics=callback_metrics, + initial_epoch=initial_epoch) def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): '''Returns the loss value and metrics values for the model @@ -1303,7 +1308,8 @@ def predict_on_batch(self, x): def fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, - class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False): + class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False, + initial_epoch=0): '''Fits the model on data generated batch-by-batch by a Python generator. The generator is run in parallel to the model, for efficiency. @@ -1339,6 +1345,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch, this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes. + initial_epoch: epoch at which to start training + (useful for resuming a previous training run) # Returns A `History` object. @@ -1361,7 +1369,7 @@ def generate_arrays_from_file(path): ``` ''' wait_time = 0.01 # in seconds - epoch = 0 + epoch = initial_epoch do_validation = bool(validation_data) self._make_train_function() diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index 9280b19cf7d8..f529b53ca39d 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -8,6 +8,7 @@ from keras.models import Sequential from keras import backend as K from keras.utils.test_utils import keras_test +from keras.callbacks import LambdaCallback @keras_test @@ -146,6 +147,28 @@ def test_model_methods(): [output_a_np, output_b_np]) assert len(out) == 4 + # test starting from non-zero initial epoch + trained_epochs = [] + + def on_epoch_begin(epoch, logs): + trained_epochs.append(epoch) + tracker_cb = LambdaCallback(on_epoch_begin=on_epoch_begin) + out = model.fit([input_a_np, input_b_np], + [output_a_np, output_b_np], nb_epoch=5, batch_size=4, + initial_epoch=2, callbacks=[tracker_cb]) + assert trained_epochs == [2, 3, 4] + + # test starting from non-zero initial epoch for generator too + trained_epochs = [] + + def gen_data(batch_sz): + while True: + yield ([np.random.random((batch_sz, 3)), np.random.random((batch_sz, 3))], + [np.random.random((batch_sz, 4)), np.random.random((batch_sz, 3))]) + out = model.fit_generator(gen_data(4), samples_per_epoch=10, nb_epoch=5, + initial_epoch=2, callbacks=[tracker_cb]) + assert trained_epochs == [2, 3, 4] + # test with a custom metric function mse = lambda y_true, y_pred: K.mean(K.pow(y_true - y_pred, 2))