diff --git a/gearai/ga/tf/keras/__init__.py b/gearai/ga/tf/keras/__init__.py index 7de70f3..c1595d0 100644 --- a/gearai/ga/tf/keras/__init__.py +++ b/gearai/ga/tf/keras/__init__.py @@ -38,10 +38,10 @@ def train_step(self,data): # ACCUMULATING THE BATCH GRADIENTS for i in range(len(self.grad_acc)): - self.gradient_acc[i].assign_add(gradients[i]) + self.grad_acc[i].assign_add(gradients[i]) - tf.cond(tf.equal(self.n_acum_step, self.n_grads), self.apply_accu_gradients, lambda: None) + tf.cond(tf.equal(self.n_acum_step, self.no_grads), self.apply_accu_gradients, lambda: None) # UPDATING THE METRICS self.compiled_metrics.update_state(y, y_pred)