From a3758eb356bec3373c1e5bd885aa844a72657ed2 Mon Sep 17 00:00:00 2001 From: Sean Jude Lyons <72541401+seanjudelyons@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:47:23 +1100 Subject: [PATCH] added fix to type comparison to enable fused AdamW --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index c698f8b601..5672607a8c 100644 --- a/model.py +++ b/model.py @@ -279,7 +279,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == 'cuda' + use_fused = fused_available and str(device_type) == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) print(f"using fused AdamW: {use_fused}")