diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index f82ee5ca2c..27a8779e39 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -166,7 +166,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs): y = self.target_tensors # Avoids Tensorflow overhead that happens at every epoch, by putting multiple steps in an epoch - steps_per_epoch = self.determine_steps_per_epoch(epochs) + steps_per_epoch = self._determine_steps_per_epoch(epochs) for k, v in x_params.items(): x_params[k] = tf.repeat(v, steps_per_epoch, axis=0) @@ -178,20 +178,19 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs): loss_dict = history.history return loss_dict - def determine_steps_per_epoch(self, epochs): - num_replicas = self.output_shape[0][0] - # in this case we're most likely running on the CPU and this is not worth it - if num_replicas == 1: + def _determine_steps_per_epoch(self, epochs): + """Determine how many step to run in every epoch. + When running a single replica (CPU) or when the number of epochs is < 100 default to 1. + Otherwise run 100 step per epoch. + + If the number of epochs requested is not divisible by 100 there will be a number + of extra training epochs being run equal to max_epochs % 100 in the worst case. + + """ + num_replicas = self.output_shape[0] + if num_replicas == 1 or epochs < 100: return 1 - # On the GPU, run with - for divisor in [10, 100]: - if epochs % divisor != 0: - steps_per_epoch = divisor // 10 - log.warning( - f"Epochs {epochs} not divisible by {divisor}, using {steps_per_epoch} steps per epoch" - ) - return steps_per_epoch return 100 def predict(self, x=None, **kwargs):