diff --git a/torch_optimizer/lookahead.py b/torch_optimizer/lookahead.py index 39abf07..53afe3f 100644 --- a/torch_optimizer/lookahead.py +++ b/torch_optimizer/lookahead.py @@ -57,6 +57,9 @@ def __init__( def _update(self, group: Dict[str, Any]) -> None: for fast in group["params"]: + if not fast.requires_grad: + continue + param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.clone(fast.data).detach()