Skip to content

Commit 0ae1729

Browse files
committed
Proper treatment of intersection of empty sequence.
1 parent b6f6812 commit 0ae1729

File tree

2 files changed

+26
-108
lines changed

2 files changed

+26
-108
lines changed

darts/tests/test_timeseries.py

Lines changed: 24 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from darts import TimeSeries, concatenate
1313
from darts.timeseries import intersect
1414
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
15-
from darts.utils.utils import expand_arr, freqs, generate_index
15+
from darts.utils.utils import freqs, generate_index
1616

1717

1818
class TestTimeSeries:
@@ -791,9 +791,6 @@ def helper_test_prepend_values(test_case, test_series: TimeSeries):
791791
assert test_series.time_index.equals(prepended_sq.time_index)
792792
assert prepended_sq.components.equals(test_series.components)
793793

794-
# component and sample dimension should match
795-
assert prepended._xa.shape[1:] == test_series._xa.shape[1:]
796-
797794
def test_slice(self):
798795
TestTimeSeries.helper_test_slice(self, self.series1)
799796

@@ -829,112 +826,18 @@ def test_append(self):
829826
assert appended.time_index.equals(expected_idx)
830827
assert appended.components.equals(series_1.components)
831828

832-
@pytest.mark.parametrize(
833-
"config",
834-
itertools.product(
835-
[
836-
( # univariate array
837-
np.array([0, 1, 2]).reshape((3, 1, 1)),
838-
np.array([0, 1]).reshape((2, 1, 1)),
839-
),
840-
( # multivariate array
841-
np.array([0, 1, 2, 3, 4, 5]).reshape((3, 2, 1)),
842-
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
843-
),
844-
( # empty array
845-
np.array([0, 1, 2]).reshape((3, 1, 1)),
846-
np.array([]).reshape((0, 1, 1)),
847-
),
848-
(
849-
# wrong number of components
850-
np.array([0, 1, 2]).reshape((3, 1, 1)),
851-
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
852-
),
853-
(
854-
# wrong number of samples
855-
np.array([0, 1, 2]).reshape((3, 1, 1)),
856-
np.array([0, 1, 2, 3]).reshape((2, 1, 2)),
857-
),
858-
( # univariate list with times
859-
np.array([0, 1, 2]).reshape((3, 1, 1)),
860-
[0, 1],
861-
),
862-
( # univariate list with times and components
863-
np.array([0, 1, 2]).reshape((3, 1, 1)),
864-
[[0], [1]],
865-
),
866-
( # univariate list with times, components and samples
867-
np.array([0, 1, 2]).reshape((3, 1, 1)),
868-
[[[0]], [[1]]],
869-
),
870-
( # multivar with list has wrong shape
871-
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
872-
[[1, 2], [3, 4]],
873-
),
874-
( # list with wrong number of components
875-
np.array([0, 1, 2]).reshape((3, 1, 1)),
876-
[[1, 2], [3, 4]],
877-
),
878-
( # list with wrong number of samples
879-
np.array([0, 1, 2]).reshape((3, 1, 1)),
880-
[[[0, 1]], [[1, 2]]],
881-
),
882-
( # multivar input but list has wrong shape
883-
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
884-
[1, 2],
885-
),
886-
],
887-
[True, False],
888-
["append_values", "prepend_values"],
889-
),
890-
)
891-
def test_append_and_prepend_values(self, config):
892-
(series_vals, vals), is_datetime, method = config
893-
start = "20240101" if is_datetime else 1
894-
series_idx = generate_index(
895-
start=start, length=len(series_vals), name="some_name"
896-
)
897-
series = TimeSeries.from_times_and_values(
898-
times=series_idx,
899-
values=series_vals,
829+
def test_append_values(self):
830+
TestTimeSeries.helper_test_append_values(self, self.series1)
831+
# Check `append_values` deals with `RangeIndex` series correctly:
832+
series = linear_timeseries(start=1, length=5, freq=2)
833+
appended = series.append_values(np.ones((2, 1, 1)))
834+
expected_vals = np.concatenate(
835+
[series.all_values(), np.ones((2, 1, 1))], axis=0
900836
)
901-
902-
# expand if it's a list
903-
vals_arr = np.array(vals) if isinstance(vals, list) else vals
904-
vals_arr = expand_arr(vals_arr, ndim=3)
905-
906-
ts_method = getattr(TimeSeries, method)
907-
908-
if vals_arr.shape[1:] != series_vals.shape[1:]:
909-
with pytest.raises(ValueError) as exc:
910-
_ = ts_method(series, vals)
911-
assert str(exc.value).startswith(
912-
"The (expanded) values must have the same number of components and samples"
913-
)
914-
return
915-
916-
appended = ts_method(series, vals)
917-
918-
if method == "append_values":
919-
expected_vals = np.concatenate([series_vals, vals_arr], axis=0)
920-
expected_idx = generate_index(
921-
start=series.start_time(),
922-
length=len(series_vals) + len(vals),
923-
freq=series.freq,
924-
)
925-
else:
926-
expected_vals = np.concatenate([vals_arr, series_vals], axis=0)
927-
expected_idx = generate_index(
928-
end=series.end_time(),
929-
length=len(series_vals) + len(vals),
930-
freq=series.freq,
931-
)
932-
837+
expected_idx = pd.RangeIndex(start=1, stop=15, step=2)
933838
assert np.allclose(appended.all_values(), expected_vals)
934839
assert appended.time_index.equals(expected_idx)
935840
assert appended.components.equals(series.components)
936-
assert appended._xa.shape[1:] == series._xa.shape[1:]
937-
assert appended.time_index.name == series.time_index.name
938841

939842
def test_prepend(self):
940843
TestTimeSeries.helper_test_prepend(self, self.series1)
@@ -950,6 +853,19 @@ def test_prepend(self):
950853
assert prepended.time_index.equals(expected_idx)
951854
assert prepended.components.equals(series_1.components)
952855

856+
def test_prepend_values(self):
857+
TestTimeSeries.helper_test_prepend_values(self, self.series1)
858+
# Check `prepend_values` deals with `RangeIndex` series correctly:
859+
series = linear_timeseries(start=1, length=5, freq=2)
860+
prepended = series.prepend_values(np.ones((2, 1, 1)))
861+
expected_vals = np.concatenate(
862+
[np.ones((2, 1, 1)), series.all_values()], axis=0
863+
)
864+
expected_idx = pd.RangeIndex(start=-3, stop=11, step=2)
865+
assert np.allclose(prepended.all_values(), expected_vals)
866+
assert prepended.time_index.equals(expected_idx)
867+
assert prepended.components.equals(series.components)
868+
953869
@pytest.mark.parametrize(
954870
"config",
955871
[
@@ -2461,8 +2377,8 @@ def test_time_col_with_tz(self):
24612377
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
24622378
assert ts.time_index.tz is None
24632379

2464-
series = pd.Series(data=values, index=time_range_H)
2465-
ts = TimeSeries.from_series(pd_series=series)
2380+
serie = pd.Series(data=values, index=time_range_H)
2381+
ts = TimeSeries.from_series(pd_series=serie)
24662382
assert list(ts.time_index) == list(time_range_H.tz_localize(None))
24672383
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
24682384
assert ts.time_index.tz is None

darts/timeseries.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5672,6 +5672,8 @@ def intersect(series: Sequence[TimeSeries]):
56725672
Sequence[TimeSeries]
56735673
Intersected series
56745674
"""
5675+
if not series:
5676+
return []
56755677

56765678
data_arrays = []
56775679
has_datetime_index = series[0].has_datetime_index

0 commit comments

Comments
 (0)