Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Pass kwargs to the underlying models fit functions #2460

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added hyperparameters controlling the hidden layer sizes for the feature encoders in `TiDEModel`. [#2408](https://github.com/unit8co/darts/issues/2408) by [eschibli](https://github.com/eschibli).
- Made README's forecasting model support table more colorblind-friendly. [#2433](https://github.com/unit8co/darts/pull/2433)
- Updated the Ray Tune Hyperparameter Optimization example in the [user guide](https://unit8co.github.io/darts/userguide/hyperparameter_optimization.html) to work with the latest `ray` versions (`>=2.31.0`). [#2459](https://github.com/unit8co/darts/pull/2459) by [He Weilin](https://github.com/cnhwl).
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA`
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor
Comment on lines +17 to +20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all suggestions have been addressed, we could formulate it as below :)

Suggested change
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA`
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor
- Improvements to `ForecastingModel` : [#2460](https://github.com/unit8co/darts/pull/2460) by [DavidKleindienst](https://github.com/DavidKleindienst).
- All forecasting models now support keyword arguments `**kwargs` when calling `fit()` that will be passed to the underlying model's fit function.
- 🔴 Changes to `ExponentialSmoothing` for a unified API:
- Removed `fit_kwargs` from `__init__()`. They must now be passed as keyword arguments to `fit()`.
- Parameters to be passed to the underlying model's constructor must now be passed as keyword arguments (instead of an explicit `kwargs` parameter).


**Fixed**

Expand Down
8 changes: 5 additions & 3 deletions darts/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __exit__(self, *_):
os.close(fd)


def execute_and_suppress_output(function, logger, suppression_threshold_level, *args):
def execute_and_suppress_output(
function, logger, suppression_threshold_level, *args, **kwargs
):
"""
This function conditionally executes the given function with the given arguments
based on whether the current level of 'logger' is below, above or equal to
Expand All @@ -207,9 +209,9 @@ def execute_and_suppress_output(function, logger, suppression_threshold_level, *
"""
if logger.level >= suppression_threshold_level:
with SuppressStdoutStderr():
return_value = function(*args)
return_value = function(*args, **kwargs)
else:
return_value = function(*args)
return_value = function(*args, **kwargs)
return return_value


Expand Down
11 changes: 9 additions & 2 deletions darts/models/forecasting/auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,19 @@ def encode_year(idx):
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(), X=future_covariates.values() if future_covariates else None
series.values(),
X=future_covariates.values() if future_covariates else None,
**kwargs,
)
return self

Expand Down
8 changes: 7 additions & 1 deletion darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def encode_year(idx):
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand All @@ -135,6 +140,7 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
if future_covariates is not None
else None
),
**kwargs,
)

return self
Expand Down
34 changes: 22 additions & 12 deletions darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
"""

from typing import Any, Dict, Optional
from typing import Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw
Expand All @@ -24,8 +24,7 @@ def __init__(
seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE,
seasonal_periods: Optional[int] = None,
random_state: int = 0,
kwargs: Optional[Dict[str, Any]] = None,
**fit_kwargs,
**kwargs,
):
"""Exponential Smoothing

Expand Down Expand Up @@ -66,11 +65,6 @@ def __init__(
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_.
fit_kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.fit.html>`_.

Examples
--------
Expand All @@ -96,12 +90,28 @@ def __init__(
self.seasonal = seasonal
self.infer_seasonal_periods = seasonal_periods is None
self.seasonal_periods = seasonal_periods
self.constructor_kwargs = dict() if kwargs is None else kwargs
self.fit_kwargs = fit_kwargs
self.constructor_kwargs = kwargs
self.model = None
np.random.seed(random_state)

def fit(self, series: TimeSeries):
def fit(self, series: TimeSeries, **kwargs):
"""Fit/train the model on the (single) provided series.

Parameters
----------
series
The model will be trained to forecast this time series.
kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.fit.html>`_.

Returns
-------
self
Fitted model.
"""
super().fit(series)
self._assert_univariate(series)
series = self.training_series
Expand All @@ -128,7 +138,7 @@ def fit(self, series: TimeSeries):
dates=series.time_index if series.has_datetime_index else None,
**self.constructor_kwargs,
)
hw_results = hw_model.fit(**self.fit_kwargs)
hw_results = hw_model.fit(**kwargs)
self.model = hw_results

if self.infer_seasonal_periods:
Expand Down
18 changes: 15 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2902,7 +2902,12 @@ class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC):
All implementations must implement the :func:`_fit()` and :func:`_predict()` methods.
"""

def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
"""Fit/train the model on the (single) provided series.

Optionally, a future covariates series can be provided as well.
Expand All @@ -2915,6 +2920,8 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None
A time series of future-known covariates. This time series will not be forecasted, but can be used by
some models as an input. It must contain at least the same time steps/indices as the target `series`.
If it is longer than necessary, it will be automatically trimmed.
kwargs
Optional keyword arguments that will be passed to the fit function of the underlying model.

Returns
-------
Expand Down Expand Up @@ -2946,10 +2953,15 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None

super().fit(series)

return self._fit(series, future_covariates=future_covariates)
return self._fit(series, future_covariates=future_covariates, **kwargs)

@abstractmethod
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
"""Fits/trains the model on the provided series.
DualCovariatesModels must implement the fit logic in this method.
"""
Expand Down
11 changes: 8 additions & 3 deletions darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ def encode_year(idx):
# Use 0 as default value
self._floor = 0

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand Down Expand Up @@ -249,10 +254,10 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non

if self.suppress_stdout_stderr:
self._execute_and_suppress_output(
self.model.fit, logger, logging.WARNING, fit_df
self.model.fit, logger, logging.WARNING, fit_df, **kwargs
)
else:
self.model.fit(fit_df)
self.model.fit(fit_df, **kwargs)

return self

Expand Down
8 changes: 7 additions & 1 deletion darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,19 @@ def encode_year(idx):
super().__init__(add_encoders=add_encoders)
self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs)

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
X=future_covariates.values(copy=False) if future_covariates else None,
**kwargs,
)
return self

Expand Down
11 changes: 7 additions & 4 deletions darts/models/forecasting/sf_auto_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def encode_year(idx):
self.model = SFAutoETS(*autoets_args, **autoets_kwargs)
self._linreg = None

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
def _fit(
self,
series: TimeSeries,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series
Expand All @@ -116,9 +121,7 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
else:
target = series

self.model.fit(
target.values(copy=False).flatten(),
)
self.model.fit(target.values(copy=False).flatten(), **kwargs)
return self

def _predict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_constructor_kwargs(self):
"initial_trend": 0.2,
"initial_seasonal": np.arange(1, 25),
}
model = ExponentialSmoothing(kwargs=constructor_kwargs)
model = ExponentialSmoothing(**constructor_kwargs)
model.fit(self.series)
# must be checked separately, name is not consistent
np.testing.assert_array_almost_equal(
Expand All @@ -70,22 +70,19 @@ def test_fit_kwargs(self):
# using default optimization method
model = ExponentialSmoothing()
model.fit(self.series)
assert model.fit_kwargs == {}
pred = model.predict(n=2)

model_bis = ExponentialSmoothing()
model_bis.fit(self.series)
assert model_bis.fit_kwargs == {}
pred_bis = model_bis.predict(n=2)

# two methods with the same parameters should yield the same forecasts
assert pred.time_index.equals(pred_bis.time_index)
np.testing.assert_array_almost_equal(pred.values(), pred_bis.values())

# change optimization method
model_ls = ExponentialSmoothing(method="least_squares")
model_ls.fit(self.series)
assert model_ls.fit_kwargs == {"method": "least_squares"}
model_ls = ExponentialSmoothing()
model_ls.fit(self.series, method="least_squares")
pred_ls = model_ls.predict(n=2)

# forecasts should be slightly different
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_model_str_call(self, config):
(
ExponentialSmoothing(),
"ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, "
+ "seasonal_periods=None, random_state=0, kwargs=None)",
+ "seasonal_periods=None, random_state=0)",
), # no params changed
(
ARIMA(1, 1, 1),
Expand Down
Loading