diff --git a/torch_optimizer/sgdw.py b/torch_optimizer/sgdw.py index a255930..406a76f 100644 --- a/torch_optimizer/sgdw.py +++ b/torch_optimizer/sgdw.py @@ -113,10 +113,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: else: d_p = buf - # Apply momentum - p.data.add_(d_p, alpha=-group['lr']) - # Apply weight decay if weight_decay != 0: - p.data.add_(weight_decay, alpha=-group['lr']) + p.data.mul_(1 - group['lr'] * weight_decay) + + # Apply momentum + p.data.add_(d_p, alpha=-group['lr']) return loss