Skip to content

Commit

Permalink
#181 Add tests for ParamFile and ParamDataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhad6 committed May 6, 2024
1 parent 60c0102 commit 684dd33
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build:
- pip install poetry==1.8.2
post_install:
# See https://docs.readthedocs.io/en/stable/build-customization.html#install-dependencies-with-poetry
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install -E pydantic --without dev
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --all-extras --without dev
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry run python -m ipykernel install --user

sphinx:
Expand Down
2 changes: 2 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ All of the following can be imported from `paramdb`.
.. autoclass:: ParamStr
.. autoclass:: ParamNone
.. autoclass:: ParamDataclass
.. autoclass:: ParamFile
.. autoclass:: ParamDataFrame
.. autoclass:: ParamList
:no-members:
.. autoclass:: ParamDict
Expand Down
18 changes: 9 additions & 9 deletions paramdb/_param_data/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from paramdb._param_data._dataclasses import ParamDataclass

try:
import pandas as pd # type: ignore
import pandas as pd

PANDAS_INSTALLED = True
except ImportError:
Expand Down Expand Up @@ -38,7 +38,7 @@ def _load_data(self, path: str) -> str:
initial_data: InitVar[T | None] = None

# pylint: disable-next=arguments-differ
def __post_init__(self, initial_data: T | None = None) -> None: # type: ignore
def __post_init__(self, initial_data: T | None) -> None:
super().__post_init__()
if initial_data is not None:
self.update_data(initial_data)
Expand All @@ -51,25 +51,25 @@ def _save_data(self, path: str, data: T) -> None:
def _load_data(self, path: str) -> T:
"""Load data from the file at the given path."""

def update_data(self, data: T) -> None:
"""Update the data stored within the file represented by this object."""
self._save_data(self.path, data)
self._update_last_updated()

@property
def data(self) -> T:
"""Data stored in the file represented by this object."""
return self._load_data(self.path)

def update_data(self, data: T) -> None:
"""Update the data stored within the file represented by this object."""
self._save_data(self.path, data)
self._update_last_updated()


if PANDAS_INSTALLED:

class ParamDataFrame(ParamFile[pd.DataFrame]):
"""
Subclass of :py:class:`ParamFile`.
Pandas DataFrame stored in a CSV file (with no index). This class will only be
defined if Pandas is installed.
Parameter data Pandas DataFrame, stored in a CSV file (with no index). This
class will only be defined if Pandas is installed.
"""

def _load_data(self, path: str) -> pd.DataFrame:
Expand Down
30 changes: 28 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pylint = "^3.1.0"
black = "^24.4.0"
pytest = "^8.1.1"
pytest-cov = "^5.0.0"
pandas-stubs = "^2.2.1.240316"

[tool.poetry.group.docs.dependencies]
sphinx = "^7.3.7"
Expand Down
7 changes: 2 additions & 5 deletions tests/_param_data/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the paramdb._param_data._dataclasses module."""

from typing import Union
from typing import Union, cast
from copy import deepcopy
import pydantic
import pytest
Expand Down Expand Up @@ -43,10 +43,7 @@ def fixture_param_dataclass_object(
request: pytest.FixtureRequest,
) -> ParamDataclassObject:
"""Parameter data class object."""
param_dataclass_object: ParamDataclassObject = deepcopy(
request.getfixturevalue(request.param)
)
return param_dataclass_object
return cast(ParamDataclassObject, deepcopy(request.getfixturevalue(request.param)))


def test_param_dataclass_direct_instantiation_fails() -> None:
Expand Down
156 changes: 156 additions & 0 deletions tests/_param_data/test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Tests for the paramdb._param_data._files module."""

from typing import Any
import os
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
from paramdb import ParamFile, ParamDataFrame
from tests.helpers import ParamTextFile, capture_start_end_times


def _data_frame(rows: int, columns: int) -> pd.DataFrame:
return pd.DataFrame(
np.random.randn(rows, columns), columns=[f"col{n}" for n in range(columns)]
)


def _assert_data_equals(param_file: ParamFile[Any], data: Any) -> None:
if isinstance(param_file, ParamDataFrame):
pd.testing.assert_frame_equal(param_file.data, data)
else:
assert param_file.data == data


@pytest.fixture(name="param_file_path")
def fixture_param_file_path(tmp_path: Path) -> str:
"""Path to the parameter file."""
return str(tmp_path / "param_file_data")


@pytest.fixture(
name="param_file_path_data",
params=[
(ParamDataFrame, _data_frame(10, 10)),
(ParamDataFrame, _data_frame(100, 10)),
(ParamDataFrame, _data_frame(10, 100)),
(ParamTextFile, ""),
(ParamTextFile, "hello"),
(ParamTextFile, repr(np.random.randn(20, 20))),
],
)
def fixture_param_file_and_data(
request: pytest.FixtureRequest, param_file_path: str
) -> tuple[Any, ParamFile[Any]]:
"""Parameter file, path, and data."""
param_file_class: type[ParamFile[Any]] = request.param[0]
data = request.param[1]
return data, param_file_class(param_file_path, data)


@pytest.fixture(name="data")
def fixture_data(param_file_path_data: tuple[Any, ParamFile[Any]]) -> Any:
"""Parameter file data."""
return param_file_path_data[0]


@pytest.fixture(name="param_file")
def fixture_param_file(
param_file_path_data: tuple[Any, ParamFile[Any]]
) -> ParamFile[Any]:
"""Parameter file."""
return param_file_path_data[1]


@pytest.fixture(name="param_file_existing")
def fixture_param_file_existing(
param_file: ParamFile[Any], param_file_path: str
) -> ParamFile[Any]:
"""Parameter file that points to existing data."""
return type(param_file)(param_file_path)


@pytest.fixture(name="different_data")
def fixture_different_data(param_file: ParamFile[Any]) -> Any:
"""Data that is different than what is stored in the parameter file."""
if isinstance(param_file, ParamDataFrame):
return _data_frame(3, 3)
return "different"


def test_param_file_saves_file(
param_file_path: str, param_file: ParamFile[Any], data: pd.DataFrame
) -> None:
"""Parameter file saves data in a file."""
os.remove(param_file_path)
assert not os.path.exists(param_file_path)
type(param_file)(param_file_path, data)
assert os.path.exists(param_file_path)
if isinstance(param_file, ParamDataFrame):
pd.testing.assert_frame_equal(data, pd.read_csv(param_file_path))
else:
with open(param_file_path, "r", encoding="utf-8") as f:
assert data == f.read()


def test_param_file_path(param_file_path: str, param_file: ParamFile[Any]) -> None:
"""Parameter file returns the correct path."""
assert param_file.path == param_file_path


def test_param_file_data(param_file: ParamFile[Any], data: Any) -> None:
"""Parameter file loads the correct data."""
_assert_data_equals(param_file, data)


def test_param_file_existing_data(
param_file: ParamFile[Any],
param_file_existing: ParamFile[Any],
data: Any,
) -> None:
"""Parameter file pointing to existing data can load that data."""
_assert_data_equals(param_file_existing, data)
_assert_data_equals(param_file_existing, param_file.data)


def test_param_file_update_path(
tmp_path: Path,
param_file: ParamFile[Any],
different_data: Any,
) -> None:
"""Parameter file can update its path."""
different_data_frame_path = str(tmp_path / "different_param_file_data")
type(param_file)(different_data_frame_path, different_data)
with capture_start_end_times() as times:
param_file.path = different_data_frame_path
assert times.start <= param_file.last_updated.timestamp() <= times.end
assert param_file.path == different_data_frame_path
_assert_data_equals(param_file, different_data)


def test_param_file_update_data(
param_file: ParamFile[Any], different_data: Any
) -> None:
"""Parameter file can update the data file it points to."""
with capture_start_end_times() as times:
param_file.update_data(different_data)
assert times.start <= param_file.last_updated.timestamp() <= times.end
_assert_data_equals(param_file, different_data)


def test_param_file_frame_existing_update_data(
param_file: ParamFile[Any],
param_file_existing: ParamFile[Any],
different_data: Any,
) -> None:
"""
Parameter file pointing to existing data can update data, which updates the data
loaded by the original Parameter DataFrame.
"""
with capture_start_end_times() as times:
param_file_existing.update_data(different_data)
assert times.start <= param_file_existing.last_updated.timestamp() <= times.end
assert param_file.last_updated.timestamp() <= times.end
_assert_data_equals(param_file_existing, different_data)
_assert_data_equals(param_file, different_data)
8 changes: 6 additions & 2 deletions tests/_param_data/test_param_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
import pytest
from tests.helpers import ComplexParam, Times, capture_start_end_times
from paramdb import ParamData
from paramdb import ParamData, ParamDataFrame
from paramdb._param_data._param_data import get_param_class


Expand Down Expand Up @@ -45,7 +45,11 @@ def test_get_param_class(param_data: ParamData) -> None:
def test_param_data_initial_last_updated(param_data_type: type[ParamData]) -> None:
"""New parameter data objects are initialized with a last updated timestamp."""
with capture_start_end_times() as times:
new_param_data = param_data_type()
new_param_data: ParamData
if issubclass(param_data_type, ParamDataFrame):
new_param_data = param_data_type("")
else:
new_param_data = param_data_type()
assert new_param_data.last_updated is not None
assert times.start < new_param_data.last_updated.timestamp() < times.end

Expand Down
Loading

0 comments on commit 684dd33

Please sign in to comment.