diff --git a/CHANGELOG.md b/CHANGELOG.md index f02a87a52e..5817013c90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Additional boosts for slicing with integers and Timestamps - Additional boosts for `from_group_dataframe()` by performing some of the heavy-duty computations on the entire DataFrame, rather than iteratively on the group level. - Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`. +- Improvements to `TorchForecastingModel`: + - Added support for additional lr scheduler configuration parameters for more control ("interval", "frequency", "monitor", "strict", "name"). [#2218](https://github.com/unit8co/darts/pull/2218) by [Dennis Bader](https://github.com/dennisbader). **Fixed** - Fixed a bug in probabilistic `LinearRegressionModel.fit()`, where the `model` attribute was not pointing to all underlying estimators. [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou). diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index ab98ee59c2..5610a2d3df 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -402,16 +402,25 @@ def _create_from_cls_and_kwargs(cls, kws): lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()} lr_sched_kws["optimizer"] = optimizer - # ReduceLROnPlateau requires a metric to "monitor" which must be set separately, most others do not - lr_monitor = lr_sched_kws.pop("monitor", None) + # lr scheduler can be configured with lightning; defaults below + lr_config_params = { + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + "name": None, + } + # update config with user params + lr_config_params = { + k: (v if k not in lr_sched_kws else lr_sched_kws.pop(k)) + for k, v in lr_config_params.items() + } lr_scheduler = _create_from_cls_and_kwargs( self.lr_scheduler_cls, lr_sched_kws ) - return [optimizer], { - "scheduler": lr_scheduler, - "monitor": lr_monitor if lr_monitor is not None else "val_loss", - } + + return [optimizer], dict({"scheduler": lr_scheduler}, **lr_config_params) else: return optimizer diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 24b8fd501e..77ad8aa07a 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1188,29 +1188,35 @@ def test_optimizers(self): # should not raise an error model.fit(self.series, epochs=1) - def test_lr_schedulers(self): - - lr_schedulers = [ + @pytest.mark.parametrize( + "lr_scheduler", + [ (torch.optim.lr_scheduler.StepLR, {"step_size": 10}), ( torch.optim.lr_scheduler.ReduceLROnPlateau, - {"threshold": 0.001, "monitor": "train_loss"}, + { + "threshold": 0.001, + "monitor": "train_loss", + "interval": "step", + "frequency": 2, + }, ), (torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.09}), - ] - - for lr_scheduler_cls, lr_scheduler_kwargs in lr_schedulers: - model = RNNModel( - 12, - "RNN", - 10, - 10, - lr_scheduler_cls=lr_scheduler_cls, - lr_scheduler_kwargs=lr_scheduler_kwargs, - **tfm_kwargs, - ) - # should not raise an error - model.fit(self.series, epochs=1) + ], + ) + def test_lr_schedulers(self, lr_scheduler): + lr_scheduler_cls, lr_scheduler_kwargs = lr_scheduler + model = RNNModel( + 12, + "RNN", + 10, + 10, + lr_scheduler_cls=lr_scheduler_cls, + lr_scheduler_kwargs=lr_scheduler_kwargs, + **tfm_kwargs, + ) + # should not raise an error + model.fit(self.series, epochs=1) def test_wrong_model_creation_params(self): valid_kwarg = {"pl_trainer_kwargs": {}}