diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 67a6915..696cf04 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -15,6 +15,7 @@ import tempfile import warnings import zipfile +from pathlib import Path from typing import Tuple, List, Optional, Union, Dict import h5py # type: ignore[import] @@ -23,8 +24,7 @@ from rul_datasets import utils from rul_datasets.reader.data_root import get_data_root -from rul_datasets.reader import AbstractReader, scaling - +from rul_datasets.reader import AbstractReader, scaling, saving NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn" @@ -206,23 +206,37 @@ def fds(self) -> List[int]: """Indices of the available sub-datasets.""" return list(self._WINDOW_SIZES) - def prepare_data(self) -> None: + def prepare_data(self, cache: bool = True) -> None: """ - Prepare the N-C-MAPSS dataset. This function needs to be called before using the - dataset for the first time. + Prepare the N-C-MAPSS dataset. This function needs to be called before using + the dataset for the first time. The dataset is cached for faster loading in + the future. This behavior can be disabled to save disk space by setting + `cache` to `False`. The dataset is assumed to be present in the data root directory. The training data is then split into development and validation set. Afterward, a scaler is fit on the development features if it was not already done previously. + + Args: + cache: Whether to cache the data for faster loading in the future. """ if not os.path.exists(self._NCMAPSS_ROOT): _download_ncmapss(self._NCMAPSS_ROOT) + if cache and not self._cache_exists(): + self._cache_data() if not os.path.exists(self._get_scaler_path()): features, _, _ = self._load_data("dev") scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range)) scaling.save_scaler(scaler, self._get_scaler_path()) - def _get_scaler_path(self): + def _cache_data(self) -> None: + os.makedirs(self._get_cache_path(), exist_ok=True) + features, targets, auxiliary = self._load_raw_data() + features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary) + for i, (f, t, a) in enumerate(zip(features, targets, auxiliary)): + saving.save(str(self._get_cache_path() / f"{i}.npy"), f, t, a) + + def _get_scaler_path(self) -> str: file_name = ( f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl" ) @@ -264,6 +278,21 @@ def load_complete_split( def _load_data( self, split: str ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + if self._cache_exists(): + features, targets, auxiliary = self._load_cached_data(split) + else: + features, targets, auxiliary = self._load_original_data(split) + + return features, targets, auxiliary + + def _load_cached_data(self, split: str): + unit_idx = self.run_split_dist[split] + save_paths = [str(self._get_cache_path() / f"{i}.npy") for i in unit_idx] + features, targets, auxiliary = saving.load_multiple(save_paths) + + return features, targets, auxiliary + + def _load_original_data(self, split): features, targets, auxiliary = self._load_raw_data() features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary) features = self._select_units(features, split) @@ -368,6 +397,12 @@ def _calc_default_window_size(self): return max(*max_window_size) + def _cache_exists(self) -> bool: + return saving.exists(str(self._get_cache_path() / "0.npy")) + + def _get_cache_path(self): + return Path(self._NCMAPSS_ROOT) / f"DS{self.fd:02d}" + def _download_ncmapss(data_root): with tempfile.TemporaryDirectory() as tmp_path: diff --git a/tests/reader/test_ncmapss.py b/tests/reader/test_ncmapss.py index fc4477c..333ffd0 100644 --- a/tests/reader/test_ncmapss.py +++ b/tests/reader/test_ncmapss.py @@ -1,13 +1,18 @@ +import os +import shutil +from pathlib import Path + import numpy as np import pytest +import rul_datasets from rul_datasets.reader.ncmapss import NCmapssReader @pytest.fixture() def prepared_ncmapss(): for fd in range(1, 8): - NCmapssReader(fd).prepare_data() + NCmapssReader(fd).prepare_data(cache=False) @pytest.mark.parametrize("fd", list(range(1, 8))) @@ -166,3 +171,31 @@ def test_scaling_range_is_tuple(scaling_range): assert isinstance(reader.scaling_range, tuple) assert reader.scaling_range == (0, 1) + + +@pytest.mark.needs_data +def test_cache(prepared_ncmapss, tmp_path): + ncmapss_files = Path(NCmapssReader._NCMAPSS_ROOT).rglob("*") + linked_files = [] + for file in ncmapss_files: + if str(file).endswith(".h5"): + linked_file = tmp_path / file.name + os.symlink(file, linked_file) + linked_files.append(linked_file) + + reader = NCmapssReader(1) + cached_reader = NCmapssReader(1) + cached_reader._NCMAPSS_ROOT = tmp_path + cached_reader.prepare_data(cache=True) + + # remove linked files so that only cached files can be used + for file in linked_files: + os.remove(file) + + org_features, org_targets = reader.load_split("dev") + cached_features, cached_targets = cached_reader.load_split("dev") + + for org_feat, cached_feat in zip(org_features, cached_features): + np.testing.assert_almost_equal(org_feat, cached_feat) + for org_targ, cached_targ in zip(org_targets, cached_targets): + np.testing.assert_almost_equal(org_targ, cached_targ)