Skip to content

Commit

Permalink
feat: add scaling selection to NCMAPSS (#59)
Browse files Browse the repository at this point in the history
* docs: add docstrings

* feat: update ncmapss reader
  • Loading branch information
ZhengyanZhu authored Mar 30, 2024
1 parent 494052f commit c6f9e32
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
9 changes: 7 additions & 2 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
truncate_degraded_only: bool = False,
resolution_seconds: int = 1,
padding_value: float = 0.0,
scaling_range: Optional[Tuple[int, int]] = (0, 1),
) -> None:
"""
Create a new reader for the New C-MAPSS dataset. The maximum RUL value is set
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self.run_split_dist = run_split_dist or self._get_default_split(self.fd)
self.resolution_seconds = resolution_seconds
self.padding_value = padding_value
self.scaling_range = scaling_range

if self.resolution_seconds > 1 and window_size is None:
warnings.warn(
Expand All @@ -189,6 +191,7 @@ def hparams(self):
"run_split_dist": self.run_split_dist,
"feature_select": self.feature_select,
"padding_value": self.padding_value,
"scaling_range": self.scaling_range,
}
)

Expand Down Expand Up @@ -216,11 +219,13 @@ def prepare_data(self) -> None:
_download_ncmapss(self._NCMAPSS_ROOT)
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler())
scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range))
scaling.save_scaler(scaler, self._get_scaler_path())

def _get_scaler_path(self):
file_name = f"scaler_{self.fd}_{self.run_split_dist['dev']}.pkl"
file_name = (
f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl"
)
file_path = os.path.join(self._NCMAPSS_ROOT, file_name)

return file_path
Expand Down
15 changes: 15 additions & 0 deletions tests/reader/test_ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def test_prepare_data(should_run, mocker):
mock_save_scaler.assert_not_called()


@pytest.mark.needs_data
@pytest.mark.parametrize("scaling_range", [(-1, 1), (0, 1)])
def test_scaling_range(scaling_range):
reader = NCmapssReader(fd=1, scaling_range=scaling_range)
reader.prepare_data()
features, _ = reader.load_split("dev")

min_val, max_val = scaling_range
for feature in features:
flat_features = feature.flatten()
np.testing.assert_almost_equal(
flat_features, np.clip(flat_features, min_val, max_val)
)


@pytest.mark.needs_data
@pytest.mark.parametrize("fd", list(range(1, 8)))
@pytest.mark.parametrize("split", ["dev", "val", "test"])
Expand Down

0 comments on commit c6f9e32

Please sign in to comment.