From ad9361231e94f549da8aad2ee231182e4d30fc8e Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Sat, 2 Nov 2024 13:40:06 +0100 Subject: [PATCH] Feat/hist fc start stride (#2560) * improve hist fc start point * add tests * update documentation * update changelog * clean up code * fix tests * fix missed lines * improve codecov --- CHANGELOG.md | 2 + darts/models/forecasting/forecasting_model.py | 29 +- .../models/forecasting/test_backtesting.py | 28 +- .../forecasting/test_historical_forecasts.py | 334 +++++++++++++++-- .../utils/historical_forecasts/test_utils.py | 154 ++++++++ darts/timeseries.py | 2 +- darts/utils/historical_forecasts/__init__.py | 2 - ...timized_historical_forecasts_regression.py | 2 + .../optimized_historical_forecasts_torch.py | 1 + darts/utils/historical_forecasts/utils.py | 346 ++++++++++++------ 10 files changed, 723 insertions(+), 177 deletions(-) create mode 100644 darts/tests/utils/historical_forecasts/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 780e4af04b..6dd0160d15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** +- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader). + **Fixed** - Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader). diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index ee271ceed6..f3c3ebf489 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -706,11 +706,12 @@ def historical_forecasts( or `retrain` is a Callable and the first trainable point is earlier than the first predictable point. - the first trainable point (given `train_length`) otherwise + Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that + is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists. Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also - shifted by `output_chunk_shift` points into the future. - Note: Raises a ValueError if `start` yields a time outside the time index of `series`. + shifted by `output_chunk_shift` points into the future. Note: If `start` is outside the possible historical forecasting times, will ignore the parameter - (default behavior with ``None``) and start at the first trainable/predictable point. + (default behavior with ``None``) and start at the first trainable/predictable point. start_format Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a `pd.RangeIndex`. @@ -1018,6 +1019,7 @@ def retrain_func( historical_forecasts_time_index=historical_forecasts_time_index, start=start, start_format=start_format, + stride=stride, show_warnings=show_warnings, ) @@ -1267,9 +1269,12 @@ def backtest( or `retrain` is a Callable and the first trainable point is earlier than the first predictable point. - the first trainable point (given `train_length`) otherwise - Note: Raises a ValueError if `start` yields a time outside the time index of `series`. + Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that + is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists. + Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also + shifted by `output_chunk_shift` points into the future. Note: If `start` is outside the possible historical forecasting times, will ignore the parameter - (default behavior with ``None``) and start at the first trainable/predictable point. + (default behavior with ``None``) and start at the first trainable/predictable point. start_format Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a `pd.RangeIndex`. @@ -1628,9 +1633,12 @@ def gridsearch( or `retrain` is a Callable and the first trainable point is earlier than the first predictable point. - the first trainable point (given `train_length`) otherwise - Note: Raises a ValueError if `start` yields a time outside the time index of `series`. + Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that + is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists. + Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also + shifted by `output_chunk_shift` points into the future. Note: If `start` is outside the possible historical forecasting times, will ignore the parameter - (default behavior with ``None``) and start at the first trainable/predictable point. + (default behavior with ``None``) and start at the first trainable/predictable point. start_format Only used in expanding window mode. Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a `pd.RangeIndex`. @@ -1924,9 +1932,12 @@ def residuals( or `retrain` is a Callable and the first trainable point is earlier than the first predictable point. - the first trainable point (given `train_length`) otherwise - Note: Raises a ValueError if `start` yields a time outside the time index of `series`. + Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that + is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists. + Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also + shifted by `output_chunk_shift` points into the future. Note: If `start` is outside the possible historical forecasting times, will ignore the parameter - (default behavior with ``None``) and start at the first trainable/predictable point. + (default behavior with ``None``) and start at the first trainable/predictable point. start_format Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a `pd.RangeIndex`. diff --git a/darts/tests/models/forecasting/test_backtesting.py b/darts/tests/models/forecasting/test_backtesting.py index 60f1b3e525..a70f4c5734 100644 --- a/darts/tests/models/forecasting/test_backtesting.py +++ b/darts/tests/models/forecasting/test_backtesting.py @@ -1,4 +1,5 @@ import itertools +import logging import random from itertools import product @@ -733,8 +734,7 @@ def test_backtest_multiple_series(self): assert round(abs(error[0] - expected[0]), 4) == 0 assert round(abs(error[1] - expected[1]), 4) == 0 - @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") - def test_backtest_regression(self): + def test_backtest_regression(self, caplog): np.random.seed(4) gaussian_series = gt(mean=2, length=50) @@ -804,13 +804,26 @@ def test_backtest_regression(self): assert score > 0.9 # Using a too small start value - with pytest.raises(ValueError): - RandomForest(lags=12).backtest(series=target, start=0, forecast_horizon=3) + warning_expected = ( + "`start` position `{0}` corresponding to time `{1}` is before the first " + "predictable/trainable historical forecasting point for series at index: 0. Using the first historical " + "forecasting point `2000-01-15 00:00:00` that lies a round-multiple of `stride=1` ahead of `start`. " + "To hide these warnings, set `show_warnings=False`." + ) + caplog.clear() + with caplog.at_level(logging.WARNING): + _ = RandomForest(lags=12).backtest( + series=target, start=0, forecast_horizon=3 + ) + assert warning_expected.format(0, target.start_time()) in caplog.text + caplog.clear() - with pytest.raises(ValueError): - RandomForest(lags=12).backtest( + with caplog.at_level(logging.WARNING): + _ = RandomForest(lags=12).backtest( series=target, start=0.01, forecast_horizon=3 ) + assert warning_expected.format(0.01, target.start_time()) in caplog.text + caplog.clear() # Using RandomForest's start default value score = RandomForest(lags=12, random_state=0).backtest( @@ -939,7 +952,6 @@ def test_gridsearch_metric_score(self): assert score == recalculated_score, "The metric scores should match" - @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") def test_gridsearch_random_search(self): np.random.seed(1) @@ -958,7 +970,6 @@ def test_gridsearch_random_search(self): assert isinstance(result[2], float) assert min(param_range) <= result[1]["lags"] <= max(param_range) - @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") def test_gridsearch_n_random_samples_bad_arguments(self): dummy_series = get_dummy_series(ts_length=50) @@ -981,7 +992,6 @@ def test_gridsearch_n_random_samples_bad_arguments(self): params, dummy_series, forecast_horizon=1, n_random_samples=1.5 ) - @pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch") def test_gridsearch_n_random_samples(self): np.random.seed(1) diff --git a/darts/tests/models/forecasting/test_historical_forecasts.py b/darts/tests/models/forecasting/test_historical_forecasts.py index 5a4d201cb3..d3a65f047a 100644 --- a/darts/tests/models/forecasting/test_historical_forecasts.py +++ b/darts/tests/models/forecasting/test_historical_forecasts.py @@ -1,4 +1,6 @@ +import copy import itertools +import logging from itertools import product import numpy as np @@ -20,6 +22,7 @@ NotImportedModule, ) from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs +from darts.utils import n_steps_between from darts.utils import timeseries_generation as tg if TORCH_AVAILABLE: @@ -802,95 +805,321 @@ def test_historical_forecasts(self, config): "Model cannot be fit/trained with `future_covariates`." ) - def test_sanity_check_invalid_start(self): + def test_sanity_check_start(self): timeidx_ = tg.linear_timeseries(length=10) rangeidx_step1 = tg.linear_timeseries(start=0, length=10, freq=1) rangeidx_step2 = tg.linear_timeseries(start=0, length=10, freq=2) + # invalid start float + model = LinearRegressionModel(lags=1) + with pytest.raises(ValueError) as msg: + model.historical_forecasts(rangeidx_step1, start=1.1) + assert str(msg.value).startswith( + "if `start` is a float, must be between 0.0 and 1.0." + ) + with pytest.raises(ValueError) as msg: + model.historical_forecasts(rangeidx_step1, start=-0.1) + assert str(msg.value).startswith( + "if `start` is a float, must be between 0.0 and 1.0." + ) + + # invalid start type + with pytest.raises(TypeError) as msg: + model.historical_forecasts(rangeidx_step1, start=[0.1]) + assert str(msg.value).startswith( + "`start` must be either `float`, `int`, `pd.Timestamp` or `None`." + ) + + # label_index (timestamp) with range index series + model = LinearRegressionModel(lags=1) + with pytest.raises(ValueError) as msg: + model.historical_forecasts( + rangeidx_step1, start=timeidx_.end_time() + timeidx_.freq + ) + assert str(msg.value).startswith( + "if `start` is a `pd.Timestamp`, all series must be indexed with a `pd.DatetimeIndex`" + ) + # label_index (int), too large with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts(timeidx_, start=11) - assert str(msg.value).startswith("`start` index `11` is out of bounds") + model.historical_forecasts(timeidx_, start=11) + assert str(msg.value).startswith("`start` position `11` is out of bounds") with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( + model.historical_forecasts( rangeidx_step1, start=rangeidx_step1.end_time() + rangeidx_step1.freq ) assert str(msg.value).startswith( - "`start` index `10` is larger than the last index" + "`start` time `10` is larger than the last index" ) with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( + model.historical_forecasts( rangeidx_step2, start=rangeidx_step2.end_time() + rangeidx_step2.freq ) assert str(msg.value).startswith( - "`start` index `20` is larger than the last index" + "`start` time `20` is larger than the last index" ) # label_index (timestamp) too high with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( + model.historical_forecasts( timeidx_, start=timeidx_.end_time() + timeidx_.freq ) assert str(msg.value).startswith( "`start` time `2000-01-11 00:00:00` is after the last timestamp `2000-01-10 00:00:00`" ) - # label_index, invalid + # label_index (timestamp), before series start and stride does not allow to find valid start point in series with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts(rangeidx_step2, start=11) - assert str(msg.value).startswith("The provided point is not a valid index") + model.historical_forecasts( + timeidx_, + start=timeidx_.start_time() - timeidx_.freq, + stride=len(timeidx_) + 1, + ) + assert str(msg.value) == ( + "`start` time `1999-12-31 00:00:00` is smaller than the first time index `2000-01-01 00:00:00` " + "for series at index: 0, and could not find a valid start point within the time index that lies a " + "round-multiple of `stride=11` ahead of `start` (first inferred start is `2000-01-11 00:00:00`, " + "but last time index is `2000-01-10 00:00:00`." + ) - # label_index, too low + # label_index (timestamp), before trainable/predictable index and stride does not allow to find valid start + # point in series with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - timeidx_, start=timeidx_.start_time() - timeidx_.freq + model.historical_forecasts( + timeidx_, start=timeidx_.start_time(), stride=len(timeidx_) ) - assert str(msg.value).startswith( - "`start` time `1999-12-31 00:00:00` is before the first timestamp `2000-01-01 00:00:00`" + assert str(msg.value) == ( + "`start` time `2000-01-01 00:00:00` is smaller than the first historical forecastable time index " + "`2000-01-04 00:00:00` for series at index: 0, and could not find a valid start point within the " + "historical forecastable time index that lies a round-multiple of `stride=10` ahead of `start` " + "(first inferred start is `2000-01-11 00:00:00`, but last historical forecastable time index is " + "`2000-01-10 00:00:00`." ) + + # label_index (int), too low and stride does not allow to find valid start point in series with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - rangeidx_step1, start=rangeidx_step1.start_time() - rangeidx_step1.freq + model.historical_forecasts( + rangeidx_step1, + start=rangeidx_step1.start_time() - rangeidx_step1.freq, + stride=len(rangeidx_step1) + 1, ) - assert str(msg.value).startswith( - "`start` index `-1` is smaller than the first index `0`" + assert str(msg.value) == ( + "`start` time `-1` is smaller than the first time index `0` for series at index: 0, and could not " + "find a valid start point within the time index that lies a round-multiple of `stride=11` ahead of " + "`start` (first inferred start is `10`, but last time index is `9`." ) + + # label_index (int), before trainable/predictable index and stride does not allow to find valid start + # point in series with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - rangeidx_step2, start=rangeidx_step2.start_time() - rangeidx_step2.freq + model.historical_forecasts( + rangeidx_step1, + start=rangeidx_step1.start_time(), + stride=len(rangeidx_step1), ) - assert str(msg.value).startswith( - "`start` index `-2` is smaller than the first index `0`" + assert str(msg.value) == ( + "`start` time `0` is smaller than the first historical forecastable time index `3` for series at " + "index: 0, and could not find a valid start point within the historical forecastable time index " + "that lies a round-multiple of `stride=10` ahead of `start` (first inferred start is `10`, but last " + "historical forecastable time index is `9`." ) - # positional_index, predicting only the last position - LinearRegressionModel(lags=1).historical_forecasts( - timeidx_, start=9, start_format="position" - ) + # positional_index with time index, predicting only the last position + preds = model.historical_forecasts(timeidx_, start=9, start_format="position") + assert len(preds) == 1 + assert preds.start_time() == timeidx_.time_index[9] # positional_index, predicting from the first position with retrain=True - with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - timeidx_, start=-10, start_format="position" - ) - assert str(msg.value).endswith(", resulting in an empty training set.") + preds1 = model.historical_forecasts( + timeidx_, start=-10, start_format="position" + ) + # positional_index, before start of series gives same results + preds2 = model.historical_forecasts( + timeidx_, start=-11, start_format="position" + ) + assert ( + len(preds1) == len(preds2) == len(timeidx_) - model.min_train_series_length + ) + assert ( + preds1.start_time() + == preds2.start_time() + == timeidx_.time_index[model.min_train_series_length] + ) # positional_index, beyond boundaries with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - timeidx_, start=10, start_format="position" - ) + model.historical_forecasts(timeidx_, start=10, start_format="position") assert str(msg.value).startswith( - "`start` index `10` is out of bounds for series of length 10" + "`start` position `10` is out of bounds for series of length 10" ) + + # positional_index with range index, predicting only the last position + preds = model.historical_forecasts( + rangeidx_step2, start=9, start_format="position" + ) + assert len(preds) == 1 + assert preds.start_time() == rangeidx_step2.time_index[9] + + # positional_index, predicting from the first position with retrain=True + preds1 = model.historical_forecasts( + rangeidx_step2, start=-10, start_format="position" + ) + # positional_index, before start of series gives same results + preds2 = model.historical_forecasts( + rangeidx_step2, start=-11, start_format="position" + ) + assert ( + len(preds1) + == len(preds2) + == len(rangeidx_step2) - model.min_train_series_length + ) + assert ( + preds1.start_time() + == preds2.start_time() + == rangeidx_step2.time_index[model.min_train_series_length] + ) + + # positional_index, beyond boundaries with pytest.raises(ValueError) as msg: - LinearRegressionModel(lags=1).historical_forecasts( - timeidx_, start=-11, start_format="position" + model.historical_forecasts( + rangeidx_step2, start=10, start_format="position" ) assert str(msg.value).startswith( - "`start` index `-11` is out of bounds for series of length 10" + "`start` position `10` is out of bounds for series of length 10" ) + @pytest.mark.parametrize( + "config", + list( + itertools.product( + [ + ( + "2000-01-01 00:00:00", # start + 1, # stride + "2000-01-01 03:00:00", # expected start + "h", # freq + ), + ("2000-01-01 00:00:00", 2, "2000-01-01 04:00:00", "h"), + ("1999-01-01 00:00:00", 6, "2000-01-01 06:00:00", "h"), + ("2000-01-01 00:00:00", 2, "2000-01-01 08:00:00", "2h"), + # special case where start is not in the frequency -> start will be converted + # to "2000-01-01 00:00:00", and then it's adjusted to be within the historical fc index + ("1999-12-31 23:00:00", 2, "2000-01-01 08:00:00", "2h"), + # integer index + (0, 1, 3, 1), + (0, 2, 4, 1), + (-24, 6, 6, 1), + (0, 2, 8, 2), + # special case where start is not in the frequency -> start will be converted + # to 0, and then it's adjusted to be within the historical fc index + (-1, 2, 8, 2), + ], + ["value", "position"], # start format + [True, False], # retrain + [True, False] if TORCH_AVAILABLE else [False], # use torch model + ) + ), + ) + def test_historical_forecasts_start_too_early(self, caplog, config): + """If start is not within the trainable/forecastable index, it should start a round-multiple of `stride` ahead + of `start`. Checks for: + - correct warnings + - datetime / integer index + - different frequencies + - different strides + - start "value" and "position" + - retrain / no-retrain (optimized and non-optimized) + - torch and regression model + """ + # the configuration is defined for `retrain = True` and `start_format = "value"` + ( + (start, stride, start_expected, freq), + start_format, + retrain, + use_torch_model, + ) = config + if isinstance(freq, str): + start, start_expected = pd.Timestamp(start), pd.Timestamp(start_expected) + start_series = pd.Timestamp("2000-01-01 00:00:00") + else: + start_series = 0 + + series = tg.linear_timeseries( + start=start_series, + length=7, + freq=freq, + ) + # when hist fc `start` is not in the valid frequency range, it is converted to a time that is valid. + # e.g. `start="1999-12-31 23:00:00:` with `freq="2h"` is converted to `"2000-01-01 00:00:00"` + start_position = n_steps_between(end=start_series, start=start, freq=freq) + start_time_expected = series.start_time() - start_position * series.freq + + if start_format == "position": + start = -start_position + if start < 0: + # negative position is relative to the end of the series + start -= len(series) + start_format_msg = f"position `{start}` corresponding to time " + else: + start_format_msg = "time " + + if use_torch_model: + kwargs = copy.deepcopy(tfm_kwargs) + kwargs["pl_trainer_kwargs"]["fast_dev_run"] = True + # use ocl=2 to have same `min_train_length` as the regression model + model = BlockRNNModel( + input_chunk_length=1, output_chunk_length=2, n_epochs=1, **kwargs + ) + else: + model = LinearRegressionModel(lags=1) + + model.fit(series) + # if the stride is shorter than the train series length, retrain=False can start earlier + if not retrain and stride <= model.min_train_series_length: + start_expected -= ( + model.min_train_series_length + model.extreme_lags[0] + ) * series.freq + + # label index + warning_expected = ( + f"`start` {start_format_msg}`{start_time_expected}` is before the first predictable/trainable historical " + f"forecasting point for series at index: 0. Using the first historical forecasting point " + f"`{start_expected}` that lies a round-multiple of `stride={stride}` ahead of `start`. To hide these " + f"warnings, set `show_warnings=False`." + ) + + # check that warning is raised when too early + enable_optimizations = [False] if retrain else [False, True] + for enable_optimization in enable_optimizations: + with caplog.at_level(logging.WARNING): + pred = model.historical_forecasts( + series, + start=start, + stride=stride, + retrain=retrain, + start_format=start_format, + enable_optimization=enable_optimization, + ) + assert warning_expected in caplog.text + assert pred.start_time() == start_expected + caplog.clear() + # but no warning when start is at the right time + warning_short = ( + f"Using the first historical forecasting point `{start_expected}` that lies a round-multiple " + f"of `stride={stride}` ahead of `start`. To hide these warnings, set `show_warnings=False`." + ) + with caplog.at_level(logging.WARNING): + pred = model.historical_forecasts( + series, + start=start_expected, + stride=stride, + retrain=False, + start_format="value", + enable_optimization=True, + ) + assert warning_short not in caplog.text + assert pred.start_time() == start_expected + @pytest.mark.parametrize("config", models_reg_no_cov_cls_kwargs) def test_regression_auto_start_multiple_no_cov(self, config): train_length = 15 @@ -2599,3 +2828,28 @@ def test_sample_weight(self, config): == f"`sample_weight` at series index {invalid_idx} must contain " f"at least all times of the corresponding target `series`." ) + + def test_historical_forecast_additional_sanity_checks(self): + model = LinearRegressionModel(lags=1) + + # `stride <= 0` + with pytest.raises(ValueError) as err: + _ = model.historical_forecasts( + series=self.ts_pass_train, + stride=0, + ) + assert ( + str(err.value) + == "The provided stride parameter must be a positive integer." + ) + + # start_format="position" but `start` is not `int` + with pytest.raises(ValueError) as err: + _ = model.historical_forecasts( + series=self.ts_pass_train, + start=pd.Timestamp("01-01-2020"), + start_format="position", + ) + assert str(err.value).startswith( + "Since `start_format='position'`, `start` must be an integer, received" + ) diff --git a/darts/tests/utils/historical_forecasts/test_utils.py b/darts/tests/utils/historical_forecasts/test_utils.py new file mode 100644 index 0000000000..fdb14ed1a5 --- /dev/null +++ b/darts/tests/utils/historical_forecasts/test_utils.py @@ -0,0 +1,154 @@ +import itertools + +import pandas as pd +import pytest + +import darts.utils.historical_forecasts.utils as hfc_utils +from darts.models import LinearRegressionModel +from darts.utils.timeseries_generation import linear_timeseries + + +class TestHistoricalForecastsUtils: + model = LinearRegressionModel(lags=1) + + def test_historical_forecasts_check_kwargs(self): + # `hfc_args` not part of `dict_kwargs` works + hfc_args = {"a", "b"} + dict_kwargs = {"c": 0, "d": 0} + out = hfc_utils._historical_forecasts_check_kwargs( + hfc_args=hfc_args, + name_kwargs="some_name", + dict_kwargs=dict_kwargs, + ) + assert out == dict_kwargs + + # `hfc_args` is part of `dict_kwargs` fails + with pytest.raises(ValueError): + _ = hfc_utils._historical_forecasts_check_kwargs( + hfc_args={"c"}, + name_kwargs="some_name", + dict_kwargs=dict_kwargs, + ) + + @pytest.mark.parametrize( + "config", + itertools.product( + [True, False], # retrain + [True, False], # show warnings + [{}, {"some_fit_param": 0}], # fit kwargs + [{}, {"some_predict_param": 0}], # predict kwargs + ), + ) + def test_historical_forecasts_sanitize_kwargs(self, config): + retrain, show_warnings, fit_kwargs, pred_kwargs = config + fit_kwargs_out, pred_kwargs_out = ( + hfc_utils._historical_forecasts_sanitize_kwargs( + self.model, + fit_kwargs=fit_kwargs, + predict_kwargs=pred_kwargs, + retrain=retrain, + show_warnings=show_warnings, + ) + ) + assert fit_kwargs_out == fit_kwargs + assert pred_kwargs_out == pred_kwargs + + @pytest.mark.parametrize( + "kwargs", + [ + { + "fit_kwargs": {"series": 0}, + "predict_kwargs": None, + "retrain": True, + "show_warnings": False, + }, + { + "fit_kwargs": None, + "predict_kwargs": {"series": 0}, + "retrain": True, + "show_warnings": False, + }, + ], + ) + def test_historical_forecasts_sanitize_kwargs_invalid(self, kwargs): + with pytest.raises(ValueError): + _ = hfc_utils._historical_forecasts_sanitize_kwargs(self.model, **kwargs) + + def test_historical_forecasts_check_start(self): + """""" + series = linear_timeseries(start=0, length=1) + kwargs = { + "start": 0, + "start_format": "value", + "series_start": 0, + "ref_start": 0, + "ref_end": 0, + "stride": 0, + "series_idx": 0, + "is_historical_forecast": False, + } + # low enough start idx works with any kwargs + hfc_utils._check_start(series, start_idx=0, **kwargs) + + # start idx >= len(series) raises error + with pytest.raises(ValueError): + hfc_utils._check_start(series, start_idx=1, **kwargs) + + @pytest.mark.parametrize( + "config", + [ + (True, pd.Timestamp("2000-01-01"), "value"), + (True, 0.9, "value"), + (True, 0.9, "position"), + (True, 0, "position"), + (True, 0, "value"), + (False, pd.Timestamp("2000-01-01"), "value"), + (False, 0.9, "value"), + (False, 0.9, "position"), + (False, 0, "position"), + ], + ) + def test_historical_forecasts_check_start_invalid(self, config): + """""" + is_dt, start, start_format = config + series = linear_timeseries(start="2000-01-01" if is_dt else 0, length=1) + series_start = series.start_time() + kwargs = { + "start": start, + "start_format": start_format, + "series_start": series_start, + "ref_start": 0, + "ref_end": 0, + "stride": 0, + "series_idx": 0, + "is_historical_forecast": False, + } + + # low enough start idx works with any kwargs + with pytest.raises(ValueError) as err: + hfc_utils._check_start(series, start_idx=1, **kwargs) + + # make sure we reach the expected error message and message is specific to input + position_msg = f"position `{start}` corresponding to time " + if start_format == "position" or is_dt and not isinstance(start, pd.Timestamp): + assert position_msg in str(err.value) + else: + assert position_msg not in str(err.value) + + @pytest.mark.parametrize( + "config", + [ + (0, 0, 0), + (1, 1, 1), + (1, 10, 1), + (-1, 1, 0), + (-3, 1, 0), + (-1, 2, 1), + (-2, 2, 0), + (-3, 2, 1), + ], + ) + def test_adjust_start(self, config): + """Check relative start position adjustment.""" + start_rel, stride, start_expected = config + assert hfc_utils._adjust_start(start_rel, stride) == start_expected diff --git a/darts/timeseries.py b/darts/timeseries.py index f6fe161464..fe77ad1390 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -2170,7 +2170,7 @@ def get_index_at_point( ``pd.Timestamp`` work only on series that are indexed with a ``pd.DatetimeIndex``. In such cases, the returned point will be the index of this timestamp if it is present in the series time index. - It it's not present in the time index, the index of the next timestamp is returned if `after=True` + If it's not present in the time index, the index of the next timestamp is returned if `after=True` (if it exists in the series), otherwise the index of the previous timestamp is returned (if it exists in the series). diff --git a/darts/utils/historical_forecasts/__init__.py b/darts/utils/historical_forecasts/__init__.py index fcd2ea765f..51145458b1 100644 --- a/darts/utils/historical_forecasts/__init__.py +++ b/darts/utils/historical_forecasts/__init__.py @@ -6,7 +6,6 @@ _check_optimizable_historical_forecasts_global_models, _get_historical_forecast_boundaries, _historical_forecasts_general_checks, - _historical_forecasts_start_warnings, _process_historical_forecast_input, ) @@ -16,6 +15,5 @@ "_check_optimizable_historical_forecasts_global_models", "_get_historical_forecast_boundaries", "_historical_forecasts_general_checks", - "_historical_forecasts_start_warnings", "_process_historical_forecast_input", ] diff --git a/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py b/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py index eeef59f04b..13f1c1d36a 100644 --- a/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py +++ b/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py @@ -74,6 +74,7 @@ def _optimized_historical_forecasts_last_points_only( start_format=start_format, forecast_horizon=forecast_horizon, overlap_end=overlap_end, + stride=stride, freq=freq, show_warnings=show_warnings, ) @@ -236,6 +237,7 @@ def _optimized_historical_forecasts_all_points( start_format=start_format, forecast_horizon=forecast_horizon, overlap_end=overlap_end, + stride=stride, freq=freq, show_warnings=show_warnings, ) diff --git a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py index 182516daf3..5dc8951e79 100644 --- a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py +++ b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py @@ -69,6 +69,7 @@ def _optimized_historical_forecasts( start_format=start_format, forecast_horizon=forecast_horizon, overlap_end=overlap_end, + stride=stride, freq=series_.freq, show_warnings=show_warnings, ) diff --git a/darts/utils/historical_forecasts/utils.py b/darts/utils/historical_forecasts/utils.py index 9099f9a44d..00a3b3d618 100644 --- a/darts/utils/historical_forecasts/utils.py +++ b/darts/utils/historical_forecasts/utils.py @@ -1,6 +1,8 @@ from types import SimpleNamespace from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union +from darts.utils import n_steps_between + try: from typing import Literal except ImportError: @@ -12,7 +14,7 @@ import pandas as pd from numpy.typing import ArrayLike -from darts.logging import get_logger, raise_if_not, raise_log +from darts.logging import get_logger, raise_log from darts.timeseries import TimeSeries from darts.utils.ts_utils import get_series_seq_type, series2seq from darts.utils.utils import generate_index @@ -47,18 +49,18 @@ def _historical_forecasts_general_checks(model, series, kwargs): n = SimpleNamespace(**kwargs) # check forecast horizon - raise_if_not( - n.forecast_horizon > 0, - "The provided forecasting horizon must be a positive integer.", - logger, - ) + if not n.forecast_horizon > 0: + raise_log( + ValueError("The provided forecasting horizon must be a positive integer."), + logger, + ) # check stride - raise_if_not( - n.stride > 0, - "The provided stride parameter must be a positive integer.", - logger, - ) + if not n.stride > 0: + raise_log( + ValueError("The provided stride parameter must be a positive integer."), + logger, + ) series = series2seq(series) @@ -78,100 +80,79 @@ def _historical_forecasts_general_checks(model, series, kwargs): f"`start_format` must be on of ['position', 'value']. Received '{n.start_format}'." ) ) - if n.start_format == "position": - raise_if_not( - isinstance(n.start, (int, np.int64)), - f"Since `start_format='position'`, `start` must be an integer, received {type(n.start)}.", + if n.start_format == "position" and not isinstance(n.start, (int, np.int64)): + raise_log( + ValueError( + f"Since `start_format='position'`, `start` must be an integer, received {type(n.start)}." + ), logger, ) - - if isinstance(n.start, float): - raise_if_not( - 0.0 <= n.start <= 1.0, - "if `start` is a float, must be between 0.0 and 1.0.", + if isinstance(n.start, float) and not 0.0 <= n.start <= 1.0: + raise_log( + ValueError("if `start` is a float, must be between 0.0 and 1.0."), logger, ) - # verbose error messages - if not isinstance(n.start, pd.Timestamp): - start_value_msg = f"`start` value `{n.start}` corresponding to timestamp" - else: - start_value_msg = "`start` time" for idx, series_ in enumerate(series): # check specifically for int and Timestamp as error by `get_timestamp_at_point` is too generic if isinstance(n.start, pd.Timestamp): - if n.start > series_.end_time(): + if not series_._has_datetime_index: raise_log( ValueError( - f"`start` time `{n.start}` is after the last timestamp `{series_.end_time()}` of the " - f"series at index: {idx}." + "if `start` is a `pd.Timestamp`, all series must be indexed with a `pd.DatetimeIndex`" ), logger, ) - elif n.start < series_.start_time(): + if n.start > series_.end_time(): raise_log( ValueError( - f"`start` time `{n.start}` is before the first timestamp `{series_.start_time()}` of the " + f"`start` time `{n.start}` is after the last timestamp `{series_.end_time()}` of the " f"series at index: {idx}." ), logger, ) elif isinstance(n.start, (int, np.int64)): - out_of_bound_error = False - if n.start_format == "position": - if (n.start > 0 and n.start >= len(series_)) or ( - n.start < 0 and np.abs(n.start) > len(series_) - ): - out_of_bound_error = True - elif series_.has_datetime_index: + if n.start_format == "position" or series_.has_datetime_index: if n.start >= len(series_): - out_of_bound_error = True - elif n.start < series_.time_index[0]: + raise_log( + ValueError( + f"`start` position `{n.start}` is out of bounds for series of length {len(series_)} " + f"at index: {idx}." + ), + logger, + ) + elif n.start > series_.time_index[-1]: # format "value" and range index raise_log( ValueError( - f"`start` index `{n.start}` is smaller than the first index `{series_.time_index[0]}` " + f"`start` time `{n.start}` is larger than the last index `{series_.time_index[-1]}` " f"for series at index: {idx}." ), logger, ) - elif n.start > series_.time_index[-1]: - raise_log( - ValueError( - f"`start` index `{n.start}` is larger than the last index `{series_.time_index[-1]}` " - f"for series at index: {idx}." - ), - logger, - ) - - if out_of_bound_error: - raise_log( - ValueError( - f"`start` index `{n.start}` is out of bounds for series of length {len(series_)} " - f"at index: {idx}." - ), - logger, - ) - - if n.start_format == "value": - start = series_.get_timestamp_at_point(n.start) - else: - start = series_.time_index[n.start] - if n.retrain is not False and start == series_.start_time(): - raise_log( - ValueError( - f"{start_value_msg} `{start}` is the first timestamp of the series {idx}, resulting in an " - f"empty training set." - ), - logger, - ) + # find valid start position relative to the series start time, otherwise raise an error + start_idx, _ = _get_start_index( + series_, idx, n.start, n.start_format, n.stride + ) - # check that overlap_end and start together form a valid combination + # check that `overlap_end` and `start` are a valid combination overlap_end = n.overlap_end if ( not overlap_end - and start + (series_.freq * (n.forecast_horizon - 1)) not in series_ + and start_idx + n.forecast_horizon + model.output_chunk_shift + > len(series_) ): + # verbose error messages + if n.start_format == "position" or ( + not isinstance(n.start, pd.Timestamp) + and series_._has_datetime_index + ): + start_value_msg = ( + f"`start` position `{n.start}` corresponding to time" + ) + else: + start_value_msg = "`start` time" + start = series_._time_index[start_idx] raise_log( ValueError( f"{start_value_msg} `{start}` is too late in the series {idx} to make any predictions with " @@ -293,35 +274,152 @@ def _historical_forecasts_check_kwargs( return dict_kwargs -def _historical_forecasts_start_warnings( - idx: int, - start: Union[pd.Timestamp, int], - start_time_: Union[int, pd.Timestamp], - historical_forecasts_time_index: TimeIndex, +def _get_start_index( + series: TimeSeries, + series_idx: int, + start: Union[pd.Timestamp, int, float], + start_format: Literal["value", "position"], + stride: int, + historical_forecasts_time_index: Optional[TimeIndex] = None, ): - """Warnings when start value provided by user is not within the forecastable indexes boundaries""" - if not isinstance(start, pd.Timestamp): - start_value_msg = f"value `{start}` corresponding to timestamp `{start_time_}`" + """Finds a valid historical forecast start point within either `series` or `historical_forecasts_time_index` + (depending on whether `historical_forecasts_time_index` is passed, denoted as `ref`). + + - If `start` is larger or equal to the first index of `ref`, uses `start` directly. + - If `start` is before the first index of `ref`, tries to find a start point within `ref` that lies a + round-multiple `stride` time steps ahead of `start`. + + Raises an error if the new start index from above is larger than the last index in `ref`. + + Parameters + ---------- + series + A time series. If `historical_forecasts_time_index` is `None`, will use this series' time index as a reference + index. + series_idx + The sequence index of the `series`. + start + The start point for historical forecasts. + start_format + The start format for historical forecasts. + stride + The stride for historical forecasts. + historical_forecasts_time_index + Optionally, the historical forecast index (or the boundaries only) to use as the reference index. + """ + series_start, series_end = series._time_index[0], series._time_index[-1] + has_dti = series._has_datetime_index + # find start position relative to the series start time + if isinstance(start, float): + # fraction of series + rel_start = series.get_index_at_point(start) + elif start_format == "value" and not (isinstance(start, int) and has_dti): + # start is a time stamp for DatetimeIndex, and integer for RangeIndex + rel_start = n_steps_between(start, series_start, freq=series.freq) else: - start_value_msg = f"time `{start_time_}`" + # start is a positional index + start: int + rel_start = start if start >= 0 else len(series) - abs(start) - if start_time_ < historical_forecasts_time_index[0]: - logger.warning( - f"`start` {start_value_msg} is before the first predictable/trainable historical " - f"forecasting point for series at index: {idx}. Ignoring `start` for this series and " - f"beginning at first trainable/predictable time: {historical_forecasts_time_index[0]}. " - f"To hide these warnings, set `show_warnings=False`." + # find actual start time + start_idx = _adjust_start(rel_start, stride) + _check_start( + series=series, + start_idx=start_idx, + start=start, + start_format=start_format, + series_start=series_start, + ref_start=series_start, + ref_end=series_end, + stride=stride, + series_idx=series_idx, + is_historical_forecast=False, + ) + if historical_forecasts_time_index is not None: + hfc_start, hfc_end = ( + historical_forecasts_time_index[0], + historical_forecasts_time_index[-1], + ) + # at this point, we know that `start_idx` is within `series`. Now, find the position of that time step + # relative to the first forecastable point + rel_start_hfc = n_steps_between( + series._time_index[start_idx], hfc_start, freq=series.freq ) + # get the positional index of `hfc_start` in `series` + hfc_start_idx = start_idx - rel_start_hfc + # potentially, adjust the position to be inside the forecastable points + hfc_start_idx += _adjust_start(rel_start_hfc, stride) + _check_start( + series=series, + start_idx=hfc_start_idx, + start=start, + start_format=start_format, + series_start=series_start, + ref_start=hfc_start, + ref_end=hfc_end, + stride=stride, + series_idx=series_idx, + is_historical_forecast=True, + ) + start_idx = hfc_start_idx + return start_idx, rel_start + + +def _adjust_start(rel_start, stride): + """If relative start position `rel_start` is negative, then adjust it to the first non-negative index that lies a + round-multiple of `stride` ahead of `rel_start` + """ + if rel_start >= 0: + start_idx = rel_start else: - logger.warning( - f"`start` {start_value_msg} is after the last trainable/predictable historical " - f"forecasting point for series at index: {idx}. This would results in empty historical " - f"forecasts. Ignoring `start` for this series and beginning at first trainable/" - f"predictable time: {historical_forecasts_time_index[0]}. Non-empty forecasts can be " - f"generated by setting `start` value to times between (including): " - f"{(historical_forecasts_time_index[0], historical_forecasts_time_index[-1])}. " - f"To hide these warnings, set `show_warnings=False`." + # if `start` lies before the start time of `series` -> check if there is a valid start point in + # `series` that is a round-multiple of `stride` ahead of `start` + start_idx = ( + rel_start + + (abs(rel_start) // stride + int(abs(rel_start) % stride > 0)) * stride ) + return start_idx + + +def _check_start( + series: TimeSeries, + start_idx: int, + start: Union[pd.Timestamp, int, float], + start_format: Literal["value", "position"], + series_start: Union[pd.Timestamp, int], + ref_start: Union[pd.Timestamp, int], + ref_end: Union[pd.Timestamp, int], + stride: int, + series_idx: int, + is_historical_forecast: bool, +): + """Raises an error if the start index (position) is not within the series.""" + if start_idx < len(series): + return + + if start_format == "position" or ( + not isinstance(start, pd.Timestamp) and series._has_datetime_index + ): + start_format_msg = f"position `{start}` corresponding to time " + if isinstance(start, float): + # fraction of series + start = series.get_index_at_point(start) + else: + start = series.start_time() + start * series.freq + else: + start_format_msg = "time " + ref_msg = "" if not is_historical_forecast else "historical forecastable " + start_new = series_start + start_idx * series.freq + raise_log( + ValueError( + f"`start` {start_format_msg}`{start}` is smaller than the first {ref_msg}time index " + f"`{ref_start}` for series at index: {series_idx}, and could not find a valid start " + f"point within the {ref_msg}time index that lies a round-multiple of `stride={stride}` " + f"ahead of `start` (first inferred start is `{start_new}`, but last {ref_msg}time index " + f"is `{ref_end}`." + ), + logger=logger, + ) def _get_historical_forecastable_time_index( @@ -542,37 +640,51 @@ def _adjust_historical_forecasts_time_index( historical_forecasts_time_index: TimeIndex, start: Optional[Union[pd.Timestamp, float, int]], start_format: Literal["position", "value"], + stride: int, show_warnings: bool, ) -> TimeIndex: """ Shrink the beginning and end of the historical forecasts time index based on the values of `start`, `forecast_horizon` and `overlap_end`. """ + # retrieve actual start # when applicable, shift the start of the forecastable index based on `start` if start is not None: - if start_format == "value": - start_time_ = series.get_timestamp_at_point(start) - else: - start_time_ = series.time_index[start] - # ignore user-defined `start` - if ( - not historical_forecasts_time_index[0] - <= start_time_ - <= historical_forecasts_time_index[-1] - ): - if show_warnings: - _historical_forecasts_start_warnings( - idx=series_idx, - start=start, - start_time_=start_time_, - historical_forecasts_time_index=historical_forecasts_time_index, + # find valid start position relative to the hfc start time, otherwise raise an error + start_idx, start_idx_orig = _get_start_index( + series=series, + series_idx=series_idx, + start=start, + start_format=start_format, + stride=stride, + historical_forecasts_time_index=historical_forecasts_time_index, + ) + start_time = series._time_index[start_idx] + + if start_idx != start_idx_orig and show_warnings: + if start_idx_orig >= 0: + start_time_orig = series._time_index[start_idx_orig] + else: + start_time_orig = series.start_time() + start_idx_orig * series.freq + + if start_format == "position" or ( + not isinstance(start, pd.Timestamp) and series._has_datetime_index + ): + start_value_msg = ( + f"position `{start}` corresponding to time `{start_time_orig}`" ) - else: - historical_forecasts_time_index = ( - max(historical_forecasts_time_index[0], start_time_), - historical_forecasts_time_index[1], + else: + start_value_msg = f"time `{start_time_orig}`" + logger.warning( + f"`start` {start_value_msg} is before the first predictable/trainable historical " + f"forecasting point for series at index: {series_idx}. Using the first historical forecasting " + f"point `{start_time}` that lies a round-multiple of `stride={stride}` " + f"ahead of `start`. To hide these warnings, set `show_warnings=False`." ) - + historical_forecasts_time_index = ( + max(historical_forecasts_time_index[0], start_time), + historical_forecasts_time_index[1], + ) return historical_forecasts_time_index @@ -719,6 +831,7 @@ def _get_historical_forecast_boundaries( start_format: Literal["position", "value"], forecast_horizon: int, overlap_end: bool, + stride: int, freq: pd.DateOffset, show_warnings: bool = True, ) -> Tuple[Any, ...]: @@ -748,6 +861,7 @@ def _get_historical_forecast_boundaries( historical_forecasts_time_index=historical_forecasts_time_index, start=start, start_format=start_format, + stride=stride, show_warnings=show_warnings, )