Skip to content

Commit

Permalink
Add a function to find the intersection of multiple time series
Browse files Browse the repository at this point in the history
  • Loading branch information
ymatzkevich committed Nov 12, 2024
1 parent 26c5f39 commit 9be5adb
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
29 changes: 29 additions & 0 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts.timeseries import intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import freqs, generate_index

Expand Down Expand Up @@ -603,40 +604,64 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

def check_intersect_sequence(series, other, start_, end_, freq_):
intersected_series = intersect([series, other])
s_int = intersected_series[0]
o_int = intersected_series[1]

assert intersected_series == [
series.slice_intersect(other),
other.slice_intersect(series),
]

if start_ is None: # empty slice
assert len(s_int) == 0
assert len(o_int) == 0
return

assert s_int.start_time() == o_int.start_time() == start_
assert s_int.end_time() == o_int.end_time() == end_
assert s_int.freq == o_int.freq == freq_

# slice with exact range
startA = start
endA = end
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, start, end, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, startA, endA, freq_expected)

# start outside of range
startC = start - 4 * freq
endC = start + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, start, endC, freq_expected)
check_intersect_sequence(series, seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
check_intersect_sequence(series, seriesC, startC, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
endE = startE + 2 * freq_other
idxE = generate_index(startE, endE, freq=freq_other)
seriesE = TimeSeries.from_series(pd.Series(range(len(idxE)), index=idxE))
check_intersect(seriesE, startE, end, freq_expected)
check_intersect_sequence(series, seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
Expand All @@ -645,6 +670,10 @@ def check_intersect(other, start_, end_, freq_):
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect_sequence(series, seriesG, None, None, freq)

# Empty sequence
assert intersect([]) == []

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down
31 changes: 31 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5629,6 +5629,37 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def intersect(series: Sequence[TimeSeries]):
"""Returns the intersection with respect to the time index of multiple ``TimeSeries``.
Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect
Returns
-------
Sequence[TimeSeries]
Intersected series
"""

data_arrays = []
has_datetime_index = series[0].has_datetime_index
for ts in series:
if ts.has_datetime_index != has_datetime_index:
raise_log(
IndexError(
"The time index type must be the same for all TimeSeries in the Sequence."
),
logger,
)
data_arrays.append(ts.data_array(copy=False))

intersected_series = xr.align(*data_arrays, exclude=["component", "sample"])

return [TimeSeries.from_xarray(array) for array in intersected_series]


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> Tuple[Optional[int], Optional[int]]:
Expand Down

0 comments on commit 9be5adb

Please sign in to comment.