Skip to content

Commit

Permalink
Fix/exp smooth constructor args (#2059)
Browse files Browse the repository at this point in the history
* feat: adding support for constructor kwargs

* feat: adding tests

* fix: udpated representation test for ExponentialSmoothing model

* update changelog.md

---------

Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
madtoinou and dennisbader authored Nov 8, 2023
1 parent a5a4306 commit da049e5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
- Other improvements:
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
- Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
12 changes: 10 additions & 2 deletions darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
"""

from typing import Optional
from typing import Any, Dict, Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw
Expand All @@ -24,7 +24,8 @@ def __init__(
seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE,
seasonal_periods: Optional[int] = None,
random_state: int = 0,
**fit_kwargs,
kwargs: Optional[Dict[str, Any]] = None,
**fit_kwargs
):

"""Exponential Smoothing
Expand Down Expand Up @@ -61,6 +62,11 @@ def __init__(
seasonal_periods
The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily
data with a weekly cycle. If not set, inferred from frequency of the series.
kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`.
See `the documentation
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_.
fit_kwargs
Some optional keyword arguments that will be used to call
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
Expand Down Expand Up @@ -91,6 +97,7 @@ def __init__(
self.seasonal = seasonal
self.infer_seasonal_periods = seasonal_periods is None
self.seasonal_periods = seasonal_periods
self.constructor_kwargs = dict() if kwargs is None else kwargs
self.fit_kwargs = fit_kwargs
self.model = None
np.random.seed(random_state)
Expand Down Expand Up @@ -120,6 +127,7 @@ def fit(self, series: TimeSeries):
seasonal_periods=seasonal_periods_param,
freq=series.freq if series.has_datetime_index else None,
dates=series.time_index if series.has_datetime_index else None,
**self.constructor_kwargs
)
hw_results = hw_model.fit(**self.fit_kwargs)
self.model = hw_results
Expand Down
79 changes: 65 additions & 14 deletions darts/tests/models/forecasting/test_exponential_smoothing.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,92 @@
import numpy as np
import pytest

from darts import TimeSeries
from darts.models import ExponentialSmoothing
from darts.utils import timeseries_generation as tg


class TestExponentialSmoothing:
def helper_test_seasonality_inference(self, freq_string, expected_seasonal_periods):
series = tg.sine_timeseries(length=200, freq=freq_string)
model = ExponentialSmoothing()
model.fit(series)
assert model.seasonal_periods == expected_seasonal_periods
series = tg.sine_timeseries(length=100, freq="H")

def test_seasonality_inference(self):

# test `seasonal_periods` inference for datetime indices
freq_str_seasonality_periods_tuples = [
@pytest.mark.parametrize(
"freq_string,expected_seasonal_periods",
[
("D", 7),
("H", 24),
("M", 12),
("W", 52),
("Q", 4),
("B", 5),
]
for tuple in freq_str_seasonality_periods_tuples:
self.helper_test_seasonality_inference(*tuple)
],
)
def test_seasonality_inference(
self, freq_string: str, expected_seasonal_periods: int
):
series = tg.sine_timeseries(length=200, freq=freq_string)
model = ExponentialSmoothing()
model.fit(series)
assert model.seasonal_periods == expected_seasonal_periods

# test default selection for integer index
def test_default_parameters(self):
"""Test default selection for integer index"""
series = TimeSeries.from_values(np.arange(1, 30, 1))
model = ExponentialSmoothing()
model.fit(series)
assert model.seasonal_periods == 12

# test whether a model that inferred a seasonality period before will do it again for a new series
def test_multiple_fit(self):
"""Test whether a model that inferred a seasonality period before will do it again for a new series"""
series1 = tg.sine_timeseries(length=100, freq="M")
series2 = tg.sine_timeseries(length=100, freq="D")
model = ExponentialSmoothing()
model.fit(series1)
model.fit(series2)
assert model.seasonal_periods == 7

def test_constructor_kwargs(self):
"""Using kwargs to pass additional parameters to the constructor"""
constructor_kwargs = {
"initialization_method": "known",
"initial_level": 0.5,
"initial_trend": 0.2,
"initial_seasonal": np.arange(1, 25),
}
model = ExponentialSmoothing(kwargs=constructor_kwargs)
model.fit(self.series)
# must be checked separately, name is not consistent
np.testing.assert_array_almost_equal(
model.model.model.params["initial_seasons"],
constructor_kwargs["initial_seasonal"],
)
for param_name in ["initial_level", "initial_trend"]:
assert (
model.model.model.params[param_name] == constructor_kwargs[param_name]
)

def test_fit_kwargs(self):
"""Using kwargs to pass additional parameters to the fit()"""
# using default optimization method
model = ExponentialSmoothing()
model.fit(self.series)
assert model.fit_kwargs == {}
pred = model.predict(n=2)

model_bis = ExponentialSmoothing()
model_bis.fit(self.series)
assert model_bis.fit_kwargs == {}
pred_bis = model_bis.predict(n=2)

# two methods with the same parameters should yield the same forecasts
assert pred.time_index.equals(pred_bis.time_index)
np.testing.assert_array_almost_equal(pred.values(), pred_bis.values())

# change optimization method
model_ls = ExponentialSmoothing(method="least_squares")
model_ls.fit(self.series)
assert model_ls.fit_kwargs == {"method": "least_squares"}
pred_ls = model_ls.predict(n=2)

# forecasts should be slightly different
assert pred.time_index.equals(pred_ls.time_index)
assert all(np.not_equal(pred.values(), pred_ls.values()))
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def test_model_str_call(self, config):
(
ExponentialSmoothing(),
"ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, "
+ "seasonal_periods=None, random_state=0)",
+ "seasonal_periods=None, random_state=0, kwargs=None)",
), # no params changed
(
ARIMA(1, 1, 1),
Expand Down

0 comments on commit da049e5

Please sign in to comment.