From 8cb04f67982bd3c77016bdd5fbdd928bf7d11bf7 Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Fri, 2 Feb 2024 11:12:29 +0100 Subject: [PATCH] Feat/improve timeseries (#2196) * found major peformance boost for time series creation * first boosted time series version * improve slicing with integers * improve slicing with time stamps * improve slicing with time stamps * update from_xarray * improve from_group_dataframe() * remove test time series * remove old time series * add option to drop group columns from from_group_dataframe * update changelog * apply suggestions from PR review --- CHANGELOG.md | 6 + .../test_timeseries_static_covariates.py | 87 +++- darts/timeseries.py | 440 ++++++++++-------- 3 files changed, 330 insertions(+), 203 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d23107988e..20c59d3efc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** +- Improvements to `TimeSeries`: [#2196](https://github.com/unit8co/darts/pull/2196) by [Dennis Bader](https://github.com/dennisbader). + - 🚀🚀🚀 Significant performance boosts for several `TimeSeries` methods resulting increased efficiency across the entire `Darts` library. Up to 2x faster creation times for series indexed with "regular" frequencies (e.g. Daily, hourly, ...), and >100x for series indexed with "special" frequencies (e.g. "W-MON", ...). Affects: + - All `TimeSeries` creation methods + - Additional boosts for slicing with integers and Timestamps + - Additional boosts for `from_group_dataframe()` by performing some of the heavy-duty computations on the entire DataFrame, rather than iteratively on the group level. + - Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`. **Fixed** diff --git a/darts/tests/test_timeseries_static_covariates.py b/darts/tests/test_timeseries_static_covariates.py index 917f76a49e..fa188dbb4a 100644 --- a/darts/tests/test_timeseries_static_covariates.py +++ b/darts/tests/test_timeseries_static_covariates.py @@ -154,27 +154,92 @@ def test_timeseries_from_longitudinal_df(self): ) assert (ts.static_covariates_values(copy=False) == [[i, j, 1]]).all() - df = copy.deepcopy(self.df_long_multi) - df.loc[:, "non_static"] = np.arange(len(df)) - # non static columns as static columns should raise an error - with pytest.raises(ValueError): + # drop group columns gives same time series with dropped static covariates + # drop first column + ts_groups4 = TimeSeries.from_group_dataframe( + df=self.df_long_multi, + group_cols=["st1", "st2"], + static_cols=["constant"], + time_col="times", + value_cols=value_cols, + drop_group_cols=["st1"], + ) + assert len(ts_groups4) == self.n_groups * 2 + for idx, ts in enumerate(ts_groups4): + j = idx % 2 + assert ts.static_covariates.shape == (1, 2) + assert ts.static_covariates.columns.equals(pd.Index(["st2", "constant"])) + assert (ts.static_covariates_values(copy=False) == [[j, 1]]).all() + + # drop last column + ts_groups5 = TimeSeries.from_group_dataframe( + df=self.df_long_multi, + group_cols=["st1", "st2"], + static_cols=["constant"], + time_col="times", + value_cols=value_cols, + drop_group_cols=["st2"], + ) + assert len(ts_groups5) == self.n_groups * 2 + for idx, ts in enumerate(ts_groups5): + i = idx // 2 + assert ts.static_covariates.shape == (1, 2) + assert ts.static_covariates.columns.equals(pd.Index(["st1", "constant"])) + assert (ts.static_covariates_values(copy=False) == [[i, 1]]).all() + + # drop all columns + ts_groups6 = TimeSeries.from_group_dataframe( + df=self.df_long_multi, + group_cols=["st1", "st2"], + static_cols=["constant"], + time_col="times", + value_cols=value_cols, + drop_group_cols=["st1", "st2"], + ) + assert len(ts_groups6) == self.n_groups * 2 + for ts in ts_groups6: + assert ts.static_covariates.shape == (1, 1) + assert ts.static_covariates.columns.equals(pd.Index(["constant"])) + assert (ts.static_covariates_values(copy=False) == [[1]]).all() + + # drop all static covariates (no `static_cols`, all `group_cols` dropped) + ts_groups7 = TimeSeries.from_group_dataframe( + df=self.df_long_multi, + group_cols=["st1", "st2"], + time_col="times", + value_cols=value_cols, + drop_group_cols=["st1", "st2"], + ) + assert len(ts_groups7) == self.n_groups * 2 + for ts in ts_groups7: + assert ts.static_covariates is None + + def test_from_group_dataframe_invalid_drop_cols(self): + # drop col is not part of `group_cols` + with pytest.raises(ValueError) as err: _ = TimeSeries.from_group_dataframe( - df=df, + df=self.df_long_multi, group_cols=["st1"], - static_cols=["non_static"], time_col="times", - value_cols=value_cols, + value_cols="a", + drop_group_cols=["invalid"], ) + assert str(err.value).endswith("received: {'invalid'}.") + def test_from_group_dataframe_groups_too_short(self): # groups that are too short for TimeSeries requirements should raise an error - with pytest.raises(ValueError): + df = copy.deepcopy(self.df_long_multi) + df.loc[:, "non_static"] = np.arange(len(df)) + with pytest.raises(ValueError) as err: _ = TimeSeries.from_group_dataframe( df=df, - group_cols=["st1", "non_static"], - static_cols=None, + group_cols="non_static", time_col="times", - value_cols=value_cols, + value_cols="a", ) + assert str(err.value).startswith( + "The time index of the provided DataArray is missing the freq attribute" + ) def test_with_static_covariates_univariate(self): ts = linear_timeseries(length=10) diff --git a/darts/timeseries.py b/darts/timeseries.py index 0a715d10b8..30d5aac716 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -75,7 +75,7 @@ class TimeSeries: - def __init__(self, xa: xr.DataArray): + def __init__(self, xa: xr.DataArray, copy=True): """ Create a TimeSeries from a (well formed) DataArray. It is recommended to use the factory methods to create TimeSeries instead. @@ -91,29 +91,32 @@ def __init__(self, xa: xr.DataArray): TimeSeries.from_json : Create from a JSON file. TimeSeries.from_xarray : Create from an :class:`xarray.DataArray`. """ - raise_if_not( - isinstance(xa, xr.DataArray), - "Data must be provided as an xarray DataArray instance. " - "If you need to create a TimeSeries from another type " - "(e.g. a DataFrame), look at TimeSeries factory methods " - "(e.g. TimeSeries.from_dataframe(), " - "TimeSeries.from_xarray(), TimeSeries.from_values()" - "TimeSeries.from_times_and_values(), etc...).", - logger, - ) - raise_if_not( - len(xa.shape) == 3, - f"TimeSeries require DataArray of dimensionality 3 ({DIMS}).", - logger, - ) + if not isinstance(xa, xr.DataArray): + raise_log( + ValueError( + "Data must be provided as an xarray DataArray instance. " + "If you need to create a TimeSeries from another type " + "(e.g. a DataFrame), look at TimeSeries factory methods " + "(e.g. TimeSeries.from_dataframe(), " + "TimeSeries.from_xarray(), TimeSeries.from_values()" + "TimeSeries.from_times_and_values(), etc...)." + ), + logger, + ) + if len(xa.shape) != 3: + raise_log( + ValueError( + f"TimeSeries require DataArray of dimensionality 3 ({DIMS})." + ), + logger, + ) # Ideally values should be np.float, otherwise certain functionalities like diff() # relying on np.nan (which is a float) won't work very properly. - raise_if_not( - np.issubdtype(xa.values.dtype, np.number), - "The time series must contain numeric values only.", - logger, - ) + if not np.issubdtype(xa.values.dtype, np.number): + raise_log( + ValueError("The time series must contain numeric values only."), logger + ) val_dtype = xa.values.dtype if not ( @@ -138,24 +141,22 @@ def __init__(self, xa: xr.DataArray): # check that columns/component names are unique components = xa.get_index(DIMS[1]) - raise_if_not( - len(set(components)) == len(components), - "The components (columns) names must be unique. Provided: {}".format( - components - ), - logger, - ) + if not len(set(components)) == len(components): + raise_log( + ValueError( + f"The components (columns) names must be unique. Provided: {components}" + ), + logger, + ) - self._time_dim = str( - xa.dims[0] - ) # how the time dimension is named; we convert hashable to string + # how the time dimension is named; we convert hashable to string + self._time_dim = str(xa.dims[0]) # The following sorting returns a copy, which we are relying on. # As of xarray 0.18.2, this sorting discards the freq of the index for some reason # https://github.com/pydata/xarray/issues/5466 # We sort only if the time axis is not already sorted (monotonically increasing). - self._xa = self._sort_index(xa, copy=True) - + self._xa = self._sort_index(xa, copy=copy) self._time_index = self._xa.get_index(self._time_dim) if not isinstance(self._time_index, VALID_INDEX_TYPES): @@ -170,74 +171,64 @@ def __init__(self, xa: xr.DataArray): self._has_datetime_index = isinstance(self._time_index, pd.DatetimeIndex) if self._has_datetime_index: - freq_tmp = xa.get_index( - self._time_dim - ).freq # store original freq (see bug of sortby() above). - self._freq: pd.DateOffset = ( - freq_tmp - if freq_tmp is not None - else to_offset(self._xa.get_index(self._time_dim).inferred_freq) - ) - raise_if( - self._freq is None, - "The time index of the provided DataArray is missing the freq attribute, and the frequency could " - "not be directly inferred. " - "This probably comes from inconsistent date frequencies with missing dates. " - "If you know the actual frequency, try setting `fill_missing_dates=True, freq=actual_frequency`. " - "If not, try setting `fill_missing_dates=True, freq=None` to see if a frequency can be inferred.", - logger, - ) + # store original freq (see bug of sortby() above). + freq_tmp = xa.get_index(self._time_dim).freq + + # if original frequency is known and positive (n > 0 -> increasing time index), + # it is guaranteed that original array was sorted and new freq must be the same. + # otherwise, infer the frequency from the sorted array + if freq_tmp is not None and freq_tmp.n > 0: + self._freq = freq_tmp + else: + self._freq = to_offset(self._xa.get_index(self._time_dim).inferred_freq) + + if self._freq is None: + raise_log( + ValueError( + "The time index of the provided DataArray is missing the freq attribute, and the frequency " + "could not be directly inferred. This probably comes from inconsistent date frequencies with " + "missing dates. If you know the actual frequency, try setting `fill_missing_dates=True, " + "freq=actual_frequency`. If not, try setting `fill_missing_dates=True, freq=None` to see if a " + "frequency can be inferred." + ), + logger, + ) self._freq_str: str = self._freq.freqstr # reset freq inside the xarray index (see bug of sortby() above). self._xa.get_index(self._time_dim).freq = self._freq - - # We have to check manually if the index is complete for non-empty series. Another way could - # be to rely on `inferred_freq` being present, but this fails for series of length < 3. - if len(self._time_index) > 0: - is_index_complete = ( - len( - pd.date_range( - self._time_index.min(), - self._time_index.max(), - freq=self._freq, - ).difference(self._time_index) - ) - == 0 - ) - - raise_if_not( - is_index_complete, - "Not all timestamps seem to be present in the time index. Does " - "the series contain holes? If you are using a factory method, " - "try specifying `fill_missing_dates=True` " - "or specify the `freq` parameter.", - logger, - ) else: self._freq: int = self._time_index.step self._freq_str = None # check static covariates static_covariates = self._xa.attrs.get(STATIC_COV_TAG, None) - raise_if_not( + if not ( isinstance(static_covariates, (pd.Series, pd.DataFrame)) - or static_covariates is None, - "`static_covariates` must be either a pandas Series, DataFrame or None", - logger, - ) + or static_covariates is None + ): + raise_log( + ValueError( + "`static_covariates` must be either a pandas Series, DataFrame or None" + ), + logger, + ) + # check if valid static covariates for multivariate TimeSeries if isinstance(static_covariates, pd.DataFrame): n_components = len(static_covariates) - raise_if( - n_components > 1 and n_components != self.n_components, - "When passing a multi-row pandas DataFrame, the number of rows must match the number of " - "components of the TimeSeries object (multi-component/multi-row static covariates must map to each " - "TimeSeries component).", - logger, - ) - static_covariates = static_covariates.copy() + if n_components > 1 and n_components != self.n_components: + raise_log( + ValueError( + "When passing a multi-row pandas DataFrame, the number of rows must match the number of " + "components of the TimeSeries object (multi-component/multi-row static covariates " + "must map to each TimeSeries component)." + ), + logger, + ) + if copy: + static_covariates = static_covariates.copy() elif isinstance(static_covariates, pd.Series): static_covariates = static_covariates.to_frame().T else: # None @@ -257,20 +248,25 @@ def __init__(self, xa: xr.DataArray): include=np.number, exclude=self.dtype ).columns - changes = {col: self.dtype for col in cols_to_cast} # Calling astype is costly even when there's no change... - if len(changes) > 0: - static_covariates = static_covariates.astype(changes, copy=False) + if not cols_to_cast.empty: + static_covariates = static_covariates.astype( + {col: self.dtype for col in cols_to_cast}, copy=False + ) # handle hierarchy hierarchy = self._xa.attrs.get(HIERARCHY_TAG, None) self._top_level_component = None self._bottom_level_components = None if hierarchy is not None: - raise_if_not( - isinstance(hierarchy, dict), - "The hierarchy must be a dict mapping (non-top) component names to their parent(s) in the hierarchy.", - ) + if not isinstance(hierarchy, dict): + raise_log( + ValueError( + "The hierarchy must be a dict mapping (non-top) component names to their parent(s) " + "in the hierarchy." + ), + logger, + ) # pre-compute grouping informations components_set = set(self.components) children = set(hierarchy.keys()) @@ -280,27 +276,40 @@ def __init__(self, xa: xr.DataArray): k: ([v] if isinstance(v, str) else v) for k, v in hierarchy.items() } - raise_if_not( - all(c in components_set for c in children), - "The keys of the hierarchy must be time series components", - ) + if not all(c in components_set for c in children): + raise_log( + ValueError( + "The keys of the hierarchy must be time series components" + ), + logger, + ) ancestors = set().union(*hierarchy.values()) - raise_if_not( - all(a in components_set for a in ancestors), - "The values of the hierarchy must only contain component names matching those of the series.", - ) + if not all(a in components_set for a in ancestors): + raise_log( + ValueError( + "The values of the hierarchy must only contain component names matching those " + "of the series." + ), + logger, + ) hierarchy_top = components_set - children - raise_if_not( - len(hierarchy_top) == 1, - "The hierarchy must be such that only one component does " - + "not appear as a key (the top level component).", - ) + if not len(hierarchy_top) == 1: + raise_log( + ValueError( + "The hierarchy must be such that only one component does not appear as a key " + "(the top level component)." + ), + logger, + ) self._top_level_component = hierarchy_top.pop() - raise_if_not( - self._top_level_component in ancestors, - "Invalid hierarchy. Component {} appears as it should be top-level, but " - + "does not appear as an ancestor in the hierarchy dict.", - ) + if self._top_level_component not in ancestors: + raise_log( + ValueError( + "Invalid hierarchy. Component {} appears as it should be top-level, but " + "does not appear as an ancestor in the hierarchy dict." + ), + logger, + ) bottom_level = components_set - ancestors # maintain the same order as the original components @@ -308,7 +317,7 @@ def __init__(self, xa: xr.DataArray): c for c in self.components if c in bottom_level ] - # Store static covariates and hierarchy in attributes (potentially storing None) + # store static covariates and hierarchy in attributes (potentially storing None) self._xa = _xarray_with_attrs(self._xa, static_covariates, hierarchy) """ @@ -433,6 +442,15 @@ def _clean_component_list(columns) -> List[str]: # by the user than handled by silent renaming, which can change the way things work. # TODO: is there a way to just update the component index without re-creating a new DataArray? + # -> Answer: Yes, but it's slower: e.g.: + # ``` + # xa_ = xa_.assign_coords( + # { + # time_index_name: xa_.get_index(time_index_name), + # DIMS[1]: columns_list + # } + # ) + # ``` xa_ = xr.DataArray( xa_.values, dims=xa_.dims, @@ -738,7 +756,8 @@ def from_group_dataframe( fill_missing_dates: Optional[bool] = False, freq: Optional[Union[str, int]] = None, fillna_value: Optional[float] = None, - ) -> List["TimeSeries"]: + drop_group_cols: Optional[Union[List[str], str]] = None, + ) -> List[Self]: """ Build a list of TimeSeries instances grouped by a selection of columns from a DataFrame. One column (or the DataFrame index) has to represent the time, @@ -785,6 +804,8 @@ def from_group_dataframe( If an integer, represents the step size of the pandas Index or pandas RangeIndex. fillna_value Optionally, a numeric value to fill missing values (NaNs) with. + drop_group_cols + Optionally, a string or list of strings with `group_cols` column(s) to exclude from the static covariates. Returns ------- @@ -799,6 +820,27 @@ def from_group_dataframe( ) group_cols = [group_cols] if not isinstance(group_cols, list) else group_cols + if drop_group_cols: + drop_group_cols = ( + [drop_group_cols] + if not isinstance(drop_group_cols, list) + else drop_group_cols + ) + invalid_cols = set(drop_group_cols) - set(group_cols) + if invalid_cols: + raise_log( + ValueError( + f"Found invalid `drop_group_cols` columns. All columns must be in the passed `group_cols`. " + f"Expected any of: {group_cols}, received: {invalid_cols}." + ), + logger=logger, + ) + drop_group_col_idx = [ + idx for idx, col in enumerate(group_cols) if col in drop_group_cols + ] + else: + drop_group_cols = [] + drop_group_col_idx = [] if static_cols is not None: static_cols = ( [static_cols] if not isinstance(static_cols, list) else static_cols @@ -806,6 +848,22 @@ def from_group_dataframe( else: static_cols = [] static_cov_cols = group_cols + static_cols + extract_static_cov_cols = [ + col for col in static_cov_cols if col not in drop_group_cols + ] + extract_time_col = [] if time_col is None else [time_col] + + if value_cols is None: + value_cols = df.columns.drop(static_cov_cols + extract_time_col).tolist() + extract_value_cols = [value_cols] if isinstance(value_cols, str) else value_cols + + df = df[static_cov_cols + extract_value_cols + extract_time_col] + + # sort on entire `df` to avoid having to sort individually later on + if time_col: + df.index = pd.DatetimeIndex(df[time_col]) + df = df.drop(columns=time_col) + df = df.sort_index() # split df by groups, and store group values and static values (static covariates) # single elements group columns must be unpacked for same groupby() behavior across different pandas versions @@ -818,6 +876,17 @@ def from_group_dataframe( if not isinstance(static_cov_vals, tuple) else static_cov_vals ) + # optionally, exclude group columns from static covariates + if drop_group_col_idx: + if len(drop_group_col_idx) == len(group_cols): + static_cov_vals = tuple() + else: + static_cov_vals = tuple( + val + for idx, val in enumerate(static_cov_vals) + if idx not in drop_group_col_idx + ) + # check that for each group there is only one unique value per column in `static_cols` if static_cols: static_cols_valid = [ @@ -842,17 +911,17 @@ def from_group_dataframe( # store static covariate Series and group DataFrame (without static cov columns) splits.append( ( - pd.DataFrame([static_cov_vals], columns=static_cov_cols), - group.drop(columns=static_cov_cols), + pd.DataFrame([static_cov_vals], columns=extract_static_cov_cols) + if extract_static_cov_cols + else None, + group[extract_value_cols], ) ) # create a list with multiple TimeSeries and add static covariates return [ - TimeSeries.from_dataframe( + cls.from_dataframe( df=split, - time_col=time_col, - value_cols=value_cols, fill_missing_dates=fill_missing_dates, freq=freq, fillna_value=fillna_value, @@ -1248,7 +1317,7 @@ def bottom_level_components(self) -> Optional[List[str]]: return self._bottom_level_components @property - def top_level_series(self) -> Optional["TimeSeries"]: + def top_level_series(self) -> Optional[Self]: """ The univariate series containing the single top-level component of this series, or None if the series has no hierarchy. @@ -1256,7 +1325,7 @@ def top_level_series(self) -> Optional["TimeSeries"]: return self[self.top_level_component] if self.has_hierarchy else None @property - def bottom_level_series(self) -> Optional[List["TimeSeries"]]: + def bottom_level_series(self) -> Optional[List[Self]]: """ The series containing the bottom-level components of this series in the same order as they appear in the series, or None if the series has no hierarchy. @@ -1930,12 +1999,12 @@ def tail( def concatenate( self, - other: "TimeSeries", + other: Self, axis: Optional[Union[str, int]] = 0, ignore_time_axis: Optional[bool] = False, ignore_static_covariates: bool = False, drop_hierarchy: bool = True, - ) -> "TimeSeries": + ) -> Self: """ Concatenate another timeseries to the current one along given axis. @@ -2158,7 +2227,7 @@ def get_timestamp_at_point( def _split_at( self, split_point: Union[pd.Timestamp, float, int], after: bool = True - ) -> Tuple["TimeSeries", "TimeSeries"]: + ) -> Tuple[Self, Self]: # Get index with not after in order to avoid moving twice if split_point is not in self point_index = self.get_index_at_point(split_point, not after) return ( @@ -2168,7 +2237,7 @@ def _split_at( def split_after( self, split_point: Union[pd.Timestamp, float, int] - ) -> Tuple["TimeSeries", "TimeSeries"]: + ) -> Tuple[Self, Self]: """ Splits the series in two, after a provided `split_point`. @@ -2191,7 +2260,7 @@ def split_after( def split_before( self, split_point: Union[pd.Timestamp, float, int] - ) -> Tuple["TimeSeries", "TimeSeries"]: + ) -> Tuple[Self, Self]: """ Splits the series in two, before a provided `split_point`. @@ -2386,7 +2455,7 @@ def slice_n_points_before(self, end_ts: Union[pd.Timestamp, int], n: int) -> Sel ValueError("start_ts must be an int or a pandas Timestamp."), logger ) - def slice_intersect(self, other: "TimeSeries") -> Self: + def slice_intersect(self, other: Self) -> Self: """ Return a ``TimeSeries`` slice of this series, where the time index has been intersected with the one of the `other` series. @@ -2629,7 +2698,7 @@ def cumsum(self) -> Self: """ return self.__class__(self._xa.copy().cumsum(axis=0)) - def has_same_time_as(self, other: "TimeSeries") -> bool: + def has_same_time_as(self, other: Self) -> bool: """ Checks whether this series has the same time index as `other`. @@ -2647,7 +2716,7 @@ def has_same_time_as(self, other: "TimeSeries") -> bool: return False return (other.time_index == self.time_index).all() - def append(self, other: "TimeSeries") -> Self: + def append(self, other: Self) -> Self: """ Appends another series to this series along the time axis. @@ -2743,7 +2812,7 @@ def append_values(self, values: np.ndarray) -> Self: ) ) - def prepend(self, other: "TimeSeries") -> Self: + def prepend(self, other: Self) -> Self: """ Prepends (i.e. adds to the beginning) another series to this series along the time axis. @@ -2941,7 +3010,7 @@ def with_hierarchy(self, hierarchy: Dict[str, Union[str, List[str]]]): ) ) - def stack(self, other: "TimeSeries") -> "TimeSeries": + def stack(self, other: Self) -> Self: """ Stacks another univariate or multivariate TimeSeries with the same time index on top of the current one (along the component axis). @@ -3017,7 +3086,7 @@ def add_datetime_attribute( one_hot: bool = False, cyclic: bool = False, tz: Optional[str] = None, - ) -> "TimeSeries": + ) -> Self: """ Build a new series with one (or more) additional component(s) that contain an attribute of the time index of the series. @@ -3064,7 +3133,7 @@ def add_holidays( prov: str = None, state: str = None, tz: Optional[str] = None, - ) -> "TimeSeries": + ) -> Self: """ Adds a binary univariate component to the current series that equals 1 at every index that corresponds to selected country's holiday, and 0 otherwise. @@ -3287,7 +3356,7 @@ def window_transform( forecasting_safe: Optional[bool] = True, keep_non_transformed: Optional[bool] = False, include_current: Optional[bool] = True, - ) -> "TimeSeries": + ) -> Self: """ Applies a moving/rolling, expanding or exponentially weighted window transformation over this ``TimeSeries``. @@ -4307,7 +4376,7 @@ def quantile(self, quantile: float, **kwargs) -> Self: def _combine_arrays( self, - other: Union["TimeSeries", xr.DataArray, np.ndarray], + other: Union[Self, xr.DataArray, np.ndarray], combine_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], ) -> Self: """ @@ -4884,30 +4953,40 @@ def __getitem__( """ def _check_dt(): - raise_if_not( - self._has_datetime_index, - "Attempted indexing a series with a DatetimeIndex or a timestamp, " - "but the series uses a RangeIndex.", - logger, - ) + if not self._has_datetime_index: + raise_log( + ValueError( + "Attempted indexing a series with a DatetimeIndex or a timestamp, " + "but the series uses a RangeIndex." + ), + logger, + ) def _check_range(): - raise_if( - self._has_datetime_index, - "Attempted indexing a series with a RangeIndex, " - "but the series uses a DatetimeIndex.", - logger, - ) + if self._has_datetime_index: + raise_log( + ValueError( + "Attempted indexing a series with a RangeIndex, " + "but the series uses a DatetimeIndex." + ), + logger, + ) - def _set_freq_in_xa(xa_: xr.DataArray, freq=None): + def _set_freq_in_xa(xa_in: xr.DataArray, freq=None): # mutates the DataArray to make sure it contains the freq - if isinstance(xa_.get_index(self._time_dim), pd.DatetimeIndex): + if isinstance(xa_in.get_index(self._time_dim), pd.DatetimeIndex): if freq is None: - freq = xa_.get_index(self._time_dim).inferred_freq + freq = xa_in.get_index(self._time_dim).inferred_freq if freq is not None: - xa_.get_index(self._time_dim).freq = to_offset(freq) + xa_in.get_index(self._time_dim).freq = freq else: - xa_.get_index(self._time_dim).freq = self._freq + xa_in.get_index(self._time_dim).freq = self._freq + + def _get_freq(xa_in: xr.DataArray): + if self._has_datetime_index: + return xa_in.get_index(self._time_dim).freq + else: + return xa_in.get_index(self._time_dim).step adapt_covs_on_component = ( True @@ -4920,7 +4999,7 @@ def _set_freq_in_xa(xa_: xr.DataArray, freq=None): _check_dt() xa_ = self._xa.sel({self._time_dim: key}) - # indexing may discard the freq so we restore it... + # indexing may discard the freq, so we restore it... # if the DateTimeIndex already has an associated freq, use it # otherwise key.freq is None and the freq will be inferred _set_freq_in_xa(xa_, key.freq) @@ -4953,43 +5032,20 @@ def _set_freq_in_xa(xa_: xr.DataArray, freq=None): key.stop, (int, np.int64) ): xa_ = self._xa.isel({self._time_dim: key}) - if isinstance(key.step, (int, np.int64)): - # new frequency is multiple of original - new_freq = key.step * self.freq - elif key.step is None: - new_freq = self.freq - else: - new_freq = None - raise_log( - ValueError( - f"Invalid slice step={key.step}. Only supports integer steps or `None`." - ), - logger=logger, - ) - # indexing may discard the freq so we restore it... - _set_freq_in_xa(xa_, new_freq) + if _get_freq(xa_) is None: + # indexing discarded the freq; we restore it + freq = key.step * self.freq if key.step else self.freq + _set_freq_in_xa(xa_, freq) return self.__class__(xa_) elif isinstance(key.start, pd.Timestamp) or isinstance( key.stop, pd.Timestamp ): _check_dt() - if isinstance(key.step, (int, np.int64)): - # new frequency is multiple of original - new_freq = key.step * self.freq - elif key.step is None: - new_freq = self.freq - else: - new_freq = None - raise_log( - ValueError( - f"Invalid slice step={key.step}. Only supports integer steps or `None`." - ), - logger=logger, - ) - - # indexing may discard the freq so we restore it... xa_ = self._xa.sel({self._time_dim: key}) - _set_freq_in_xa(xa_, new_freq) + if _get_freq(xa_) is None: + # indexing discarded the freq; we restore it + freq = key.step * self.freq if key.step else self.freq + _set_freq_in_xa(xa_, freq) return self.__class__(xa_) # handle simple types: @@ -5022,15 +5078,15 @@ def _set_freq_in_xa(xa_: xr.DataArray, freq=None): ) } ) - - _set_freq_in_xa(xa_) # indexing may discard the freq so we restore it... + # indexing may discard the freq, so we restore it... + _set_freq_in_xa(xa_, freq=self.freq) return self.__class__(xa_) elif isinstance(key, pd.Timestamp): _check_dt() - # indexing may discard the freq so we restore it... + # indexing may discard the freq, so we restore it... xa_ = self._xa.sel({self._time_dim: [key]}) - _set_freq_in_xa(xa_) + _set_freq_in_xa(xa_, self.freq) return self.__class__(xa_) # handle lists: @@ -5049,7 +5105,7 @@ def _set_freq_in_xa(xa_: xr.DataArray, freq=None): elif all(isinstance(i, (int, np.int64)) for i in key): xa_ = self._xa.isel({self._time_dim: key}) - # indexing may discard the freq so we restore it... + # indexing may discard the freq, so we restore it... _set_freq_in_xa(xa_) orig_idx = self.time_index @@ -5077,7 +5133,7 @@ def _set_freq_in_xa(xa_: xr.DataArray, freq=None): elif all(isinstance(t, pd.Timestamp) for t in key): _check_dt() - # indexing may discard the freq so we restore it... + # indexing may discard the freq, so we restore it... xa_ = self._xa.sel({self._time_dim: key}) _set_freq_in_xa(xa_) return self.__class__(xa_) @@ -5095,7 +5151,7 @@ def _xarray_with_attrs(xa_, static_covariates, hierarchy): return xa_ -def _concat_static_covs(series: Sequence["TimeSeries"]) -> Optional[pd.DataFrame]: +def _concat_static_covs(series: Sequence[TimeSeries]) -> Optional[pd.DataFrame]: """Concatenates static covariates along component dimension (rows of static covariates). For stacking or concatenating TimeSeries along component dimension (axis=1). @@ -5154,7 +5210,7 @@ def _concat_static_covs(series: Sequence["TimeSeries"]) -> Optional[pd.DataFrame ) -def _concat_hierarchy(series: Sequence["TimeSeries"]): +def _concat_hierarchy(series: Sequence[TimeSeries]): """ Used to concatenate the hierarchies of multiple TimeSeries, when concatenating series along axis 1 (components). This simply merges the hierarchy dictionaries. @@ -5167,7 +5223,7 @@ def _concat_hierarchy(series: Sequence["TimeSeries"]): def concatenate( - series: Sequence["TimeSeries"], + series: Sequence[TimeSeries], axis: Union[str, int] = 0, ignore_time_axis: bool = False, ignore_static_covariates: bool = False,