diff --git a/model.py b/model.py index c698f8b601..5672607a8c 100644 --- a/model.py +++ b/model.py @@ -279,7 +279,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == 'cuda' + use_fused = fused_available and str(device_type) == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) print(f"using fused AdamW: {use_fused}")