Skip to content

Commit 0b7fc80

Browse files
committed
Add ReseampleAggregate transformer
1 parent 7621454 commit 0b7fc80

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

src/adtk/transformer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DoubleRollingAggregate,
1717
Retrospect,
1818
RollingAggregate,
19+
ResampleAggregate,
1920
StandardScale,
2021
)
2122
from ._transformer_hd import (
@@ -54,6 +55,7 @@ def print_all_models() -> None:
5455

5556
__all__ = [
5657
"RollingAggregate",
58+
"ResampleAggregate",
5759
"DoubleRollingAggregate",
5860
"ClassicSeasonalDecomposition",
5961
"Retrospect",

src/adtk/transformer/_transformer_1d.py

+137
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,143 @@ def _predict_core(self, s: pd.Series) -> pd.Series:
127127
return (s - mean) / std
128128

129129

130+
131+
class ResampleAggregate(_NonTrainableUnivariateTransformer):
132+
"""Transformer that resample a time series to a specified frequency
133+
134+
Parameters
135+
----------
136+
rule: str. The new frequency for resampling. Examples: '1D', '2H', '30min'.
137+
138+
agg: str or function
139+
Aggregation method applied to series.
140+
If str, must be one of supported built-in methods:
141+
142+
- 'mean': mean of all values in a rolling window.
143+
- 'median': median of all values in a rolling window.
144+
- 'sum': summation of all values in a rolling window.
145+
- 'min': minimum of all values in a rolling window.
146+
- 'max': maximum of all values in a rolling window.
147+
- 'std': sample standard deviation of all values in a rolling window.
148+
- 'var': sample variance of all values in a rolling window.
149+
- 'skew': skewness of all values in a rolling window.
150+
- 'kurt': kurtosis of all values in a rolling window.
151+
- 'count': number of non-nan values in a rolling window.
152+
- 'nnz': number of non-zero values in a rolling window.
153+
- 'nunique': number of unique values in a rolling window.
154+
- 'quantile': quantile of all values in a rolling window. Require
155+
percentile parameter `q` in in parameter `agg_params`, which is a
156+
float or a list of float between 0 and 1 inclusive.
157+
- 'iqr': interquartile range, i.e. difference between 75% and 25%
158+
quantiles.
159+
- 'idr': interdecile range, i.e. difference between 90% and 10%
160+
quantiles.
161+
162+
If function, it should accept a rolling window in form of a pandas
163+
Series, and return either a scalar or a 1D numpy array. To specify
164+
names of outputs, specify a list of strings as a parameter `names` in
165+
parameter `agg_params`.
166+
167+
Default: 'mean'
168+
169+
agg_params: dict, optional
170+
Parameters of aggregation function. Default: None.
171+
172+
"""
173+
174+
def __init__(
175+
self,
176+
rule: str,
177+
agg: Union[
178+
str, Callable[[pd.Series], Union[float, np.ndarray]]
179+
] = "mean",
180+
agg_params: Optional[Dict[str, Any]] = None,
181+
) -> None:
182+
super().__init__()
183+
self.rule = rule
184+
self.agg = agg
185+
self.agg_params = agg_params
186+
187+
@property
188+
def _param_names(self) -> Tuple[str, ...]:
189+
return ("freq", "agg", "agg_params")
190+
191+
def _predict_core(self, s: pd.Series) -> Union[pd.Series, pd.DataFrame]:
192+
if not (
193+
s.index.is_monotonic_increasing or s.index.is_monotonic_decreasing
194+
):
195+
raise ValueError("Time series must have a monotonic time index. ")
196+
197+
agg = self.agg
198+
agg_params = self.agg_params if (self.agg_params is not None) else {}
199+
200+
resample = s.resample(
201+
rule=self.rule
202+
) # type: Union[pd.Series, pd.DataFrame]
203+
204+
aggList = [
205+
"mean",
206+
"median",
207+
"sum",
208+
"min",
209+
"max",
210+
"quantile",
211+
"iqr",
212+
"idr",
213+
"count",
214+
"nnz",
215+
"nunique",
216+
"std",
217+
"var",
218+
"skew",
219+
"kurt",
220+
"hist",
221+
]
222+
if agg in [
223+
"mean",
224+
"median",
225+
"sum",
226+
"min",
227+
"max",
228+
"count",
229+
"std",
230+
"var",
231+
"skew",
232+
"kurt",
233+
]:
234+
s_resample = resample.agg(agg)
235+
elif agg == "nunique":
236+
s_resample = resample.agg(lambda x: len(np.unique(x.dropna())))
237+
elif agg == "nnz":
238+
s_resample = resample.agg(np.count_nonzero)
239+
elif agg == "quantile":
240+
if hasattr(agg_params["q"], "__iter__"):
241+
s_resample = pd.concat(
242+
[
243+
resample.quantile(q).rename("q{}".format(q))
244+
for q in agg_params["q"]
245+
],
246+
axis=1,
247+
)
248+
else:
249+
s_resample = resample.quantile(agg_params["q"])
250+
elif agg == "iqr":
251+
s_resample = resample.quantile(0.75) - resample.quantile(0.25)
252+
elif agg == "idr":
253+
s_resample = resample.quantile(0.9) - resample.quantile(0.1)
254+
255+
elif callable(agg):
256+
s_resample = resample.agg(agg)
257+
258+
else:
259+
raise ValueError(f"Attribute agg must be one of {aggList}")
260+
261+
if isinstance(s_resample, pd.Series):
262+
s_resample.name = s.name
263+
264+
return s_resample
265+
266+
130267
class RollingAggregate(_NonTrainableUnivariateTransformer):
131268
"""Transformer that rolls a sliding window along a time series, and
132269
aggregates using a user-selected operation.

0 commit comments

Comments
 (0)