diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 1b98f6f..44bfe08 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -123,6 +123,7 @@ def __init__( truncate_degraded_only: bool = False, resolution_seconds: int = 1, padding_value: float = 0.0, + scaling_range: Optional[Tuple[int, int]] = (0, 1), ) -> None: """ Create a new reader for the New C-MAPSS dataset. The maximum RUL value is set @@ -172,6 +173,7 @@ def __init__( self.run_split_dist = run_split_dist or self._get_default_split(self.fd) self.resolution_seconds = resolution_seconds self.padding_value = padding_value + self.scaling_range = scaling_range if self.resolution_seconds > 1 and window_size is None: warnings.warn( @@ -189,6 +191,7 @@ def hparams(self): "run_split_dist": self.run_split_dist, "feature_select": self.feature_select, "padding_value": self.padding_value, + "scaling_range": self.scaling_range, } ) @@ -216,11 +219,13 @@ def prepare_data(self) -> None: _download_ncmapss(self._NCMAPSS_ROOT) if not os.path.exists(self._get_scaler_path()): features, _, _ = self._load_data("dev") - scaler = scaling.fit_scaler(features, MinMaxScaler()) + scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range)) scaling.save_scaler(scaler, self._get_scaler_path()) def _get_scaler_path(self): - file_name = f"scaler_{self.fd}_{self.run_split_dist['dev']}.pkl" + file_name = ( + f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl" + ) file_path = os.path.join(self._NCMAPSS_ROOT, file_name) return file_path diff --git a/tests/reader/test_ncmapss.py b/tests/reader/test_ncmapss.py index 9bfc618..c476574 100644 --- a/tests/reader/test_ncmapss.py +++ b/tests/reader/test_ncmapss.py @@ -47,6 +47,21 @@ def test_prepare_data(should_run, mocker): mock_save_scaler.assert_not_called() +@pytest.mark.needs_data +@pytest.mark.parametrize("scaling_range", [(-1, 1), (0, 1)]) +def test_scaling_range(scaling_range): + reader = NCmapssReader(fd=1, scaling_range=scaling_range) + reader.prepare_data() + features, _ = reader.load_split("dev") + + min_val, max_val = scaling_range + for feature in features: + flat_features = feature.flatten() + np.testing.assert_almost_equal( + flat_features, np.clip(flat_features, min_val, max_val) + ) + + @pytest.mark.needs_data @pytest.mark.parametrize("fd", list(range(1, 8))) @pytest.mark.parametrize("split", ["dev", "val", "test"])