Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/slice intersect multi series #2592

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
29 changes: 29 additions & 0 deletions darts/tests/test_timeseries.py
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts.timeseries import intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index

Expand Down Expand Up @@ -603,40 +604,64 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

def check_intersect_sequence(series, other, start_, end_, freq_):
intersected_series = intersect([series, other])
s_int = intersected_series[0]
o_int = intersected_series[1]

assert intersected_series == [
series.slice_intersect(other),
other.slice_intersect(series),
]

if start_ is None: # empty slice
assert len(s_int) == 0
assert len(o_int) == 0
return

assert s_int.start_time() == o_int.start_time() == start_
assert s_int.end_time() == o_int.end_time() == end_
assert s_int.freq == o_int.freq == freq_
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved

# slice with exact range
startA = start
endA = end
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, start, end, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, startA, endA, freq_expected)

# start outside of range
startC = start - 4 * freq
endC = start + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, start, endC, freq_expected)
check_intersect_sequence(series, seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
check_intersect_sequence(series, seriesC, startC, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
endE = startE + 2 * freq_other
idxE = generate_index(startE, endE, freq=freq_other)
seriesE = TimeSeries.from_series(pd.Series(range(len(idxE)), index=idxE))
check_intersect(seriesE, startE, end, freq_expected)
check_intersect_sequence(series, seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
Expand All @@ -645,6 +670,10 @@ def check_intersect(other, start_, end_, freq_):
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect_sequence(series, seriesG, None, None, freq)

# Empty sequence
assert intersect([]) == []

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down
33 changes: 33 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5659,6 +5659,39 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def intersect(series: Sequence[TimeSeries]):
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the intersection with respect to the time index of multiple ``TimeSeries``.

ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect

Returns
-------
Sequence[TimeSeries]
Intersected series
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
"""
if not series:
return []

data_arrays = []
has_datetime_index = series[0].has_datetime_index
for ts in series:
if ts.has_datetime_index != has_datetime_index:
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
raise_log(
IndexError(
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
"The time index type must be the same for all TimeSeries in the Sequence."
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
),
logger,
)
data_arrays.append(ts.data_array(copy=False))

intersected_series = xr.align(*data_arrays, exclude=["component", "sample"])

return [TimeSeries.from_xarray(array) for array in intersected_series]


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> tuple[Optional[int], Optional[int]]:
Expand Down
Loading