diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index 5a84c941..ae906efd 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.nn.parameter import Parameter from torch.optim import Adam -from torch.optim.optimizer import params_t +from torch.optim.optimizer import ParamsT from sparse_autoencoder.tensor_types import Axis @@ -35,7 +35,7 @@ class AdamWithReset(Adam): def __init__( # (extending existing implementation) self, - params: params_t, + params: ParamsT, lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8,