Skip to content

Commit

Permalink
fix grad acc
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Aug 16, 2021
1 parent c806e6e commit f5f5609
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if __name__ == "__main__":
setup(
name="tez",
version="0.1.7",
version="0.1.8",
description="tez - a simple pytorch trainer",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
38 changes: 19 additions & 19 deletions tez/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,28 @@ def model_fn(self, data):

def train_one_step(self, data):
if self.accumulation_steps == 1 and self.batch_index == 0:
self.optimizer.zero_grad()
self.zero_grad()
_, loss, metrics = self.model_fn(data)
loss = loss / self.accumulation_steps
if self.fp16:
self.scaler.scale(loss).backward()
else:
loss.backward()
if (self.batch_index + 1) % self.accumulation_steps == 0:
with torch.set_grad_enabled(True):
if self.fp16:
with torch.cuda.amp.autocast():
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
self.optimizer.step()
if self.scheduler:
if self.step_scheduler_after == "batch":
if self.step_scheduler_metric is None:
self.scheduler.step()
else:
step_metric = self.name_to_metric(self.step_scheduler_metric)
self.scheduler.step(step_metric)

if self.fp16:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
if self.scheduler:
if self.step_scheduler_after == "batch":
if self.step_scheduler_metric is None:
self.scheduler.step()
else:
step_metric = self.name_to_metric(self.step_scheduler_metric)
self.scheduler.step(step_metric)
if self.batch_index > 0:
self.optimizer.zero_grad()
self.zero_grad()
return loss, metrics

def validate_one_step(self, data):
Expand Down

0 comments on commit f5f5609

Please sign in to comment.