Skip to content

Commit

Permalink
feat: add caching to NCMAPSS
Browse files Browse the repository at this point in the history
saves data as npy files that load faster than h5
  • Loading branch information
tilman151 committed May 22, 2024
1 parent 75abd7b commit e3af229
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
47 changes: 41 additions & 6 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import Tuple, List, Optional, Union, Dict

import h5py # type: ignore[import]
Expand All @@ -23,8 +24,7 @@

from rul_datasets import utils
from rul_datasets.reader.data_root import get_data_root
from rul_datasets.reader import AbstractReader, scaling

from rul_datasets.reader import AbstractReader, scaling, saving

NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn"

Expand Down Expand Up @@ -206,23 +206,37 @@ def fds(self) -> List[int]:
"""Indices of the available sub-datasets."""
return list(self._WINDOW_SIZES)

def prepare_data(self) -> None:
def prepare_data(self, cache: bool = True) -> None:
"""
Prepare the N-C-MAPSS dataset. This function needs to be called before using the
dataset for the first time.
Prepare the N-C-MAPSS dataset. This function needs to be called before using
the dataset for the first time. The dataset is cached for faster loading in
the future. This behavior can be disabled to save disk space by setting
`cache` to `False`.
The dataset is assumed to be present in the data root directory. The training
data is then split into development and validation set. Afterward, a scaler
is fit on the development features if it was not already done previously.
Args:
cache: Whether to cache the data for faster loading in the future.
"""
if not os.path.exists(self._NCMAPSS_ROOT):
_download_ncmapss(self._NCMAPSS_ROOT)
if cache and not self._cache_exists():
self._cache_data()
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range))
scaling.save_scaler(scaler, self._get_scaler_path())

def _get_scaler_path(self):
def _cache_data(self) -> None:
os.makedirs(self._get_cache_path(), exist_ok=True)
features, targets, auxiliary = self._load_raw_data()
features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary)
for i, (f, t, a) in enumerate(zip(features, targets, auxiliary)):
saving.save(str(self._get_cache_path() / f"{i}.npy"), f, t, a)

def _get_scaler_path(self) -> str:
file_name = (
f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl"
)
Expand Down Expand Up @@ -264,6 +278,21 @@ def load_complete_split(
def _load_data(
self, split: str
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
if self._cache_exists():
features, targets, auxiliary = self._load_cached_data(split)
else:
features, targets, auxiliary = self._load_original_data(split)

return features, targets, auxiliary

def _load_cached_data(self, split: str):
unit_idx = self.run_split_dist[split]
save_paths = [str(self._get_cache_path() / f"{i}.npy") for i in unit_idx]
features, targets, auxiliary = saving.load_multiple(save_paths)

return features, targets, auxiliary

def _load_original_data(self, split):
features, targets, auxiliary = self._load_raw_data()
features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary)
features = self._select_units(features, split)
Expand Down Expand Up @@ -368,6 +397,12 @@ def _calc_default_window_size(self):

return max(*max_window_size)

def _cache_exists(self) -> bool:
return saving.exists(str(self._get_cache_path() / "0.npy"))

def _get_cache_path(self):
return Path(self._NCMAPSS_ROOT) / f"DS{self.fd:02d}"


def _download_ncmapss(data_root):
with tempfile.TemporaryDirectory() as tmp_path:
Expand Down
35 changes: 34 additions & 1 deletion tests/reader/test_ncmapss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import os
import shutil
from pathlib import Path

import numpy as np
import pytest

import rul_datasets
from rul_datasets.reader.ncmapss import NCmapssReader


@pytest.fixture()
def prepared_ncmapss():
for fd in range(1, 8):
NCmapssReader(fd).prepare_data()
NCmapssReader(fd).prepare_data(cache=False)


@pytest.mark.parametrize("fd", list(range(1, 8)))
Expand Down Expand Up @@ -166,3 +171,31 @@ def test_scaling_range_is_tuple(scaling_range):

assert isinstance(reader.scaling_range, tuple)
assert reader.scaling_range == (0, 1)


@pytest.mark.needs_data
def test_cache(prepared_ncmapss, tmp_path):
ncmapss_files = Path(NCmapssReader._NCMAPSS_ROOT).rglob("*")
linked_files = []
for file in ncmapss_files:
if str(file).endswith(".h5"):
linked_file = tmp_path / file.name
os.symlink(file, linked_file)
linked_files.append(linked_file)

reader = NCmapssReader(1)
cached_reader = NCmapssReader(1)
cached_reader._NCMAPSS_ROOT = tmp_path
cached_reader.prepare_data(cache=True)

# remove linked files so that only cached files can be used
for file in linked_files:
os.remove(file)

org_features, org_targets = reader.load_split("dev")
cached_features, cached_targets = cached_reader.load_split("dev")

for org_feat, cached_feat in zip(org_features, cached_features):
np.testing.assert_almost_equal(org_feat, cached_feat)
for org_targ, cached_targ in zip(org_targets, cached_targets):
np.testing.assert_almost_equal(org_targ, cached_targ)

0 comments on commit e3af229

Please sign in to comment.