Skip to content

Commit

Permalink
add support for more lr scheduler config parameters to torch models (#…
Browse files Browse the repository at this point in the history
…2218)

* add support for more lr scheduler config parameters to torch models

* update changelog
  • Loading branch information
dennisbader authored Feb 8, 2024
1 parent 5b05d2b commit 8073de4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Additional boosts for slicing with integers and Timestamps
- Additional boosts for `from_group_dataframe()` by performing some of the heavy-duty computations on the entire DataFrame, rather than iteratively on the group level.
- Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`.
- Improvements to `TorchForecastingModel`:
- Added support for additional lr scheduler configuration parameters for more control ("interval", "frequency", "monitor", "strict", "name"). [#2218](https://github.com/unit8co/darts/pull/2218) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed a bug in probabilistic `LinearRegressionModel.fit()`, where the `model` attribute was not pointing to all underlying estimators. [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
21 changes: 15 additions & 6 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,25 @@ def _create_from_cls_and_kwargs(cls, kws):
lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()}
lr_sched_kws["optimizer"] = optimizer

# ReduceLROnPlateau requires a metric to "monitor" which must be set separately, most others do not
lr_monitor = lr_sched_kws.pop("monitor", None)
# lr scheduler can be configured with lightning; defaults below
lr_config_params = {
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
"strict": True,
"name": None,
}
# update config with user params
lr_config_params = {
k: (v if k not in lr_sched_kws else lr_sched_kws.pop(k))
for k, v in lr_config_params.items()
}

lr_scheduler = _create_from_cls_and_kwargs(
self.lr_scheduler_cls, lr_sched_kws
)
return [optimizer], {
"scheduler": lr_scheduler,
"monitor": lr_monitor if lr_monitor is not None else "val_loss",
}

return [optimizer], dict({"scheduler": lr_scheduler}, **lr_config_params)
else:
return optimizer

Expand Down
42 changes: 24 additions & 18 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,29 +1188,35 @@ def test_optimizers(self):
# should not raise an error
model.fit(self.series, epochs=1)

def test_lr_schedulers(self):

lr_schedulers = [
@pytest.mark.parametrize(
"lr_scheduler",
[
(torch.optim.lr_scheduler.StepLR, {"step_size": 10}),
(
torch.optim.lr_scheduler.ReduceLROnPlateau,
{"threshold": 0.001, "monitor": "train_loss"},
{
"threshold": 0.001,
"monitor": "train_loss",
"interval": "step",
"frequency": 2,
},
),
(torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.09}),
]

for lr_scheduler_cls, lr_scheduler_kwargs in lr_schedulers:
model = RNNModel(
12,
"RNN",
10,
10,
lr_scheduler_cls=lr_scheduler_cls,
lr_scheduler_kwargs=lr_scheduler_kwargs,
**tfm_kwargs,
)
# should not raise an error
model.fit(self.series, epochs=1)
],
)
def test_lr_schedulers(self, lr_scheduler):
lr_scheduler_cls, lr_scheduler_kwargs = lr_scheduler
model = RNNModel(
12,
"RNN",
10,
10,
lr_scheduler_cls=lr_scheduler_cls,
lr_scheduler_kwargs=lr_scheduler_kwargs,
**tfm_kwargs,
)
# should not raise an error
model.fit(self.series, epochs=1)

def test_wrong_model_creation_params(self):
valid_kwarg = {"pl_trainer_kwargs": {}}
Expand Down

0 comments on commit 8073de4

Please sign in to comment.