Skip to content

Commit

Permalink
feat: better save options for MGTFS + testing
Browse files Browse the repository at this point in the history
  • Loading branch information
CBROWN-ONS committed Jan 19, 2024
1 parent ef49c18 commit 2db6a32
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/transport_performance/gtfs/multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, path: Union[str, list, pathlib.Path]) -> None:
def save_feeds(
self,
dir: Union[pathlib.Path, str],
suffix: str = None,
suffix: str = "_new",
file_names: list = None,
overwrite: bool = False,
) -> None:
Expand All @@ -131,7 +131,8 @@ def save_feeds(
file_names : list
A list of save names for the altered GTFS. The list must be the
same length as the number of GTFS instances. Takes priority over
the 'suffix' param.
the 'suffix' param. Names will be used in order of the instances
(access using self.instances()).
overwrite : bool
Whether or not to overwrite the pre-existing saves with matching
paths.
Expand All @@ -146,16 +147,22 @@ def save_feeds(
_type_defence(overwrite, "overwrite", bool)
defence_path = os.path.join(dir, "test.test")
_check_parent_dir_exists(defence_path, "dir", create=True)
save_paths = [
os.path.join(
dir, os.path.splitext(os.path.basename(p))[0] + "_new.zip"
)
for p in self.paths
]
# format save locations
if not file_names:
save_paths = [
os.path.join(
dir,
os.path.splitext(os.path.basename(p))[0] + f"{suffix}.zip",
)
for p in self.paths
]
else:
save_paths = [os.path.join(dir, p) for p in file_names]
# save gtfs
progress = tqdm(zip(save_paths, self.instances), total=len(self.paths))
for path, inst in progress:
progress.set_description(f"Saving at {path}")
inst.save(path)
inst.save(path, overwrite=overwrite)
return None

def clean_feeds(self, clean_kwargs: Union[dict, None] = None) -> None:
Expand Down
13 changes: 13 additions & 0 deletions tests/gtfs/test_multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ def test_save_feeds(self, multi_gtfs_paths, tmp_path):
assert np.array_equal(
np.sort(expected_paths), np.sort(found_paths)
), "GtfsInstances not saved as expected"
# test saves with filenames
file_name_dir = os.path.join(tmp_path, "filenames")
gtfs.save_feeds(
dir=file_name_dir, file_names=["test1.zip", "test2.zip"]
)
assert len(os.listdir(file_name_dir)) == 2, "Not enough files saved"
found_paths = [
os.path.basename(fpath)
for fpath in glob.glob(file_name_dir + "/*.zip")
]
assert np.array_equal(
found_paths, ["test1.zip", "test2.zip"]
), "File names not saved correctly"

def test_clean_feeds_defences(self, multi_gtfs_fixture):
"""Defensive tests for .clean_feeds()."""
Expand Down

0 comments on commit 2db6a32

Please sign in to comment.