diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index bd5e1b1562..6c8452436f 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -10,6 +10,7 @@ from scipy.stats import kurtosis, skew from darts import TimeSeries, concatenate +from darts.timeseries import intersect from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries from darts.utils.utils import freqs, generate_index @@ -603,12 +604,32 @@ def check_intersect(other, start_, end_, freq_): s_int_idx = series.slice_intersect_times(other, copy=False) assert s_int.time_index.equals(s_int_idx) + def check_intersect_sequence(series, other, start_, end_, freq_): + intersected_series = intersect([series, other]) + s_int = intersected_series[0] + o_int = intersected_series[1] + + assert intersected_series == [ + series.slice_intersect(other), + other.slice_intersect(series), + ] + + if start_ is None: # empty slice + assert len(s_int) == 0 + assert len(o_int) == 0 + return + + assert s_int.start_time() == o_int.start_time() == start_ + assert s_int.end_time() == o_int.end_time() == end_ + assert s_int.freq == o_int.freq == freq_ + # slice with exact range startA = start endA = end idxA = generate_index(startA, endA, freq=freq_other) seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA)) check_intersect(seriesA, startA, endA, freq_expected) + check_intersect_sequence(series, seriesA, start, end, freq_expected) # entire slice within the range startA = start + freq @@ -616,6 +637,7 @@ def check_intersect(other, start_, end_, freq_): idxA = generate_index(startA, endA, freq=freq_other) seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA)) check_intersect(seriesA, startA, endA, freq_expected) + check_intersect_sequence(series, seriesA, startA, endA, freq_expected) # start outside of range startC = start - 4 * freq @@ -623,6 +645,7 @@ def check_intersect(other, start_, end_, freq_): idxC = generate_index(startC, endC, freq=freq_other) seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC)) check_intersect(seriesC, start, endC, freq_expected) + check_intersect_sequence(series, seriesC, start, endC, freq_expected) # end outside of range startC = start + 4 * freq @@ -630,6 +653,7 @@ def check_intersect(other, start_, end_, freq_): idxC = generate_index(startC, endC, freq=freq_other) seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC)) check_intersect(seriesC, startC, end, freq_expected) + check_intersect_sequence(series, seriesC, startC, end, freq_expected) # small intersect startE = start + (n_steps - 1) * freq @@ -637,6 +661,7 @@ def check_intersect(other, start_, end_, freq_): idxE = generate_index(startE, endE, freq=freq_other) seriesE = TimeSeries.from_series(pd.Series(range(len(idxE)), index=idxE)) check_intersect(seriesE, startE, end, freq_expected) + check_intersect_sequence(series, seriesE, startE, end, freq_expected) # No intersect startG = end + 3 * freq @@ -645,6 +670,10 @@ def check_intersect(other, start_, end_, freq_): seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG)) # for empty slices, we expect the original freq check_intersect(seriesG, None, None, freq) + check_intersect_sequence(series, seriesG, None, None, freq) + + # Empty sequence + assert intersect([]) == [] @staticmethod def helper_test_shift(test_case, test_series: TimeSeries): diff --git a/darts/timeseries.py b/darts/timeseries.py index 5f7878eb56..5b25d2ad60 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -5629,6 +5629,37 @@ def concatenate( return TimeSeries.from_xarray(da_concat, fill_missing_dates=False) +def intersect(series: Sequence[TimeSeries]): + """Returns the intersection with respect to the time index of multiple ``TimeSeries``. + + Parameters + ---------- + series : Sequence[TimeSeries] + sequence of ``TimeSeries`` to intersect + + Returns + ------- + Sequence[TimeSeries] + Intersected series + """ + + data_arrays = [] + has_datetime_index = series[0].has_datetime_index + for ts in series: + if ts.has_datetime_index != has_datetime_index: + raise_log( + IndexError( + "The time index type must be the same for all TimeSeries in the Sequence." + ), + logger, + ) + data_arrays.append(ts.data_array(copy=False)) + + intersected_series = xr.align(*data_arrays, exclude=["component", "sample"]) + + return [TimeSeries.from_xarray(array) for array in intersected_series] + + def _finite_rows_boundaries( values: np.ndarray, how: str = "all" ) -> Tuple[Optional[int], Optional[int]]: