Skip to content

Commit

Permalink
feat: allow saving/loading multiple target arrays per feature array
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 committed May 22, 2024
1 parent c8a580b commit 75abd7b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
52 changes: 34 additions & 18 deletions rul_datasets/reader/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm # type: ignore


def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None:
def save(save_path: str, features: np.ndarray, *targets: np.ndarray) -> None:
"""
Save features and targets of a run to .npy files.
Expand All @@ -21,15 +21,20 @@ def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None:
Args:
save_path: The path including file name to save the arrays to.
features: The feature array to save.
targets: The targets array to save.
targets: The targets arrays to save.
"""
feature_path = _get_feature_path(save_path)
np.save(feature_path, features, allow_pickle=False)
target_path = _get_target_path(save_path)
np.save(target_path, targets, allow_pickle=False)
if len(targets) == 1: # keeps backward compat for when only one target was allowed
target_path = _get_target_path(save_path, None)
np.save(target_path, targets[0], allow_pickle=False)
else:
for i, target in enumerate(targets):
target_path = _get_target_path(save_path, i)
np.save(target_path, target, allow_pickle=False)


def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, ...]:
"""
Load features and targets of a run from .npy files.
Expand All @@ -50,15 +55,24 @@ def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
memmap_mode: Optional[Literal["r"]] = "r" if memmap else None
feature_path = _get_feature_path(save_path)
features = np.load(feature_path, memmap_mode, allow_pickle=False)
target_path = _get_target_path(save_path)
targets = np.load(target_path, memmap_mode, allow_pickle=False)
if os.path.exists(_get_target_path(save_path, None)):
# keeps backward compat for when only one target was allowed
target_path = _get_target_path(save_path, None)
targets = [np.load(target_path, memmap_mode, allow_pickle=False)]
else:
i = 0
targets = []
while os.path.exists(_get_target_path(save_path, i)):
target_path = _get_target_path(save_path, i)
targets.append(np.load(target_path, memmap_mode, allow_pickle=False))
i += 1

return features, targets
return features, *targets


def load_multiple(
save_paths: List[str], memmap: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
) -> Tuple[List[np.ndarray], ...]:
"""
Load multiple runs with the [load][rul_datasets.reader.saving.load] function.
Expand All @@ -72,11 +86,11 @@ def load_multiple(
"""
if save_paths:
runs = [load(save_path, memmap) for save_path in save_paths]
features, targets = [list(x) for x in zip(*runs)]
features, *targets = [list(x) for x in zip(*runs)]
else:
features, targets = [], []
features, targets = [], [[]]

return features, targets
return features, *targets


def exists(save_path: str) -> bool:
Expand All @@ -90,24 +104,26 @@ def exists(save_path: str) -> bool:
Returns:
Whether the files exist
"""
feature_path = _get_feature_path(save_path)
target_path = _get_target_path(save_path)
feature_exists = os.path.exists(_get_feature_path(save_path))
target_no_index_exists = os.path.exists(_get_target_path(save_path, None))
target_index_exists = os.path.exists(_get_target_path(save_path, 0))

return os.path.exists(feature_path) and os.path.exists(target_path)
return feature_exists and (target_no_index_exists or target_index_exists)


def _get_feature_path(save_path):
def _get_feature_path(save_path: str) -> str:
if save_path.endswith(".npy"):
save_path = save_path[:-4]
feature_path = f"{save_path}_features.npy"

return feature_path


def _get_target_path(save_path):
def _get_target_path(save_path: str, target_index: Optional[int]) -> str:
if save_path.endswith(".npy"):
save_path = save_path[:-4]
target_path = f"{save_path}_targets.npy"
suffix = "" if target_index is None else f"_{target_index}"
target_path = f"{save_path}_targets{suffix}.npy"

return target_path

Expand Down
46 changes: 46 additions & 0 deletions tests/reader/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ def test_save(tmp_path, file_name):
npt.assert_equal(loaded_targets, targets)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_save_multi_target(tmp_path, file_name):
features = np.empty((10, 2, 5))
targets_0 = np.empty((10,))
targets_1 = np.empty((10,))
save_path = os.path.join(tmp_path, file_name)
saving.save(save_path, features, targets_0, targets_1)

exp_save_path = save_path.replace(".npy", "")
loaded_features = np.load(exp_save_path + "_features.npy")
loaded_targets_0 = np.load(exp_save_path + "_targets_0.npy")
loaded_targets_1 = np.load(exp_save_path + "_targets_1.npy")
npt.assert_equal(loaded_features, features)
npt.assert_equal(loaded_targets_0, targets_0)
npt.assert_equal(loaded_targets_1, targets_1)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_load(tmp_path, file_name):
features = np.empty((10, 2, 5))
Expand All @@ -37,6 +54,23 @@ def test_load(tmp_path, file_name):
npt.assert_equal(loaded_targets, targets)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_load_multi_target(tmp_path, file_name):
features = np.empty((10, 2, 5))
targets_0 = np.empty((10,))
targets_1 = np.empty((10,))
exp_file_name = file_name.replace(".npy", "")
np.save(os.path.join(tmp_path, f"{exp_file_name}_features.npy"), features)
np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_0.npy"), targets_0)
np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_1.npy"), targets_1)

save_path = os.path.join(tmp_path, file_name)
loaded_features, loaded_targets_0, loaded_targets_1 = saving.load(save_path)
npt.assert_equal(loaded_features, features)
npt.assert_equal(loaded_targets_0, targets_0)
npt.assert_equal(loaded_targets_1, targets_1)


@mock.patch("rul_datasets.reader.saving.load", return_value=(None, None))
@pytest.mark.parametrize("file_names", [["run1", "run2"], []])
def test_load_multiple(mock_load, file_names):
Expand All @@ -59,6 +93,18 @@ def test_exists(tmp_path, file_name):
assert saving.exists(save_path)


@pytest.mark.parametrize("file_name", ["run", "run.npy"])
def test_exists_multi_target(tmp_path, file_name):
save_path = os.path.join(tmp_path, file_name)
assert not saving.exists(save_path)

Path(os.path.join(tmp_path, "run_features.npy")).touch()
assert not saving.exists(save_path)

Path(os.path.join(tmp_path, "run_targets_0.npy")).touch()
assert saving.exists(save_path)


@pytest.mark.parametrize("columns", [[0], [0, 1]])
@pytest.mark.parametrize(
"file_name", ["raw_features.csv", "raw_features_corrupted.csv"]
Expand Down

0 comments on commit 75abd7b

Please sign in to comment.