Skip to content

Commit

Permalink
Feat/hist fc start stride (#2560)
Browse files Browse the repository at this point in the history
* improve hist fc start point

* add tests

* update documentation

* update changelog

* clean up code

* fix tests

* fix missed lines

* improve codecov
  • Loading branch information
dennisbader authored Nov 2, 2024
1 parent c116405 commit ad93612
Show file tree
Hide file tree
Showing 10 changed files with 723 additions and 177 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**

- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
29 changes: 20 additions & 9 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,12 @@ def historical_forecasts(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down Expand Up @@ -1018,6 +1019,7 @@ def retrain_func(
historical_forecasts_time_index=historical_forecasts_time_index,
start=start,
start_format=start_format,
stride=stride,
show_warnings=show_warnings,
)

Expand Down Expand Up @@ -1267,9 +1269,12 @@ def backtest(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down Expand Up @@ -1628,9 +1633,12 @@ def gridsearch(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Only used in expanding window mode. Defines the `start` format. Only effective when `start` is an integer
and `series` is indexed with a `pd.RangeIndex`.
Expand Down Expand Up @@ -1924,9 +1932,12 @@ def residuals(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down
28 changes: 19 additions & 9 deletions darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import logging
import random
from itertools import product

Expand Down Expand Up @@ -733,8 +734,7 @@ def test_backtest_multiple_series(self):
assert round(abs(error[0] - expected[0]), 4) == 0
assert round(abs(error[1] - expected[1]), 4) == 0

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_backtest_regression(self):
def test_backtest_regression(self, caplog):
np.random.seed(4)

gaussian_series = gt(mean=2, length=50)
Expand Down Expand Up @@ -804,13 +804,26 @@ def test_backtest_regression(self):
assert score > 0.9

# Using a too small start value
with pytest.raises(ValueError):
RandomForest(lags=12).backtest(series=target, start=0, forecast_horizon=3)
warning_expected = (
"`start` position `{0}` corresponding to time `{1}` is before the first "
"predictable/trainable historical forecasting point for series at index: 0. Using the first historical "
"forecasting point `2000-01-15 00:00:00` that lies a round-multiple of `stride=1` ahead of `start`. "
"To hide these warnings, set `show_warnings=False`."
)
caplog.clear()
with caplog.at_level(logging.WARNING):
_ = RandomForest(lags=12).backtest(
series=target, start=0, forecast_horizon=3
)
assert warning_expected.format(0, target.start_time()) in caplog.text
caplog.clear()

with pytest.raises(ValueError):
RandomForest(lags=12).backtest(
with caplog.at_level(logging.WARNING):
_ = RandomForest(lags=12).backtest(
series=target, start=0.01, forecast_horizon=3
)
assert warning_expected.format(0.01, target.start_time()) in caplog.text
caplog.clear()

# Using RandomForest's start default value
score = RandomForest(lags=12, random_state=0).backtest(
Expand Down Expand Up @@ -939,7 +952,6 @@ def test_gridsearch_metric_score(self):

assert score == recalculated_score, "The metric scores should match"

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_random_search(self):
np.random.seed(1)

Expand All @@ -958,7 +970,6 @@ def test_gridsearch_random_search(self):
assert isinstance(result[2], float)
assert min(param_range) <= result[1]["lags"] <= max(param_range)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_n_random_samples_bad_arguments(self):
dummy_series = get_dummy_series(ts_length=50)

Expand All @@ -981,7 +992,6 @@ def test_gridsearch_n_random_samples_bad_arguments(self):
params, dummy_series, forecast_horizon=1, n_random_samples=1.5
)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_n_random_samples(self):
np.random.seed(1)

Expand Down
Loading

0 comments on commit ad93612

Please sign in to comment.