diff --git a/training/src/eval.py b/training/src/eval.py index 161a23c89..d4fbca82e 100644 --- a/training/src/eval.py +++ b/training/src/eval.py @@ -12,7 +12,7 @@ Trainer, seed_everything, ) -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import Logger from src.utils import utils @@ -97,7 +97,7 @@ def evaluate(config: DictConfig) -> None: callbacks.append(hydra.utils.instantiate(cb_conf)) # Init Lightning loggers - logger: List[LightningLoggerBase] = [] + logger: List[Logger] = [] if "logger" in config: for _, lg_conf in config["logger"].items(): if lg_conf is not None and "_target_" in lg_conf: diff --git a/training/src/train.py b/training/src/train.py index 8c92413e4..424c38b35 100644 --- a/training/src/train.py +++ b/training/src/train.py @@ -10,7 +10,7 @@ Trainer, seed_everything, ) -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import Logger from src.utils import utils @@ -59,7 +59,7 @@ def train(config: DictConfig) -> Optional[float]: callbacks.append(hydra.utils.instantiate(cb_conf)) # Init lightning loggers - logger: List[LightningLoggerBase] = [] + logger: List[Logger] = [] if "logger" in config: for _, lg_conf in config.logger.items(): if lg_conf is not None and "_target_" in lg_conf: diff --git a/training/src/utils/utils.py b/training/src/utils/utils.py index 32e64ab4d..96c822576 100644 --- a/training/src/utils/utils.py +++ b/training/src/utils/utils.py @@ -134,7 +134,7 @@ def finish( datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], - logger: List[pl.loggers.LightningLoggerBase], + logger: List[pl.loggers.Logger], ) -> None: """Makes sure everything closed properly."""