From b1ec68ae05463843bce0cc1271491929f6f18fd0 Mon Sep 17 00:00:00 2001 From: Leiay Date: Mon, 8 Aug 2022 16:19:00 -0500 Subject: [PATCH 1/2] fix sgdw weight decay bug --- torch_optimizer/sgdw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_optimizer/sgdw.py b/torch_optimizer/sgdw.py index a255930..fb38061 100644 --- a/torch_optimizer/sgdw.py +++ b/torch_optimizer/sgdw.py @@ -118,5 +118,5 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: # Apply weight decay if weight_decay != 0: - p.data.add_(weight_decay, alpha=-group['lr']) + p.data.mul_(1 - group['lr'] * weight_decay) return loss From 45e1711aa606d4f760e4ac1b05e7cceedc236eae Mon Sep 17 00:00:00 2001 From: Leiay Date: Tue, 9 Aug 2022 19:41:55 -0500 Subject: [PATCH 2/2] change sgdw weight decay place --- torch_optimizer/sgdw.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_optimizer/sgdw.py b/torch_optimizer/sgdw.py index fb38061..406a76f 100644 --- a/torch_optimizer/sgdw.py +++ b/torch_optimizer/sgdw.py @@ -113,10 +113,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: else: d_p = buf - # Apply momentum - p.data.add_(d_p, alpha=-group['lr']) - # Apply weight decay if weight_decay != 0: p.data.mul_(1 - group['lr'] * weight_decay) + + # Apply momentum + p.data.add_(d_p, alpha=-group['lr']) return loss