diff --git a/benchmarks/common.py b/benchmarks/common.py index 0ea6144c..95bce329 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -17,8 +17,8 @@ def setup(train_path, fixed_input_height=0): syms = str(train_path / "syms") syms_table = SymbolsTable() - for k, v in data_module.syms.items(): - syms_table.add(v, k) + for k, v in data_module.syms: + syms_table.add(k, v) syms_table.save(syms) model( diff --git a/laia/callbacks/learning_rate.py b/laia/callbacks/learning_rate.py index c103ab90..47a33200 100644 --- a/laia/callbacks/learning_rate.py +++ b/laia/callbacks/learning_rate.py @@ -1,4 +1,5 @@ import pytorch_lightning as pl +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.utilities import rank_zero_only import laia.common.logging as log @@ -6,7 +7,7 @@ _logger = log.get_logger(__name__) -class LearningRate(pl.callbacks.LearningRateMonitor): +class LearningRate(LearningRateMonitor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.last_values = None @@ -23,8 +24,8 @@ def on_train_start(self, trainer, *args, **kwargs): self.last_values = {} @rank_zero_only - def on_epoch_end(self, trainer, *args, **kwargs): - super().on_epoch_end(trainer, *args, **kwargs) + def on_train_epoch_end(self, trainer, *args, **kwargs): + super().on_train_epoch_end(trainer, *args, **kwargs) for k, v in self.lrs.items(): prev_value = self.last_values.get(k, None) new_value = v[-1] diff --git a/laia/dummies/dummy_trainer.py b/laia/dummies/dummy_trainer.py index c2034d49..0e612152 100644 --- a/laia/dummies/dummy_trainer.py +++ b/laia/dummies/dummy_trainer.py @@ -5,7 +5,7 @@ class DummyTrainer(pl.Trainer): def __init__(self, **kwargs): defaults = { "checkpoint_callback": False, - "logger": False, + "logger": True, "weights_summary": None, "max_epochs": 1, "limit_train_batches": 10, diff --git a/requirements.txt b/requirements.txt index ce95a4be..989c57d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scipy matplotlib # cpu version: nnutils-pytorch nnutils-pytorch-cuda -pytorch-lightning==1.1.0 +pytorch-lightning==1.1.7 torch>=1.13,<1.14 torchvision>=0.14,<0.15 torchaudio>=0.13,<0.14 diff --git a/tests/callbacks/progress_bar_test.py b/tests/callbacks/progress_bar_test.py index 67b48391..a4397912 100644 --- a/tests/callbacks/progress_bar_test.py +++ b/tests/callbacks/progress_bar_test.py @@ -75,6 +75,7 @@ def test_progress_bar(tmpdir): r"100%\|[█]+\| 10/10 \[00:0[0-9]<00:00, " rf"{float_pattern}it/s, " rf"loss={float_pattern}, " + r"v_num=0, " rf"cer={float_pattern}%, " r"gpu_stats={'gpu_stats': 'baz'}]" )