Skip to content

Commit 06cc6d7

Browse files
kencokenfchollet
authored andcommitted
Add initial epoch argument to fit functions (keras-team#4429)
* Added initial_epoch argument to fit functions in trainer * Added unit test * PEP8 fixes
1 parent 97484ec commit 06cc6d7

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

keras/engine/training.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def _make_predict_function(self):
760760
def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
761761
nb_epoch=100, verbose=1, callbacks=[],
762762
val_f=None, val_ins=None, shuffle=True,
763-
callback_metrics=[]):
763+
callback_metrics=[], initial_epoch=0):
764764
'''Abstract fit function for f(ins).
765765
Assume that f returns a list, labeled by out_labels.
766766
@@ -780,6 +780,8 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
780780
passed to the callbacks. They should be the
781781
concatenation of list the display names of the outputs of
782782
`f` and the list of display names of the outputs of `f_val`.
783+
initial_epoch: epoch at which to start training
784+
(useful for resuming a previous training run)
783785
784786
# Returns
785787
`History` object.
@@ -820,7 +822,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
820822
callback_model.stop_training = False
821823
self.validation_data = val_ins
822824

823-
for epoch in range(nb_epoch):
825+
for epoch in range(initial_epoch, nb_epoch):
824826
callbacks.on_epoch_begin(epoch)
825827
if shuffle == 'batch':
826828
index_array = batch_shuffle(index_array, batch_size)
@@ -1007,7 +1009,7 @@ def _standardize_user_data(self, x, y,
10071009

10081010
def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
10091011
validation_split=0., validation_data=None, shuffle=True,
1010-
class_weight=None, sample_weight=None):
1012+
class_weight=None, sample_weight=None, initial_epoch=0):
10111013
'''Trains the model for a fixed number of epochs (iterations on a dataset).
10121014
10131015
# Arguments
@@ -1044,6 +1046,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
10441046
with shape (samples, sequence_length),
10451047
to apply a different weight to every timestep of every sample.
10461048
In this case you should make sure to specify sample_weight_mode="temporal" in compile().
1049+
initial_epoch: epoch at which to start training
1050+
(useful for resuming a previous training run)
10471051
10481052
10491053
# Returns
@@ -1127,7 +1131,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
11271131
batch_size=batch_size, nb_epoch=nb_epoch,
11281132
verbose=verbose, callbacks=callbacks,
11291133
val_f=val_f, val_ins=val_ins, shuffle=shuffle,
1130-
callback_metrics=callback_metrics)
1134+
callback_metrics=callback_metrics,
1135+
initial_epoch=initial_epoch)
11311136

11321137
def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
11331138
'''Returns the loss value and metrics values for the model
@@ -1303,7 +1308,8 @@ def predict_on_batch(self, x):
13031308
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
13041309
verbose=1, callbacks=[],
13051310
validation_data=None, nb_val_samples=None,
1306-
class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False):
1311+
class_weight={}, max_q_size=10, nb_worker=1, pickle_safe=False,
1312+
initial_epoch=0):
13071313
'''Fits the model on data generated batch-by-batch by
13081314
a Python generator.
13091315
The generator is run in parallel to the model, for efficiency.
@@ -1339,6 +1345,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
13391345
this implementation relies on multiprocessing, you should not pass
13401346
non picklable arguments to the generator as they can't be passed
13411347
easily to children processes.
1348+
initial_epoch: epoch at which to start training
1349+
(useful for resuming a previous training run)
13421350
13431351
# Returns
13441352
A `History` object.
@@ -1361,7 +1369,7 @@ def generate_arrays_from_file(path):
13611369
```
13621370
'''
13631371
wait_time = 0.01 # in seconds
1364-
epoch = 0
1372+
epoch = initial_epoch
13651373

13661374
do_validation = bool(validation_data)
13671375
self._make_train_function()

tests/keras/engine/test_training.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.models import Sequential
99
from keras import backend as K
1010
from keras.utils.test_utils import keras_test
11+
from keras.callbacks import LambdaCallback
1112

1213

1314
@keras_test
@@ -146,6 +147,28 @@ def test_model_methods():
146147
[output_a_np, output_b_np])
147148
assert len(out) == 4
148149

150+
# test starting from non-zero initial epoch
151+
trained_epochs = []
152+
153+
def on_epoch_begin(epoch, logs):
154+
trained_epochs.append(epoch)
155+
tracker_cb = LambdaCallback(on_epoch_begin=on_epoch_begin)
156+
out = model.fit([input_a_np, input_b_np],
157+
[output_a_np, output_b_np], nb_epoch=5, batch_size=4,
158+
initial_epoch=2, callbacks=[tracker_cb])
159+
assert trained_epochs == [2, 3, 4]
160+
161+
# test starting from non-zero initial epoch for generator too
162+
trained_epochs = []
163+
164+
def gen_data(batch_sz):
165+
while True:
166+
yield ([np.random.random((batch_sz, 3)), np.random.random((batch_sz, 3))],
167+
[np.random.random((batch_sz, 4)), np.random.random((batch_sz, 3))])
168+
out = model.fit_generator(gen_data(4), samples_per_epoch=10, nb_epoch=5,
169+
initial_epoch=2, callbacks=[tracker_cb])
170+
assert trained_epochs == [2, 3, 4]
171+
149172
# test with a custom metric function
150173
mse = lambda y_true, y_pred: K.mean(K.pow(y_true - y_pred, 2))
151174

0 commit comments

Comments
 (0)