Skip to content

Commit

Permalink
Merge pull request #70 from ibm-granite/stride_2
Browse files Browse the repository at this point in the history
create dataset with stride
  • Loading branch information
wgifford authored Jun 13, 2024
2 parents b5bc3ba + d65fff9 commit a730b5d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
5 changes: 4 additions & 1 deletion tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_forecasting_df_dataset_stride(ts_data_with_categorical):

# length check
series_len = len(df) / len(df["id"].unique())
assert len(ds) == ((series_len - prediction_length - context_length + 1) // stride) * len(df["id"].unique())
assert len(ds) == (((series_len - prediction_length - context_length) // stride) + 1) * len(df["id"].unique())

# check proper windows are selected based on chosen stride
ds_past_np = np.array([v["past_values"].numpy() for v in ds])
Expand All @@ -181,12 +181,15 @@ def test_forecasting_df_dataset_stride(ts_data_with_categorical):
[[0.0, 10.0], [1.0, 10.333333], [2.0, 10.666667]],
[[13.0, 14.333333], [14.0, 14.666667], [15.0, 15.0]],
[[26.0, 18.666666], [27.0, 19.0], [28.0, 19.333334]],
[[39.0, 23.0], [40.0, 23.333334], [41.0, 23.666666]],
[[50.0, 26.666666], [51.0, 27.0], [52.0, 27.333334]],
[[63.0, 31.0], [64.0, 31.333334], [65.0, 31.666666]],
[[76.0, 35.333332], [77.0, 35.666668], [78.0, 36.0]],
[[89.0, 39.666668], [90.0, 40.0], [91.0, 40.333332]],
[[100.0, 43.333332], [101.0, 43.666668], [102.0, 44.0]],
[[113.0, 47.666668], [114.0, 48.0], [115.0, 48.333332]],
[[126.0, 52.0], [127.0, 52.333332], [128.0, 52.666668]],
[[139.0, 56.333332], [140.0, 56.666668], [141.0, 57.0]],
]
)

Expand Down
15 changes: 15 additions & 0 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ def test_get_datasets(ts_data):

assert len(valid) == len(test)

full_lengths = [len(train), len(valid), len(test)]

stride = 3
num_ids = len(ts_data["id"].unique())
# test stride
train, valid, test = get_datasets(
tsp, ts_data, split_config={"train": [0, 1 / 3], "valid": [1 / 3, 2 / 3], "test": [2 / 3, 1]}, stride=stride
)

strided_lengths = [len(train), len(valid), len(test)]

# x is full length under stride 1
# x // 3 is full length for each ID, need to subtract one and then compute strided length per ID
assert [(((x // num_ids) - 1) // stride + 1) * num_ids for x in full_lengths] == strided_lengths

# no id columns, so treat as one big time series
tsp = TimeSeriesPreprocessor(
id_columns=[],
Expand Down
4 changes: 2 additions & 2 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def pad_zero(self, data_df):
)

def __len__(self):
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride
return (len(self.X) - self.context_length - self.prediction_length) // self.stride + 1

def __getitem__(self, index: int):
"""
Expand Down Expand Up @@ -484,7 +484,7 @@ def __getitem__(self, index):
return ret

def __len__(self):
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride
return (len(self.X) - self.context_length - self.prediction_length) // self.stride + 1


class RegressionDFDataset(BaseConcatDFDataset):
Expand Down
5 changes: 4 additions & 1 deletion tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ def get_datasets(
ts_preprocessor: TimeSeriesPreprocessor,
dataset: Union[Dataset, pd.DataFrame],
split_config: Dict[str, Union[List[Union[int, float]], float]] = {"train": 0.7, "test": 0.2},
stride: int = 1,
fewshot_fraction: Optional[float] = None,
fewshot_location: str = FractionLocation.LAST.value,
) -> Tuple[Any]:
Expand All @@ -823,7 +824,8 @@ def get_datasets(
test: 0.2
}
A valid split should not be specified directly; the above implies valid = 0.1
stride (int): Stride used for creating the datasets. It is applied to all of train, validation, and test.
Defaults to 1.
fewshot_fraction (float, optional): When non-null, return this percent of the original training
dataset. This is done to support fewshot fine-tuning.
fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last"
Expand Down Expand Up @@ -878,6 +880,7 @@ def get_datasets(
params = column_specifiers
params["context_length"] = ts_preprocessor.context_length
params["prediction_length"] = ts_preprocessor.prediction_length
params["stride"] = stride

# get torch datasets
train_valid_test = [train_data, valid_data, test_data]
Expand Down

0 comments on commit a730b5d

Please sign in to comment.