From ec2beb9da35348cf3b7c11f710fbd166b458e9cc Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Mon, 16 Oct 2023 09:43:21 +0200 Subject: [PATCH] Move from `pytorch_lightning` to `lightning` (#3013) --- .../advanced_topics/howto_pytorch_lightning.md.template | 2 +- requirements/requirements-pytorch.txt | 4 +++- src/gluonts/torch/model/d_linear/estimator.py | 2 +- src/gluonts/torch/model/d_linear/lightning_module.py | 2 +- src/gluonts/torch/model/deepar/lightning_module.py | 2 +- src/gluonts/torch/model/estimator.py | 4 ++-- src/gluonts/torch/model/lag_tst/estimator.py | 2 +- src/gluonts/torch/model/lag_tst/lightning_module.py | 2 +- src/gluonts/torch/model/lightning_util.py | 2 +- src/gluonts/torch/model/mqf2/lightning_module.py | 2 +- src/gluonts/torch/model/patch_tst/estimator.py | 2 +- src/gluonts/torch/model/patch_tst/lightning_module.py | 2 +- src/gluonts/torch/model/simple_feedforward/estimator.py | 2 +- .../torch/model/simple_feedforward/lightning_module.py | 2 +- src/gluonts/torch/model/tft/lightning_module.py | 2 +- src/gluonts/torch/model/wavenet/estimator.py | 2 +- src/gluonts/torch/model/wavenet/lightning_module.py | 2 +- 17 files changed, 20 insertions(+), 18 deletions(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 693f070dcf..34e6937b19 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -133,7 +133,7 @@ To train the model using PyTorch Lightning, we only need to extend the class wit ```python -import pytorch_lightning as pl +import lightning.pytorch as pl ``` diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt index 16a40f64a3..03f4e997ab 100644 --- a/requirements/requirements-pytorch.txt +++ b/requirements/requirements-pytorch.txt @@ -1,5 +1,7 @@ torch>=1.9,<3 -pytorch-lightning>=1.5,<3 +lightning>=1.8,<2.2 +# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually +pytorch_lightning>=1.8,<2.2 # Need to pin protobuf (for now) # See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 protobuf~=3.19.0 diff --git a/src/gluonts/torch/model/d_linear/estimator.py b/src/gluonts/torch/model/d_linear/estimator.py index f62429d162..f8dc0453c1 100644 --- a/src/gluonts/torch/model/d_linear/estimator.py +++ b/src/gluonts/torch/model/d_linear/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/d_linear/lightning_module.py b/src/gluonts/torch/model/d_linear/lightning_module.py index 28dccf1b97..bd081b45dd 100644 --- a/src/gluonts/torch/model/d_linear/lightning_module.py +++ b/src/gluonts/torch/model/d_linear/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/deepar/lightning_module.py b/src/gluonts/torch/model/deepar/lightning_module.py index fc676dfab3..8d190e2329 100644 --- a/src/gluonts/torch/model/deepar/lightning_module.py +++ b/src/gluonts/torch/model/deepar/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 003d88f7fa..ba91f0d725 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -15,7 +15,7 @@ import logging import numpy as np -import pytorch_lightning as pl +import lightning.pytorch as pl import torch.nn as nn from gluonts.core.component import validated @@ -217,7 +217,7 @@ def train_model( logger.info( f"Loading best model from {checkpoint.best_model_path}" ) - best_model = training_network.load_from_checkpoint( + best_model = training_network.__class__.load_from_checkpoint( checkpoint.best_model_path ) else: diff --git a/src/gluonts/torch/model/lag_tst/estimator.py b/src/gluonts/torch/model/lag_tst/estimator.py index 330a1ff4b9..96f2b2a603 100644 --- a/src/gluonts/torch/model/lag_tst/estimator.py +++ b/src/gluonts/torch/model/lag_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any, List import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/lag_tst/lightning_module.py b/src/gluonts/torch/model/lag_tst/lightning_module.py index 2510944cfa..5c9e70e9e4 100644 --- a/src/gluonts/torch/model/lag_tst/lightning_module.py +++ b/src/gluonts/torch/model/lag_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/lightning_util.py b/src/gluonts/torch/model/lightning_util.py index 6742c8c7cf..73e2396140 100644 --- a/src/gluonts/torch/model/lightning_util.py +++ b/src/gluonts/torch/model/lightning_util.py @@ -13,7 +13,7 @@ from packaging import version -import pytorch_lightning as pl +import lightning.pytorch as pl def has_validation_loop(trainer: pl.Trainer): diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..16916c3c41 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -13,7 +13,7 @@ from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index de7b880f36..34dfa5dacb 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/patch_tst/lightning_module.py b/src/gluonts/torch/model/patch_tst/lightning_module.py index f5e95158b2..d80137ae05 100644 --- a/src/gluonts/torch/model/patch_tst/lightning_module.py +++ b/src/gluonts/torch/model/patch_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/simple_feedforward/estimator.py b/src/gluonts/torch/model/simple_feedforward/estimator.py index e43956d1ad..a909d4ee59 100644 --- a/src/gluonts/torch/model/simple_feedforward/estimator.py +++ b/src/gluonts/torch/model/simple_feedforward/estimator.py @@ -14,7 +14,7 @@ from typing import List, Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/simple_feedforward/lightning_module.py b/src/gluonts/torch/model/simple_feedforward/lightning_module.py index b7cf9a529a..f03473e78d 100644 --- a/src/gluonts/torch/model/simple_feedforward/lightning_module.py +++ b/src/gluonts/torch/model/simple_feedforward/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/tft/lightning_module.py b/src/gluonts/torch/model/tft/lightning_module.py index f6f7daa335..4647d740fd 100644 --- a/src/gluonts/torch/model/tft/lightning_module.py +++ b/src/gluonts/torch/model/tft/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated from gluonts.itertools import select diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index ab9fc9db54..e7fea4b0d7 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Iterable -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import numpy as np diff --git a/src/gluonts/torch/model/wavenet/lightning_module.py b/src/gluonts/torch/model/wavenet/lightning_module.py index daf78e451c..85dd0a6671 100644 --- a/src/gluonts/torch/model/wavenet/lightning_module.py +++ b/src/gluonts/torch/model/wavenet/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated