Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Jan 9, 2024
1 parent bcfd7e5 commit 103b7c7
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions msamp/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,11 @@ def zero_grad(self, set_to_none=False):
def step(self):
"""Performs a single optimization step."""
# Set gradient of master weight.
for i, master_param in enumerate(self.master_weights):
if master_param is not None:
param = self.original_params[i]
for i, param in enumerate(self.original_params):
if self.master_weights[i] is not None:
grad_meta = param._grad_meta
dtype = Dtypes.qtype_to_dtype[grad_meta.qtype]
master_param[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta)
self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta)
param.grad = None

# call step() to update master weight
Expand Down

0 comments on commit 103b7c7

Please sign in to comment.