diff --git a/rul_datasets/reader/femto.py b/rul_datasets/reader/femto.py index b6ec809..105ebb6 100644 --- a/rul_datasets/reader/femto.py +++ b/rul_datasets/reader/femto.py @@ -125,10 +125,26 @@ def __init__( "and 'max_rul' in conjunction." ) + self.run_split_dist = run_split_dist self.first_time_to_predict = first_time_to_predict self.norm_rul = norm_rul - self._preparator = FemtoPreparator(self.fd, self._FEMTO_ROOT, run_split_dist) + self._preparator = FemtoPreparator( + self.fd, self._FEMTO_ROOT, self.run_split_dist + ) + + @property + def hparams(self): + hparams = super().hparams + hparams.update( + { + "first_time_to_predict": self.first_time_to_predict, + "norm_rul": self.norm_rul, + "run_split_dist": self.run_split_dist, + } + ) + + return hparams @property def dataset_name(self) -> str: diff --git a/rul_datasets/reader/xjtu_sy.py b/rul_datasets/reader/xjtu_sy.py index f31b621..78a5288 100644 --- a/rul_datasets/reader/xjtu_sy.py +++ b/rul_datasets/reader/xjtu_sy.py @@ -118,10 +118,26 @@ def __init__( "and 'max_rul' in conjunction." ) + self.run_split_dist = run_split_dist self.first_time_to_predict = first_time_to_predict self.norm_rul = norm_rul - self._preparator = XjtuSyPreparator(self.fd, self._XJTU_SY_ROOT, run_split_dist) + self._preparator = XjtuSyPreparator( + self.fd, self._XJTU_SY_ROOT, self.run_split_dist + ) + + @property + def hparams(self): + hparams = super().hparams + hparams.update( + { + "first_time_to_predict": self.first_time_to_predict, + "norm_rul": self.norm_rul, + "run_split_dist": self.run_split_dist, + } + ) + + return hparams @property def dataset_name(self) -> str: diff --git a/tests/reader/test_femto.py b/tests/reader/test_femto.py index 71ff79c..76925ec 100644 --- a/tests/reader/test_femto.py +++ b/tests/reader/test_femto.py @@ -5,22 +5,28 @@ from rul_datasets import reader -@pytest.fixture(scope="module", autouse=True) -def prepare_femto(): - for fd in range(1, 4): - reader.FemtoReader(fd).prepare_data() +def test_additional_hparams(): + femto = reader.FemtoReader(1, first_time_to_predict=[10] * 5, norm_rul=True) + assert femto.hparams["first_time_to_predict"] == [10] * 5 + assert femto.hparams["norm_rul"] + assert femto.hparams["run_split_dist"] is None @pytest.mark.needs_data -class TestFEMTOLoader: +class TestFemtoReader: NUM_CHANNELS = 2 + @pytest.fixture(scope="class", autouse=True) + def prepare_femto(self): + for fd in range(1, 4): + reader.FemtoReader(fd).prepare_data() + @pytest.mark.parametrize("fd", [1, 2, 3]) @pytest.mark.parametrize("window_size", [2560, 1500, 1000, 100]) @pytest.mark.parametrize("split", ["dev", "val", "test"]) def test_run_shape_and_dtype(self, fd, window_size, split): - femto_loader = reader.FemtoReader(fd, window_size=window_size) - features, targets = femto_loader.load_split(split) + femto_reader = reader.FemtoReader(fd, window_size=window_size) + features, targets = femto_reader.load_split(split) for run, run_target in zip(features, targets): self._assert_run_correct(run, run_target, window_size) diff --git a/tests/reader/test_xjtu_sy.py b/tests/reader/test_xjtu_sy.py index 2ff070c..a5343c3 100644 --- a/tests/reader/test_xjtu_sy.py +++ b/tests/reader/test_xjtu_sy.py @@ -7,16 +7,22 @@ from rul_datasets.reader.xjtu_sy import _download_xjtu_sy -@pytest.fixture(scope="module", autouse=True) -def prepare_xjtu_sy(): - for fd in range(1, 4): - reader.XjtuSyReader(fd).prepare_data() +def test_additional_hparams(): + femto = reader.XjtuSyReader(1, first_time_to_predict=[10] * 5, norm_rul=True) + assert femto.hparams["first_time_to_predict"] == [10] * 5 + assert femto.hparams["norm_rul"] + assert femto.hparams["run_split_dist"] is None @pytest.mark.needs_data class TestXjtuSyLoader: NUM_CHANNELS = 2 + @pytest.fixture(scope="class", autouse=True) + def prepare_xjtu_sy(self): + for fd in range(1, 4): + reader.XjtuSyReader(fd).prepare_data() + def test_default_window_size(self): xjtu = reader.XjtuSyReader(1) assert xjtu.window_size == reader.XjtuSyPreparator.DEFAULT_WINDOW_SIZE