diff --git a/setup.py b/setup.py index 0981222..40aea42 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if __name__ == "__main__": setup( name="tez", - version="0.1.7", + version="0.1.8", description="tez - a simple pytorch trainer", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tez/model/model.py b/tez/model/model.py index 474cc85..8708aaa 100644 --- a/tez/model/model.py +++ b/tez/model/model.py @@ -158,28 +158,28 @@ def model_fn(self, data): def train_one_step(self, data): if self.accumulation_steps == 1 and self.batch_index == 0: - self.optimizer.zero_grad() + self.zero_grad() _, loss, metrics = self.model_fn(data) + loss = loss / self.accumulation_steps + if self.fp16: + self.scaler.scale(loss).backward() + else: + loss.backward() if (self.batch_index + 1) % self.accumulation_steps == 0: - with torch.set_grad_enabled(True): - if self.fp16: - with torch.cuda.amp.autocast(): - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - else: - loss.backward() - self.optimizer.step() - if self.scheduler: - if self.step_scheduler_after == "batch": - if self.step_scheduler_metric is None: - self.scheduler.step() - else: - step_metric = self.name_to_metric(self.step_scheduler_metric) - self.scheduler.step(step_metric) - + if self.fp16: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + if self.scheduler: + if self.step_scheduler_after == "batch": + if self.step_scheduler_metric is None: + self.scheduler.step() + else: + step_metric = self.name_to_metric(self.step_scheduler_metric) + self.scheduler.step(step_metric) if self.batch_index > 0: - self.optimizer.zero_grad() + self.zero_grad() return loss, metrics def validate_one_step(self, data):