diff --git a/CHANGELOG.md b/CHANGELOG.md index 95f52f1a76..fa602543ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** +- Improvements to `TimeSeries`: + - Improved the time series frequency inference when using slices or pandas DatetimeIndex as keys for `__getitem__`. [#2152](https://github.com/unit8co/darts/pull/2152) by [DavidKleindienst](https://github.com/DavidKleindienst). **Fixed** - +- Fixed a bug when using a `TorchForecastingModel` with `use_reversible_instance_norm=True` and predicting with `n > output_chunk_length`. The input normalized multiple times. [#2160](https://github.com/unit8co/darts/pull/2160) by [FourierMourier](https://github.com/FourierMourier). ### For developers of the library: ## [0.27.1](https://github.com/unit8co/darts/tree/0.27.1) (2023-12-10) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 7ade7eaac9..ab98ee59c2 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -50,7 +50,8 @@ def forward_wrapper(self, *args, **kwargs): # x is input batch tuple which by definition has the past features in the first element starting with the # first n target features - x: Tuple = args[0][0] + # assuming `args[0][0]` is torch.Tensor we could clone it to prevent target re-normalization + x: Tuple = args[0][0].clone() # apply reversible instance normalization x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets]) # run the forward pass diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 12729d0519..24b8fd501e 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -707,17 +707,14 @@ def test_load_weights_params_check(self, tmpdir_fn): ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt") # barebone model model = DLinearModel( - input_chunk_length=4, - output_chunk_length=1, - n_epochs=1, + input_chunk_length=4, output_chunk_length=1, n_epochs=1, **tfm_kwargs ) model.fit(self.series[:10]) model.save(ckpt_path) # identical model loading_model = DLinearModel( - input_chunk_length=4, - output_chunk_length=1, + input_chunk_length=4, output_chunk_length=1, **tfm_kwargs ) loading_model.load_weights(ckpt_path) @@ -726,21 +723,26 @@ def test_load_weights_params_check(self, tmpdir_fn): input_chunk_length=4, output_chunk_length=1, optimizer_cls=torch.optim.AdamW, + **tfm_kwargs, ) loading_model.load_weights(ckpt_path) + model_summary_kwargs = { + "pl_trainer_kwargs": dict( + {"enable_model_sumamry": False}, **tfm_kwargs["pl_trainer_kwargs"] + ) + } # different pl_trainer_kwargs loading_model = DLinearModel( input_chunk_length=4, output_chunk_length=1, - pl_trainer_kwargs={"enable_model_summary": False}, + **model_summary_kwargs, ) loading_model.load_weights(ckpt_path) # different input_chunk_length (tfm parameter) loading_model = DLinearModel( - input_chunk_length=4 + 1, - output_chunk_length=1, + input_chunk_length=4 + 1, output_chunk_length=1, **tfm_kwargs ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_path) @@ -754,6 +756,7 @@ def test_load_weights_params_check(self, tmpdir_fn): input_chunk_length=4, output_chunk_length=1, kernel_size=10, + **tfm_kwargs, ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_path) diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index db5459fee6..2a2b7c2788 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -17,7 +17,6 @@ class TestTimeSeries: - times = pd.date_range("20130101", "20130110", freq="D") pd_series1 = pd.Series(range(10), index=times) pd_series2 = pd.Series(range(5, 15), index=times) @@ -933,6 +932,55 @@ def test_getitem_integer_index(self): with pytest.raises(KeyError): _ = series[pd.RangeIndex(start, stop=end + 2 * freq, step=freq)] + def test_getitem_frequency_inferrence(self): + ts = self.series1 + assert ts.freq == "D" + ts_got = ts[1::2] + assert ts_got.freq == "2D" + ts_got = ts[pd.Timestamp("20130103") :: 2] + assert ts_got.freq == "2D" + + idx = pd.DatetimeIndex(["20130102", "20130105", "20130108"]) + ts_idx = ts[idx] + assert ts_idx.freq == "3D" + + # With BusinessDay frequency + offset = pd.offsets.BusinessDay() # Closed on Saturdays & Sundays + dates1 = pd.date_range("20231101", "20231126", freq=offset) + values1 = np.ones(len(dates1)) + ts = TimeSeries.from_times_and_values(dates1, values1) + assert ts.freq == ts[-4:].freq + + # Using a step parameter + assert ts[1::3].freq == 3 * ts.freq + assert ts[pd.Timestamp("20231102") :: 4].freq == 4 * ts.freq + + # Indexing with datetime index + idx = pd.date_range("20231101", "20231126", freq=offset) + assert ts[idx].freq == idx.freq + + def test_getitem_frequency_inferrence_integer_index(self): + start = 2 + freq = 3 + ts = TimeSeries.from_times_and_values( + times=pd.RangeIndex( + start=start, stop=start + freq * len(self.series1), step=freq + ), + values=self.series1.values(), + ) + + assert ts.freq == freq + ts_got = ts[1::2] + assert ts_got.start_time() == start + freq + assert ts_got.freq == 2 * freq + + idx = pd.RangeIndex( + start=start + 2 * freq, stop=start + 4 * freq, step=2 * freq + ) + ts_idx = ts[idx] + assert ts_idx.start_time() == idx[0] + assert ts_idx.freq == 2 * freq + def test_fill_missing_dates(self): with pytest.raises(ValueError): # Series cannot have date holes without automatic filling @@ -1050,7 +1098,6 @@ def test_fill_missing_dates(self): series_target = TimeSeries.from_dataframe(df_full, time_col="date") for df, df_name in zip([df_full, df_holes], ["full", "holes"]): - # fill_missing_dates will find multiple inferred frequencies (i.e. for 'B' it finds {'B', 'D'}) if offset_alias in offset_aliases_raise: with pytest.raises(ValueError): @@ -1519,7 +1566,6 @@ def test_to_csv_stochastic(self, pddf_mock): class TestTimeSeriesConcatenate: - # # COMPONENT AXIS TESTS # @@ -1735,7 +1781,6 @@ def test_concatenate_timeseries_method(self): class TestTimeSeriesHierarchy: - components = ["total", "a", "b", "x", "y", "ax", "ay", "bx", "by"] hierarchy = { @@ -1912,7 +1957,6 @@ def test_with_string_items(self): class TestTimeSeriesHeadTail: - ts = TimeSeries( xr.DataArray( np.random.rand(10, 10, 10), @@ -2185,7 +2229,6 @@ def test_df_named_columns_index(self): class TestSimpleStatistics: - times = pd.date_range("20130101", "20130110", freq="D") values = np.random.rand(10, 2, 100) ar = xr.DataArray( diff --git a/darts/timeseries.py b/darts/timeseries.py index 1792dfe345..0a715d10b8 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -4899,12 +4899,13 @@ def _check_range(): logger, ) - def _set_freq_in_xa(xa_: xr.DataArray): + def _set_freq_in_xa(xa_: xr.DataArray, freq=None): # mutates the DataArray to make sure it contains the freq if isinstance(xa_.get_index(self._time_dim), pd.DatetimeIndex): - inferred_freq = xa_.get_index(self._time_dim).inferred_freq - if inferred_freq is not None: - xa_.get_index(self._time_dim).freq = to_offset(inferred_freq) + if freq is None: + freq = xa_.get_index(self._time_dim).inferred_freq + if freq is not None: + xa_.get_index(self._time_dim).freq = to_offset(freq) else: xa_.get_index(self._time_dim).freq = self._freq @@ -4920,8 +4921,9 @@ def _set_freq_in_xa(xa_: xr.DataArray): xa_ = self._xa.sel({self._time_dim: key}) # indexing may discard the freq so we restore it... - # TODO: unit-test this - _set_freq_in_xa(xa_) + # if the DateTimeIndex already has an associated freq, use it + # otherwise key.freq is None and the freq will be inferred + _set_freq_in_xa(xa_, key.freq) return self.__class__(xa_) elif isinstance(key, pd.RangeIndex): @@ -4951,18 +4953,43 @@ def _set_freq_in_xa(xa_: xr.DataArray): key.stop, (int, np.int64) ): xa_ = self._xa.isel({self._time_dim: key}) - _set_freq_in_xa( - xa_ - ) # indexing may discard the freq so we restore it... + if isinstance(key.step, (int, np.int64)): + # new frequency is multiple of original + new_freq = key.step * self.freq + elif key.step is None: + new_freq = self.freq + else: + new_freq = None + raise_log( + ValueError( + f"Invalid slice step={key.step}. Only supports integer steps or `None`." + ), + logger=logger, + ) + # indexing may discard the freq so we restore it... + _set_freq_in_xa(xa_, new_freq) return self.__class__(xa_) elif isinstance(key.start, pd.Timestamp) or isinstance( key.stop, pd.Timestamp ): _check_dt() + if isinstance(key.step, (int, np.int64)): + # new frequency is multiple of original + new_freq = key.step * self.freq + elif key.step is None: + new_freq = self.freq + else: + new_freq = None + raise_log( + ValueError( + f"Invalid slice step={key.step}. Only supports integer steps or `None`." + ), + logger=logger, + ) # indexing may discard the freq so we restore it... xa_ = self._xa.sel({self._time_dim: key}) - _set_freq_in_xa(xa_) + _set_freq_in_xa(xa_, new_freq) return self.__class__(xa_) # handle simple types: @@ -5030,13 +5057,18 @@ def _set_freq_in_xa(xa_: xr.DataArray): # We have to restore a RangeIndex. But first we need to # check the list is corresponding to a RangeIndex. min_idx, max_idx = min(key), max(key) - raise_if_not( - key[0] == min_idx + if ( + not key[0] == min_idx and key[-1] == max_idx - and max_idx + 1 - min_idx == len(key), - "Indexing a TimeSeries with a list requires the list to contain monotically " - + "increasing integers with no gap.", - ) + and max_idx + 1 - min_idx == len(key) + ): + raise_log( + ValueError( + "Indexing a TimeSeries with a list requires the list to " + "contain monotonically increasing integers with no gap." + ), + logger=logger, + ) new_idx = orig_idx[min_idx : max_idx + 1] xa_ = xa_.assign_coords({self._time_dim: new_idx}) diff --git a/requirements/release.txt b/requirements/release.txt index d83cefcdc8..fadc093e1c 100644 --- a/requirements/release.txt +++ b/requirements/release.txt @@ -5,14 +5,14 @@ ipykernel==5.3.4 ipywidgets==7.5.1 jupyterlab==4.0.3 ipython_genutils==0.2.0 -jinja2==3.0.3 +jinja2==3.1.3 m2r2==0.3.2 nbsphinx==0.8.7 numpydoc==1.1.0 papermill==2.2.2 pydata-sphinx-theme==0.7.2 recommonmark==0.7.1 -sphinx==4.3.2 +sphinx==5.0.0 sphinx-automodapi==0.14.0 sphinx_autodoc_typehints==1.12.0 twine==3.3.0