diff --git a/src/matgl/utils/training.py b/src/matgl/utils/training.py index 17953e12..27e15283 100644 --- a/src/matgl/utils/training.py +++ b/src/matgl/utils/training.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: import dgl import numpy as np - from torch.optim import Optimizer + from torch.optim import LRScheduler, Optimizer class MatglLightningModuleMixin: @@ -147,7 +147,7 @@ def __init__( data_std: float = 1.0, loss: str = "mse_loss", optimizer: Optimizer | None = None, - scheduler=None, + scheduler: LRScheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01, @@ -270,7 +270,7 @@ def __init__( loss: str = "mse_loss", loss_params: dict | None = None, optimizer: Optimizer | None = None, - scheduler=None, + scheduler: LRScheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01,