diff --git a/rul_datasets/reader/saving.py b/rul_datasets/reader/saving.py index 8b15b7b..0bde77e 100644 --- a/rul_datasets/reader/saving.py +++ b/rul_datasets/reader/saving.py @@ -8,7 +8,7 @@ from tqdm import tqdm # type: ignore -def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None: +def save(save_path: str, features: np.ndarray, *targets: np.ndarray) -> None: """ Save features and targets of a run to .npy files. @@ -21,15 +21,20 @@ def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None: Args: save_path: The path including file name to save the arrays to. features: The feature array to save. - targets: The targets array to save. + targets: The targets arrays to save. """ feature_path = _get_feature_path(save_path) np.save(feature_path, features, allow_pickle=False) - target_path = _get_target_path(save_path) - np.save(target_path, targets, allow_pickle=False) + if len(targets) == 1: # keeps backward compat for when only one target was allowed + target_path = _get_target_path(save_path, None) + np.save(target_path, targets[0], allow_pickle=False) + else: + for i, target in enumerate(targets): + target_path = _get_target_path(save_path, i) + np.save(target_path, target, allow_pickle=False) -def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]: +def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, ...]: """ Load features and targets of a run from .npy files. @@ -50,15 +55,24 @@ def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]: memmap_mode: Optional[Literal["r"]] = "r" if memmap else None feature_path = _get_feature_path(save_path) features = np.load(feature_path, memmap_mode, allow_pickle=False) - target_path = _get_target_path(save_path) - targets = np.load(target_path, memmap_mode, allow_pickle=False) + if os.path.exists(_get_target_path(save_path, None)): + # keeps backward compat for when only one target was allowed + target_path = _get_target_path(save_path, None) + targets = [np.load(target_path, memmap_mode, allow_pickle=False)] + else: + i = 0 + targets = [] + while os.path.exists(_get_target_path(save_path, i)): + target_path = _get_target_path(save_path, i) + targets.append(np.load(target_path, memmap_mode, allow_pickle=False)) + i += 1 - return features, targets + return features, *targets def load_multiple( save_paths: List[str], memmap: bool = False -) -> Tuple[List[np.ndarray], List[np.ndarray]]: +) -> Tuple[List[np.ndarray], ...]: """ Load multiple runs with the [load][rul_datasets.reader.saving.load] function. @@ -72,11 +86,11 @@ def load_multiple( """ if save_paths: runs = [load(save_path, memmap) for save_path in save_paths] - features, targets = [list(x) for x in zip(*runs)] + features, *targets = [list(x) for x in zip(*runs)] else: - features, targets = [], [] + features, targets = [], [[]] - return features, targets + return features, *targets def exists(save_path: str) -> bool: @@ -90,13 +104,14 @@ def exists(save_path: str) -> bool: Returns: Whether the files exist """ - feature_path = _get_feature_path(save_path) - target_path = _get_target_path(save_path) + feature_exists = os.path.exists(_get_feature_path(save_path)) + target_no_index_exists = os.path.exists(_get_target_path(save_path, None)) + target_index_exists = os.path.exists(_get_target_path(save_path, 0)) - return os.path.exists(feature_path) and os.path.exists(target_path) + return feature_exists and (target_no_index_exists or target_index_exists) -def _get_feature_path(save_path): +def _get_feature_path(save_path: str) -> str: if save_path.endswith(".npy"): save_path = save_path[:-4] feature_path = f"{save_path}_features.npy" @@ -104,10 +119,11 @@ def _get_feature_path(save_path): return feature_path -def _get_target_path(save_path): +def _get_target_path(save_path: str, target_index: Optional[int]) -> str: if save_path.endswith(".npy"): save_path = save_path[:-4] - target_path = f"{save_path}_targets.npy" + suffix = "" if target_index is None else f"_{target_index}" + target_path = f"{save_path}_targets{suffix}.npy" return target_path diff --git a/tests/reader/test_saving.py b/tests/reader/test_saving.py index 38df780..5cf439d 100644 --- a/tests/reader/test_saving.py +++ b/tests/reader/test_saving.py @@ -23,6 +23,23 @@ def test_save(tmp_path, file_name): npt.assert_equal(loaded_targets, targets) +@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) +def test_save_multi_target(tmp_path, file_name): + features = np.empty((10, 2, 5)) + targets_0 = np.empty((10,)) + targets_1 = np.empty((10,)) + save_path = os.path.join(tmp_path, file_name) + saving.save(save_path, features, targets_0, targets_1) + + exp_save_path = save_path.replace(".npy", "") + loaded_features = np.load(exp_save_path + "_features.npy") + loaded_targets_0 = np.load(exp_save_path + "_targets_0.npy") + loaded_targets_1 = np.load(exp_save_path + "_targets_1.npy") + npt.assert_equal(loaded_features, features) + npt.assert_equal(loaded_targets_0, targets_0) + npt.assert_equal(loaded_targets_1, targets_1) + + @pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) def test_load(tmp_path, file_name): features = np.empty((10, 2, 5)) @@ -37,6 +54,23 @@ def test_load(tmp_path, file_name): npt.assert_equal(loaded_targets, targets) +@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) +def test_load_multi_target(tmp_path, file_name): + features = np.empty((10, 2, 5)) + targets_0 = np.empty((10,)) + targets_1 = np.empty((10,)) + exp_file_name = file_name.replace(".npy", "") + np.save(os.path.join(tmp_path, f"{exp_file_name}_features.npy"), features) + np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_0.npy"), targets_0) + np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_1.npy"), targets_1) + + save_path = os.path.join(tmp_path, file_name) + loaded_features, loaded_targets_0, loaded_targets_1 = saving.load(save_path) + npt.assert_equal(loaded_features, features) + npt.assert_equal(loaded_targets_0, targets_0) + npt.assert_equal(loaded_targets_1, targets_1) + + @mock.patch("rul_datasets.reader.saving.load", return_value=(None, None)) @pytest.mark.parametrize("file_names", [["run1", "run2"], []]) def test_load_multiple(mock_load, file_names): @@ -59,6 +93,18 @@ def test_exists(tmp_path, file_name): assert saving.exists(save_path) +@pytest.mark.parametrize("file_name", ["run", "run.npy"]) +def test_exists_multi_target(tmp_path, file_name): + save_path = os.path.join(tmp_path, file_name) + assert not saving.exists(save_path) + + Path(os.path.join(tmp_path, "run_features.npy")).touch() + assert not saving.exists(save_path) + + Path(os.path.join(tmp_path, "run_targets_0.npy")).touch() + assert saving.exists(save_path) + + @pytest.mark.parametrize("columns", [[0], [0, 1]]) @pytest.mark.parametrize( "file_name", ["raw_features.csv", "raw_features_corrupted.csv"]