From 684dd332947f0c21b89d634d64d457df91ce1dc8 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Mon, 6 May 2024 15:28:09 -0700 Subject: [PATCH] #181 Add tests for ParamFile and ParamDataFrame --- .readthedocs.yaml | 2 +- docs/api-reference.md | 2 + paramdb/_param_data/_files.py | 18 +-- poetry.lock | 30 ++++- pyproject.toml | 1 + tests/_param_data/test_dataclasses.py | 7 +- tests/_param_data/test_files.py | 156 ++++++++++++++++++++++++++ tests/_param_data/test_param_data.py | 8 +- tests/conftest.py | 16 ++- tests/helpers.py | 19 +++- tests/test_database.py | 4 + 11 files changed, 241 insertions(+), 22 deletions(-) create mode 100644 tests/_param_data/test_files.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index bb6f5b0..f9ee352 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ build: - pip install poetry==1.8.2 post_install: # See https://docs.readthedocs.io/en/stable/build-customization.html#install-dependencies-with-poetry - - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install -E pydantic --without dev + - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --all-extras --without dev - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry run python -m ipykernel install --user sphinx: diff --git a/docs/api-reference.md b/docs/api-reference.md index 3f21e4c..edbf045 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -16,6 +16,8 @@ All of the following can be imported from `paramdb`. .. autoclass:: ParamStr .. autoclass:: ParamNone .. autoclass:: ParamDataclass +.. autoclass:: ParamFile +.. autoclass:: ParamDataFrame .. autoclass:: ParamList :no-members: .. autoclass:: ParamDict diff --git a/paramdb/_param_data/_files.py b/paramdb/_param_data/_files.py index 4ccb801..5e2633f 100644 --- a/paramdb/_param_data/_files.py +++ b/paramdb/_param_data/_files.py @@ -7,7 +7,7 @@ from paramdb._param_data._dataclasses import ParamDataclass try: - import pandas as pd # type: ignore + import pandas as pd PANDAS_INSTALLED = True except ImportError: @@ -38,7 +38,7 @@ def _load_data(self, path: str) -> str: initial_data: InitVar[T | None] = None # pylint: disable-next=arguments-differ - def __post_init__(self, initial_data: T | None = None) -> None: # type: ignore + def __post_init__(self, initial_data: T | None) -> None: super().__post_init__() if initial_data is not None: self.update_data(initial_data) @@ -51,16 +51,16 @@ def _save_data(self, path: str, data: T) -> None: def _load_data(self, path: str) -> T: """Load data from the file at the given path.""" - def update_data(self, data: T) -> None: - """Update the data stored within the file represented by this object.""" - self._save_data(self.path, data) - self._update_last_updated() - @property def data(self) -> T: """Data stored in the file represented by this object.""" return self._load_data(self.path) + def update_data(self, data: T) -> None: + """Update the data stored within the file represented by this object.""" + self._save_data(self.path, data) + self._update_last_updated() + if PANDAS_INSTALLED: @@ -68,8 +68,8 @@ class ParamDataFrame(ParamFile[pd.DataFrame]): """ Subclass of :py:class:`ParamFile`. - Pandas DataFrame stored in a CSV file (with no index). This class will only be - defined if Pandas is installed. + Parameter data Pandas DataFrame, stored in a CSV file (with no index). This + class will only be defined if Pandas is installed. """ def _load_data(self, path: str) -> pd.DataFrame: diff --git a/poetry.lock b/poetry.lock index ac7b31f..36f6298 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1427,7 +1427,7 @@ files = [ name = "numpy" version = "1.26.4" description = "Fundamental package for array computing in Python" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, @@ -1552,6 +1552,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.1.240316" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas_stubs-2.2.1.240316-py3-none-any.whl", hash = "sha256:0126a26451a37cb893ea62357ca87ba3d181bd999ec8ba2ca5602e20207d6682"}, + {file = "pandas_stubs-2.2.1.240316.tar.gz", hash = "sha256:236a4f812fb6b1922e9607ff09e427f6d8540c421c9e5a40e3e4ddf7adac7f05"}, +] + +[package.dependencies] +numpy = {version = ">=1.26.0", markers = "python_version < \"3.13\""} +types-pytz = ">=2022.1.1" + [[package]] name = "pandocfilters" version = "1.5.1" @@ -2687,6 +2702,17 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "types-pytz" +version = "2024.1.0.20240417" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.1.0.20240417.tar.gz", hash = "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981"}, + {file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"}, +] + [[package]] name = "typing-extensions" version = "4.11.0" @@ -2844,4 +2870,4 @@ pydantic = ["eval-type-backport", "pydantic"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "3b4eabb8c544880ce3befdb6bd42922b2eab3e82f45e070c6701769fa07dc84d" +content-hash = "7401f0468e9e6a04ccbb540ea2d037b88902056080c6085dc5dde27b879ba724" diff --git a/pyproject.toml b/pyproject.toml index a3824fa..b9fe04f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ pylint = "^3.1.0" black = "^24.4.0" pytest = "^8.1.1" pytest-cov = "^5.0.0" +pandas-stubs = "^2.2.1.240316" [tool.poetry.group.docs.dependencies] sphinx = "^7.3.7" diff --git a/tests/_param_data/test_dataclasses.py b/tests/_param_data/test_dataclasses.py index 8ecf976..9b3c781 100644 --- a/tests/_param_data/test_dataclasses.py +++ b/tests/_param_data/test_dataclasses.py @@ -1,6 +1,6 @@ """Tests for the paramdb._param_data._dataclasses module.""" -from typing import Union +from typing import Union, cast from copy import deepcopy import pydantic import pytest @@ -43,10 +43,7 @@ def fixture_param_dataclass_object( request: pytest.FixtureRequest, ) -> ParamDataclassObject: """Parameter data class object.""" - param_dataclass_object: ParamDataclassObject = deepcopy( - request.getfixturevalue(request.param) - ) - return param_dataclass_object + return cast(ParamDataclassObject, deepcopy(request.getfixturevalue(request.param))) def test_param_dataclass_direct_instantiation_fails() -> None: diff --git a/tests/_param_data/test_files.py b/tests/_param_data/test_files.py new file mode 100644 index 0000000..33ec86c --- /dev/null +++ b/tests/_param_data/test_files.py @@ -0,0 +1,156 @@ +"""Tests for the paramdb._param_data._files module.""" + +from typing import Any +import os +from pathlib import Path +import numpy as np +import pandas as pd +import pytest +from paramdb import ParamFile, ParamDataFrame +from tests.helpers import ParamTextFile, capture_start_end_times + + +def _data_frame(rows: int, columns: int) -> pd.DataFrame: + return pd.DataFrame( + np.random.randn(rows, columns), columns=[f"col{n}" for n in range(columns)] + ) + + +def _assert_data_equals(param_file: ParamFile[Any], data: Any) -> None: + if isinstance(param_file, ParamDataFrame): + pd.testing.assert_frame_equal(param_file.data, data) + else: + assert param_file.data == data + + +@pytest.fixture(name="param_file_path") +def fixture_param_file_path(tmp_path: Path) -> str: + """Path to the parameter file.""" + return str(tmp_path / "param_file_data") + + +@pytest.fixture( + name="param_file_path_data", + params=[ + (ParamDataFrame, _data_frame(10, 10)), + (ParamDataFrame, _data_frame(100, 10)), + (ParamDataFrame, _data_frame(10, 100)), + (ParamTextFile, ""), + (ParamTextFile, "hello"), + (ParamTextFile, repr(np.random.randn(20, 20))), + ], +) +def fixture_param_file_and_data( + request: pytest.FixtureRequest, param_file_path: str +) -> tuple[Any, ParamFile[Any]]: + """Parameter file, path, and data.""" + param_file_class: type[ParamFile[Any]] = request.param[0] + data = request.param[1] + return data, param_file_class(param_file_path, data) + + +@pytest.fixture(name="data") +def fixture_data(param_file_path_data: tuple[Any, ParamFile[Any]]) -> Any: + """Parameter file data.""" + return param_file_path_data[0] + + +@pytest.fixture(name="param_file") +def fixture_param_file( + param_file_path_data: tuple[Any, ParamFile[Any]] +) -> ParamFile[Any]: + """Parameter file.""" + return param_file_path_data[1] + + +@pytest.fixture(name="param_file_existing") +def fixture_param_file_existing( + param_file: ParamFile[Any], param_file_path: str +) -> ParamFile[Any]: + """Parameter file that points to existing data.""" + return type(param_file)(param_file_path) + + +@pytest.fixture(name="different_data") +def fixture_different_data(param_file: ParamFile[Any]) -> Any: + """Data that is different than what is stored in the parameter file.""" + if isinstance(param_file, ParamDataFrame): + return _data_frame(3, 3) + return "different" + + +def test_param_file_saves_file( + param_file_path: str, param_file: ParamFile[Any], data: pd.DataFrame +) -> None: + """Parameter file saves data in a file.""" + os.remove(param_file_path) + assert not os.path.exists(param_file_path) + type(param_file)(param_file_path, data) + assert os.path.exists(param_file_path) + if isinstance(param_file, ParamDataFrame): + pd.testing.assert_frame_equal(data, pd.read_csv(param_file_path)) + else: + with open(param_file_path, "r", encoding="utf-8") as f: + assert data == f.read() + + +def test_param_file_path(param_file_path: str, param_file: ParamFile[Any]) -> None: + """Parameter file returns the correct path.""" + assert param_file.path == param_file_path + + +def test_param_file_data(param_file: ParamFile[Any], data: Any) -> None: + """Parameter file loads the correct data.""" + _assert_data_equals(param_file, data) + + +def test_param_file_existing_data( + param_file: ParamFile[Any], + param_file_existing: ParamFile[Any], + data: Any, +) -> None: + """Parameter file pointing to existing data can load that data.""" + _assert_data_equals(param_file_existing, data) + _assert_data_equals(param_file_existing, param_file.data) + + +def test_param_file_update_path( + tmp_path: Path, + param_file: ParamFile[Any], + different_data: Any, +) -> None: + """Parameter file can update its path.""" + different_data_frame_path = str(tmp_path / "different_param_file_data") + type(param_file)(different_data_frame_path, different_data) + with capture_start_end_times() as times: + param_file.path = different_data_frame_path + assert times.start <= param_file.last_updated.timestamp() <= times.end + assert param_file.path == different_data_frame_path + _assert_data_equals(param_file, different_data) + + +def test_param_file_update_data( + param_file: ParamFile[Any], different_data: Any +) -> None: + """Parameter file can update the data file it points to.""" + with capture_start_end_times() as times: + param_file.update_data(different_data) + assert times.start <= param_file.last_updated.timestamp() <= times.end + _assert_data_equals(param_file, different_data) + + +def test_param_file_frame_existing_update_data( + param_file: ParamFile[Any], + param_file_existing: ParamFile[Any], + different_data: Any, +) -> None: + """ + Parameter file pointing to existing data can update data, which updates the data + loaded by the original Parameter DataFrame. + """ + with capture_start_end_times() as times: + param_file_existing.update_data(different_data) + assert times.start <= param_file_existing.last_updated.timestamp() <= times.end + assert param_file.last_updated.timestamp() <= times.end + _assert_data_equals(param_file_existing, different_data) + _assert_data_equals(param_file, different_data) diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index abe63b8..162556c 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -4,7 +4,7 @@ from copy import deepcopy import pytest from tests.helpers import ComplexParam, Times, capture_start_end_times -from paramdb import ParamData +from paramdb import ParamData, ParamDataFrame from paramdb._param_data._param_data import get_param_class @@ -45,7 +45,11 @@ def test_get_param_class(param_data: ParamData) -> None: def test_param_data_initial_last_updated(param_data_type: type[ParamData]) -> None: """New parameter data objects are initialized with a last updated timestamp.""" with capture_start_end_times() as times: - new_param_data = param_data_type() + new_param_data: ParamData + if issubclass(param_data_type, ParamDataFrame): + new_param_data = param_data_type("") + else: + new_param_data = param_data_type() assert new_param_data.last_updated is not None assert times.start < new_param_data.last_updated.timestamp() < times.end diff --git a/tests/conftest.py b/tests/conftest.py index 5f463ce..19892eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ ParamBool, ParamStr, ParamNone, + ParamDataFrame, ParamList, ParamDict, ) @@ -71,6 +72,12 @@ def fixture_param_none() -> ParamNone: return ParamNone() +@pytest.fixture(name="param_data_frame") +def fixture_param_data_frame(string: str) -> ParamDataFrame: + """Parameter DataFrame.""" + return ParamDataFrame(f"{string}.csv") + + @pytest.fixture(name="empty_param") def fixture_empty_param() -> EmptyParam: """Empty parameter data class object.""" @@ -129,6 +136,7 @@ def fixture_complex_param(number: float, string: str) -> ComplexParam: return ComplexParam( number=number, string=string, + param_data_frame=ParamDataFrame(string), empty_param=EmptyParam(), simple_param=SimpleParam(), no_type_validation_param=NoTypeValidationParam(), @@ -153,6 +161,7 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: ParamBool(), ParamStr(string), ParamNone(), + ParamDataFrame(string), EmptyParam(), SimpleParam(), NoTypeValidationParam(), @@ -176,6 +185,7 @@ def fixture_param_dict_contents( param_bool: ParamBool, param_str: ParamStr, param_none: ParamNone, + param_data_frame: ParamDataFrame, empty_param: EmptyParam, simple_param: SimpleParam, no_type_validation_param: NoTypeValidationParam, @@ -194,6 +204,7 @@ def fixture_param_dict_contents( "param_bool": deepcopy(param_bool), "param_str": deepcopy(param_str), "param_none": deepcopy(param_none), + "param_data_frame": deepcopy(param_data_frame), "empty_param": deepcopy(empty_param), "simple_param": deepcopy(simple_param), "no_type_validation_param": deepcopy(no_type_validation_param), @@ -239,6 +250,7 @@ def fixture_param_dict(param_dict_contents: dict[str, Any]) -> ParamDict[Any]: "param_bool", "param_str", "param_none", + "param_data_frame", "empty_param", "simple_param", "no_type_validation_param", @@ -274,6 +286,8 @@ def fixture_updated_param_data_and_times( updated_param_data = type(updated_param_data)(updated_param_data.value) elif isinstance(updated_param_data, (ParamNone, EmptyParam)): updated_param_data = type(updated_param_data)() + elif isinstance(updated_param_data, ParamDataFrame): + updated_param_data.path = "" elif isinstance(updated_param_data, SimpleParam): updated_param_data.number += 1 elif isinstance(updated_param_data, SubclassParam): @@ -285,7 +299,7 @@ def fixture_updated_param_data_and_times( if len(updated_param_data) == 0: updated_param_data.append(number) else: - updated_param_data[8].number += 1 + updated_param_data[9].number += 1 elif isinstance(updated_param_data, ParamDict): if len(updated_param_data) == 0: updated_param_data["number"] = number diff --git a/tests/helpers.py b/tests/helpers.py index 0768263..28efe01 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,6 +20,8 @@ ParamStr, ParamNone, ParamDataclass, + ParamFile, + ParamDataFrame, ParamList, ParamDict, ) @@ -28,6 +30,18 @@ DEFAULT_STRING = "test" +class ParamTextFile(ParamFile[str]): + """Parameter text file, created using ``ParamFile``.""" + + def _save_data(self, path: str, data: str) -> None: + with open(path, "w", encoding="utf-8") as f: + f.write(data) + + def _load_data(self, path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + class EmptyParam(ParamDataclass): """Empty parameter data class""" @@ -78,13 +92,14 @@ class ComplexParam(ParamDataclass): number: float = DEFAULT_NUMBER number_init_false: float = field(init=False, default=DEFAULT_NUMBER) string: str = DEFAULT_STRING + list: list[Any] = field(default_factory=list) + dict: dict[str, Any] = field(default_factory=dict) param_int: ParamInt = ParamInt(123) param_float: ParamFloat = ParamFloat(DEFAULT_NUMBER) param_bool: ParamBool = ParamBool(False) param_str: ParamStr = ParamStr(DEFAULT_STRING) param_none: ParamNone = ParamNone() - list: list[Any] = field(default_factory=list) - dict: dict[str, Any] = field(default_factory=dict) + param_data_frame: ParamDataFrame | None = None empty_param: EmptyParam | None = None simple_param: SimpleParam | None = None no_type_validation_param: NoTypeValidationParam | None = None diff --git a/tests/test_database.py b/tests/test_database.py index 51e88c0..4b61ea3 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -33,6 +33,7 @@ ParamStr, ParamNone, ParamDataclass, + ParamDataFrame, ParamList, ParamDict, ParamDB, @@ -263,6 +264,7 @@ def test_commit_and_load_complex( param_bool: ParamBool, param_str: ParamStr, param_none: ParamNone, + param_data_frame: ParamDataFrame, empty_param: EmptyParam, simple_param: SimpleParam, subclass_param: SubclassParam, @@ -284,6 +286,7 @@ class Root(ParamDataclass): param_bool: ParamBool param_str: ParamStr param_none: ParamNone + param_data_frame: ParamDataFrame empty_param: EmptyParam simple_param: SimpleParam subclass_param: SubclassParam @@ -308,6 +311,7 @@ class Root(ParamDataclass): param_bool=param_bool, param_str=param_str, param_none=param_none, + param_data_frame=param_data_frame, empty_param=empty_param, simple_param=simple_param, subclass_param=subclass_param,