@@ -760,7 +760,7 @@ def _make_predict_function(self):
760
760
def _fit_loop (self , f , ins , out_labels = [], batch_size = 32 ,
761
761
nb_epoch = 100 , verbose = 1 , callbacks = [],
762
762
val_f = None , val_ins = None , shuffle = True ,
763
- callback_metrics = []):
763
+ callback_metrics = [], initial_epoch = 0 ):
764
764
'''Abstract fit function for f(ins).
765
765
Assume that f returns a list, labeled by out_labels.
766
766
@@ -780,6 +780,8 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
780
780
passed to the callbacks. They should be the
781
781
concatenation of list the display names of the outputs of
782
782
`f` and the list of display names of the outputs of `f_val`.
783
+ initial_epoch: epoch at which to start training
784
+ (useful for resuming a previous training run)
783
785
784
786
# Returns
785
787
`History` object.
@@ -820,7 +822,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32,
820
822
callback_model .stop_training = False
821
823
self .validation_data = val_ins
822
824
823
- for epoch in range (nb_epoch ):
825
+ for epoch in range (initial_epoch , nb_epoch ):
824
826
callbacks .on_epoch_begin (epoch )
825
827
if shuffle == 'batch' :
826
828
index_array = batch_shuffle (index_array , batch_size )
@@ -1007,7 +1009,7 @@ def _standardize_user_data(self, x, y,
1007
1009
1008
1010
def fit (self , x , y , batch_size = 32 , nb_epoch = 10 , verbose = 1 , callbacks = [],
1009
1011
validation_split = 0. , validation_data = None , shuffle = True ,
1010
- class_weight = None , sample_weight = None ):
1012
+ class_weight = None , sample_weight = None , initial_epoch = 0 ):
1011
1013
'''Trains the model for a fixed number of epochs (iterations on a dataset).
1012
1014
1013
1015
# Arguments
@@ -1044,6 +1046,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
1044
1046
with shape (samples, sequence_length),
1045
1047
to apply a different weight to every timestep of every sample.
1046
1048
In this case you should make sure to specify sample_weight_mode="temporal" in compile().
1049
+ initial_epoch: epoch at which to start training
1050
+ (useful for resuming a previous training run)
1047
1051
1048
1052
1049
1053
# Returns
@@ -1127,7 +1131,8 @@ def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
1127
1131
batch_size = batch_size , nb_epoch = nb_epoch ,
1128
1132
verbose = verbose , callbacks = callbacks ,
1129
1133
val_f = val_f , val_ins = val_ins , shuffle = shuffle ,
1130
- callback_metrics = callback_metrics )
1134
+ callback_metrics = callback_metrics ,
1135
+ initial_epoch = initial_epoch )
1131
1136
1132
1137
def evaluate (self , x , y , batch_size = 32 , verbose = 1 , sample_weight = None ):
1133
1138
'''Returns the loss value and metrics values for the model
@@ -1303,7 +1308,8 @@ def predict_on_batch(self, x):
1303
1308
def fit_generator (self , generator , samples_per_epoch , nb_epoch ,
1304
1309
verbose = 1 , callbacks = [],
1305
1310
validation_data = None , nb_val_samples = None ,
1306
- class_weight = {}, max_q_size = 10 , nb_worker = 1 , pickle_safe = False ):
1311
+ class_weight = {}, max_q_size = 10 , nb_worker = 1 , pickle_safe = False ,
1312
+ initial_epoch = 0 ):
1307
1313
'''Fits the model on data generated batch-by-batch by
1308
1314
a Python generator.
1309
1315
The generator is run in parallel to the model, for efficiency.
@@ -1339,6 +1345,8 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
1339
1345
this implementation relies on multiprocessing, you should not pass
1340
1346
non picklable arguments to the generator as they can't be passed
1341
1347
easily to children processes.
1348
+ initial_epoch: epoch at which to start training
1349
+ (useful for resuming a previous training run)
1342
1350
1343
1351
# Returns
1344
1352
A `History` object.
@@ -1361,7 +1369,7 @@ def generate_arrays_from_file(path):
1361
1369
```
1362
1370
'''
1363
1371
wait_time = 0.01 # in seconds
1364
- epoch = 0
1372
+ epoch = initial_epoch
1365
1373
1366
1374
do_validation = bool (validation_data )
1367
1375
self ._make_train_function ()
0 commit comments