diff --git a/src/gluonts/transform/sampler.py b/src/gluonts/transform/sampler.py index ec4de86b10..cdf95a091d 100644 --- a/src/gluonts/transform/sampler.py +++ b/src/gluonts/transform/sampler.py @@ -11,14 +11,93 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Tuple +from dataclasses import dataclass +from typing import Tuple, Optional, Union import numpy as np +import pandas as pd from pydantic import BaseModel from gluonts.dataset.stat import ScaleHistogram +def clip(value, low, high): + """ + Clip ``value`` between ``low`` and ``high``, included. + """ + return max(low, min(high, value)) + + +@dataclass +class Range: + start: Optional[Union[int, pd.Period]] = None + stop: Optional[Union[int, pd.Period]] = None + step: int = 1 + + def _start_as_int(self, start: pd.Period, length: int) -> int: + if self.start is None: + return 0 + if isinstance(self.start, pd.Period): + return int((self.start - start) / start.freq) + if self.start < 0: + return length + self.start + return self.start + + def _stop_as_int(self, start: pd.Period, length: int) -> int: + if self.stop is None: + return length + if isinstance(self.stop, pd.Period): + return int((self.stop - start) / start.freq) + if self.stop < 0: + return length + self.stop + return self.stop + + def get(self, start: pd.Period, length: int) -> range: + return range( + clip(self._start_as_int(start, length), 0, length), + clip(self._stop_as_int(start, length), 0, length), + self.step, + ) + + +@dataclass +class Sampler: + range_: Range + + def sample(self, rge: range) -> list: + raise NotImplementedError() + + def __call__(self, start: pd.Period, length: int) -> list: + return self.sample(self.range_.get(start, length)) + + +@dataclass +class SampleAll(Sampler): + def sample(self, rge: range) -> list: + return list(rge) + + +@dataclass +class SampleOnAverage(Sampler): + average_num_samples: float = 1.0 + + def __post_init__(self): + self.average_length = 0 + self.count = 0 + + def sample(self, rge: range) -> list: + if len(rge) == 0: + return [] + + self.average_length = (self.count * self.average_length + len(rge)) / ( + self.count + 1 + ) + self.count += 1 + p = self.average_num_samples / self.average_length + (indices,) = np.where(np.random.random_sample(len(rge)) < p) + return (min(rge) + indices).tolist() + + class InstanceSampler(BaseModel): """ An InstanceSampler is called with the time series ``ts``, and returns a set