diff --git a/src/nhflodata/get_paths.py b/src/nhflodata/get_paths.py index eaeafa6..c6ee0ff 100644 --- a/src/nhflodata/get_paths.py +++ b/src/nhflodata/get_paths.py @@ -1,8 +1,10 @@ """Functions to get paths to data sets.""" +from __future__ import annotations import logging import os import re +from pathlib import Path import yaml from yaml.loader import SafeLoader @@ -110,6 +112,26 @@ def get_data_dir(): return os.path.join(os.path.dirname(__file__), "data") +def get_latest_data_paths() -> list[Path]: + """ + Get paths to all latest data versions in the repository. + + Returns + ------- + List[Path] + List of Path objects representing all found directories + + Examples + -------- + >>> folders = get_latest_data_paths() + >>> print(folders[0]) + ./data/subfolder1/v1.2.3 + """ + dataset_names = sorted(get_repository_data().keys()) + dataset_paths = [get_abs_data_path(name, version="latest", location="mockup") for name in dataset_names] + return [Path(path) for path in dataset_paths] + + def get_repository_path(): """Return the path to the repository.yaml file from data/repository.yaml.""" # from importlib.resources import files diff --git a/tests/test_forbidden_file_formats.py b/tests/test_forbidden_file_formats.py index de0a22a..e2fc388 100644 --- a/tests/test_forbidden_file_formats.py +++ b/tests/test_forbidden_file_formats.py @@ -6,27 +6,7 @@ import pytest -from nhflodata.get_paths import get_abs_data_path, get_repository_data - - -def get_latest_data_paths() -> list[Path]: - """ - Get paths to all latest data versions in the repository. - - Returns - ------- - List[Path] - List of Path objects representing all found directories - - Examples - -------- - >>> folders = get_latest_data_paths() - >>> print(folders[0]) - ./data/subfolder1/v1.2.3 - """ - dataset_names = sorted(get_repository_data().keys()) - dataset_paths = [get_abs_data_path(name, version="latest", location="mockup") for name in dataset_names] - return [Path(path) for path in dataset_paths] +from nhflodata.get_paths import get_latest_data_paths def find_files_by_extension(folder: Path, extensions: set[str]) -> list[Path]: @@ -62,7 +42,7 @@ def find_files_by_extension(folder: Path, extensions: set[str]) -> list[Path]: normalized_extensions = {(ext if ext.startswith(".") else f".{ext}").lower() for ext in extensions} # Get all files in the folder and its subfolders - all_files = folder.glob("**/*") + all_files = Path(folder).glob("**/*") # Check each file's extension case-insensitively for file_path in all_files: