Skip to content

Commit

Permalink
Merge pull request #55 from yschneiderTEKLIA/bump-pytorch-lightning
Browse files Browse the repository at this point in the history
Bump pytorch lightning to 1.1.7
  • Loading branch information
starride-teklia authored Oct 11, 2023
2 parents 3d554b9 + 8d27f7c commit 6c441df
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def setup(train_path, fixed_input_height=0):

syms = str(train_path / "syms")
syms_table = SymbolsTable()
for k, v in data_module.syms.items():
syms_table.add(v, k)
for k, v in data_module.syms:
syms_table.add(k, v)
syms_table.save(syms)

model(
Expand Down
7 changes: 4 additions & 3 deletions laia/callbacks/learning_rate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.utilities import rank_zero_only

import laia.common.logging as log

_logger = log.get_logger(__name__)


class LearningRate(pl.callbacks.LearningRateMonitor):
class LearningRate(LearningRateMonitor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_values = None
Expand All @@ -23,8 +24,8 @@ def on_train_start(self, trainer, *args, **kwargs):
self.last_values = {}

@rank_zero_only
def on_epoch_end(self, trainer, *args, **kwargs):
super().on_epoch_end(trainer, *args, **kwargs)
def on_train_epoch_end(self, trainer, *args, **kwargs):
super().on_train_epoch_end(trainer, *args, **kwargs)
for k, v in self.lrs.items():
prev_value = self.last_values.get(k, None)
new_value = v[-1]
Expand Down
2 changes: 1 addition & 1 deletion laia/dummies/dummy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class DummyTrainer(pl.Trainer):
def __init__(self, **kwargs):
defaults = {
"checkpoint_callback": False,
"logger": False,
"logger": True,
"weights_summary": None,
"max_epochs": 1,
"limit_train_batches": 10,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ scipy
matplotlib
# cpu version: nnutils-pytorch
nnutils-pytorch-cuda
pytorch-lightning==1.1.0
pytorch-lightning==1.1.7
torch>=1.13,<1.14
torchvision>=0.14,<0.15
torchaudio>=0.13,<0.14
Expand Down
1 change: 1 addition & 0 deletions tests/callbacks/progress_bar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_progress_bar(tmpdir):
r"100%\|[█]+\| 10/10 \[00:0[0-9]<00:00, "
rf"{float_pattern}it/s, "
rf"loss={float_pattern}, "
r"v_num=0, "
rf"cer={float_pattern}%, "
r"gpu_stats={'gpu_stats': 'baz'}]"
)
Expand Down

0 comments on commit 6c441df

Please sign in to comment.