Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader authored Jan 17, 2024
2 parents 31065f8 + 962fd78 commit 0f0dbe1
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 34 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- Improvements to `TimeSeries`:
- Improved the time series frequency inference when using slices or pandas DatetimeIndex as keys for `__getitem__`. [#2152](https://github.com/unit8co/darts/pull/2152) by [DavidKleindienst](https://github.com/DavidKleindienst).

**Fixed**

- Fixed a bug when using a `TorchForecastingModel` with `use_reversible_instance_norm=True` and predicting with `n > output_chunk_length`. The input normalized multiple times. [#2160](https://github.com/unit8co/darts/pull/2160) by [FourierMourier](https://github.com/FourierMourier).
### For developers of the library:

## [0.27.1](https://github.com/unit8co/darts/tree/0.27.1) (2023-12-10)
Expand Down
3 changes: 2 additions & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def forward_wrapper(self, *args, **kwargs):

# x is input batch tuple which by definition has the past features in the first element starting with the
# first n target features
x: Tuple = args[0][0]
# assuming `args[0][0]` is torch.Tensor we could clone it to prevent target re-normalization
x: Tuple = args[0][0].clone()
# apply reversible instance normalization
x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
# run the forward pass
Expand Down
19 changes: 11 additions & 8 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,17 +707,14 @@ def test_load_weights_params_check(self, tmpdir_fn):
ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt")
# barebone model
model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
n_epochs=1,
input_chunk_length=4, output_chunk_length=1, n_epochs=1, **tfm_kwargs
)
model.fit(self.series[:10])
model.save(ckpt_path)

# identical model
loading_model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
input_chunk_length=4, output_chunk_length=1, **tfm_kwargs
)
loading_model.load_weights(ckpt_path)

Expand All @@ -726,21 +723,26 @@ def test_load_weights_params_check(self, tmpdir_fn):
input_chunk_length=4,
output_chunk_length=1,
optimizer_cls=torch.optim.AdamW,
**tfm_kwargs,
)
loading_model.load_weights(ckpt_path)

model_summary_kwargs = {
"pl_trainer_kwargs": dict(
{"enable_model_sumamry": False}, **tfm_kwargs["pl_trainer_kwargs"]
)
}
# different pl_trainer_kwargs
loading_model = DLinearModel(
input_chunk_length=4,
output_chunk_length=1,
pl_trainer_kwargs={"enable_model_summary": False},
**model_summary_kwargs,
)
loading_model.load_weights(ckpt_path)

# different input_chunk_length (tfm parameter)
loading_model = DLinearModel(
input_chunk_length=4 + 1,
output_chunk_length=1,
input_chunk_length=4 + 1, output_chunk_length=1, **tfm_kwargs
)
with pytest.raises(ValueError) as error_msg:
loading_model.load_weights(ckpt_path)
Expand All @@ -754,6 +756,7 @@ def test_load_weights_params_check(self, tmpdir_fn):
input_chunk_length=4,
output_chunk_length=1,
kernel_size=10,
**tfm_kwargs,
)
with pytest.raises(ValueError) as error_msg:
loading_model.load_weights(ckpt_path)
Expand Down
55 changes: 49 additions & 6 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


class TestTimeSeries:

times = pd.date_range("20130101", "20130110", freq="D")
pd_series1 = pd.Series(range(10), index=times)
pd_series2 = pd.Series(range(5, 15), index=times)
Expand Down Expand Up @@ -933,6 +932,55 @@ def test_getitem_integer_index(self):
with pytest.raises(KeyError):
_ = series[pd.RangeIndex(start, stop=end + 2 * freq, step=freq)]

def test_getitem_frequency_inferrence(self):
ts = self.series1
assert ts.freq == "D"
ts_got = ts[1::2]
assert ts_got.freq == "2D"
ts_got = ts[pd.Timestamp("20130103") :: 2]
assert ts_got.freq == "2D"

idx = pd.DatetimeIndex(["20130102", "20130105", "20130108"])
ts_idx = ts[idx]
assert ts_idx.freq == "3D"

# With BusinessDay frequency
offset = pd.offsets.BusinessDay() # Closed on Saturdays & Sundays
dates1 = pd.date_range("20231101", "20231126", freq=offset)
values1 = np.ones(len(dates1))
ts = TimeSeries.from_times_and_values(dates1, values1)
assert ts.freq == ts[-4:].freq

# Using a step parameter
assert ts[1::3].freq == 3 * ts.freq
assert ts[pd.Timestamp("20231102") :: 4].freq == 4 * ts.freq

# Indexing with datetime index
idx = pd.date_range("20231101", "20231126", freq=offset)
assert ts[idx].freq == idx.freq

def test_getitem_frequency_inferrence_integer_index(self):
start = 2
freq = 3
ts = TimeSeries.from_times_and_values(
times=pd.RangeIndex(
start=start, stop=start + freq * len(self.series1), step=freq
),
values=self.series1.values(),
)

assert ts.freq == freq
ts_got = ts[1::2]
assert ts_got.start_time() == start + freq
assert ts_got.freq == 2 * freq

idx = pd.RangeIndex(
start=start + 2 * freq, stop=start + 4 * freq, step=2 * freq
)
ts_idx = ts[idx]
assert ts_idx.start_time() == idx[0]
assert ts_idx.freq == 2 * freq

def test_fill_missing_dates(self):
with pytest.raises(ValueError):
# Series cannot have date holes without automatic filling
Expand Down Expand Up @@ -1050,7 +1098,6 @@ def test_fill_missing_dates(self):

series_target = TimeSeries.from_dataframe(df_full, time_col="date")
for df, df_name in zip([df_full, df_holes], ["full", "holes"]):

# fill_missing_dates will find multiple inferred frequencies (i.e. for 'B' it finds {'B', 'D'})
if offset_alias in offset_aliases_raise:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -1519,7 +1566,6 @@ def test_to_csv_stochastic(self, pddf_mock):


class TestTimeSeriesConcatenate:

#
# COMPONENT AXIS TESTS
#
Expand Down Expand Up @@ -1735,7 +1781,6 @@ def test_concatenate_timeseries_method(self):


class TestTimeSeriesHierarchy:

components = ["total", "a", "b", "x", "y", "ax", "ay", "bx", "by"]

hierarchy = {
Expand Down Expand Up @@ -1912,7 +1957,6 @@ def test_with_string_items(self):


class TestTimeSeriesHeadTail:

ts = TimeSeries(
xr.DataArray(
np.random.rand(10, 10, 10),
Expand Down Expand Up @@ -2185,7 +2229,6 @@ def test_df_named_columns_index(self):


class TestSimpleStatistics:

times = pd.date_range("20130101", "20130110", freq="D")
values = np.random.rand(10, 2, 100)
ar = xr.DataArray(
Expand Down
64 changes: 48 additions & 16 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,12 +4899,13 @@ def _check_range():
logger,
)

def _set_freq_in_xa(xa_: xr.DataArray):
def _set_freq_in_xa(xa_: xr.DataArray, freq=None):
# mutates the DataArray to make sure it contains the freq
if isinstance(xa_.get_index(self._time_dim), pd.DatetimeIndex):
inferred_freq = xa_.get_index(self._time_dim).inferred_freq
if inferred_freq is not None:
xa_.get_index(self._time_dim).freq = to_offset(inferred_freq)
if freq is None:
freq = xa_.get_index(self._time_dim).inferred_freq
if freq is not None:
xa_.get_index(self._time_dim).freq = to_offset(freq)
else:
xa_.get_index(self._time_dim).freq = self._freq

Expand All @@ -4920,8 +4921,9 @@ def _set_freq_in_xa(xa_: xr.DataArray):
xa_ = self._xa.sel({self._time_dim: key})

# indexing may discard the freq so we restore it...
# TODO: unit-test this
_set_freq_in_xa(xa_)
# if the DateTimeIndex already has an associated freq, use it
# otherwise key.freq is None and the freq will be inferred
_set_freq_in_xa(xa_, key.freq)

return self.__class__(xa_)
elif isinstance(key, pd.RangeIndex):
Expand Down Expand Up @@ -4951,18 +4953,43 @@ def _set_freq_in_xa(xa_: xr.DataArray):
key.stop, (int, np.int64)
):
xa_ = self._xa.isel({self._time_dim: key})
_set_freq_in_xa(
xa_
) # indexing may discard the freq so we restore it...
if isinstance(key.step, (int, np.int64)):
# new frequency is multiple of original
new_freq = key.step * self.freq
elif key.step is None:
new_freq = self.freq
else:
new_freq = None
raise_log(
ValueError(
f"Invalid slice step={key.step}. Only supports integer steps or `None`."
),
logger=logger,
)
# indexing may discard the freq so we restore it...
_set_freq_in_xa(xa_, new_freq)
return self.__class__(xa_)
elif isinstance(key.start, pd.Timestamp) or isinstance(
key.stop, pd.Timestamp
):
_check_dt()
if isinstance(key.step, (int, np.int64)):
# new frequency is multiple of original
new_freq = key.step * self.freq
elif key.step is None:
new_freq = self.freq
else:
new_freq = None
raise_log(
ValueError(
f"Invalid slice step={key.step}. Only supports integer steps or `None`."
),
logger=logger,
)

# indexing may discard the freq so we restore it...
xa_ = self._xa.sel({self._time_dim: key})
_set_freq_in_xa(xa_)
_set_freq_in_xa(xa_, new_freq)
return self.__class__(xa_)

# handle simple types:
Expand Down Expand Up @@ -5030,13 +5057,18 @@ def _set_freq_in_xa(xa_: xr.DataArray):
# We have to restore a RangeIndex. But first we need to
# check the list is corresponding to a RangeIndex.
min_idx, max_idx = min(key), max(key)
raise_if_not(
key[0] == min_idx
if (
not key[0] == min_idx
and key[-1] == max_idx
and max_idx + 1 - min_idx == len(key),
"Indexing a TimeSeries with a list requires the list to contain monotically "
+ "increasing integers with no gap.",
)
and max_idx + 1 - min_idx == len(key)
):
raise_log(
ValueError(
"Indexing a TimeSeries with a list requires the list to "
"contain monotonically increasing integers with no gap."
),
logger=logger,
)
new_idx = orig_idx[min_idx : max_idx + 1]
xa_ = xa_.assign_coords({self._time_dim: new_idx})

Expand Down
4 changes: 2 additions & 2 deletions requirements/release.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ ipykernel==5.3.4
ipywidgets==7.5.1
jupyterlab==4.0.3
ipython_genutils==0.2.0
jinja2==3.0.3
jinja2==3.1.3
m2r2==0.3.2
nbsphinx==0.8.7
numpydoc==1.1.0
papermill==2.2.2
pydata-sphinx-theme==0.7.2
recommonmark==0.7.1
sphinx==4.3.2
sphinx==5.0.0
sphinx-automodapi==0.14.0
sphinx_autodoc_typehints==1.12.0
twine==3.3.0

0 comments on commit 0f0dbe1

Please sign in to comment.