diff --git a/CHANGELOG.md b/CHANGELOG.md index b97a9cf..b7237cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - If Pydantic is installed, parameter data classes automatically have Pydantic type validation enabled. +- Parameter primitives classes: `ParamInt`, `ParamFloat`, `ParamBool`, `ParamStr`, and + `ParamNone`. ### Changed diff --git a/docs/api-reference.md b/docs/api-reference.md index 93b668b..3f21e4c 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -10,11 +10,20 @@ All of the following can be imported from `paramdb`. ```{eval-rst} .. autoclass:: ParamData +.. autoclass:: ParamInt +.. autoclass:: ParamFloat +.. autoclass:: ParamBool +.. autoclass:: ParamStr +.. autoclass:: ParamNone .. autoclass:: ParamDataclass .. autoclass:: ParamList + :no-members: .. autoclass:: ParamDict + :no-members: .. autoclass:: ParentType + :no-members: .. autoclass:: RootType + :no-members: ``` ## Database diff --git a/docs/conf.py b/docs/conf.py index c3e4b9c..278964a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,11 @@ # Autodoc options # See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration autodoc_default_options = {"members": True, "member-order": "bysource"} -autodoc_inherit_docstrings = False +autodoc_type_aliases = { + "ConvertibleToInt": "ConvertibleToInt", + "ConvertibleToFloat": "ConvertibleToFloat", +} +# autodoc_inherit_docstrings = False add_module_names = False diff --git a/docs/parameter-data.md b/docs/parameter-data.md index 47a93c5..ba5c4f0 100644 --- a/docs/parameter-data.md +++ b/docs/parameter-data.md @@ -19,7 +19,7 @@ defines some core functionality for this data, including the {py:class}`ParamData` are automatically registered with ParamDB so that they can be loaded to and from JSON, which is how they are stored in the database. -All of the classes described on this page are subclasses of {py:class}`ParamData`. +All of the "Param" classes described on this page are subclasses of {py:class}`ParamData`. ```{important} Any data that is going to be stored in a ParamDB database must be a JSON serializable @@ -28,6 +28,46 @@ type (`str`, `int`, `float`, `bool`, `None`, `dict`, or `list`), a [`datetime`], a `TypeError` will be raised when they are committed to the database. ``` +## Primitives + +Primitives are the building blocks of parameter data. While builtin primitive types can +be used in a ParamDB (`int`, `float`, `str`, `bool`, and `None`), they will not store a +{py:class}`~ParamData.last_updated` time and will not have {py:class}`~ParamData.parent` +or {py:class}`~ParamData.root` properties. When these features are desired, we can wrap +primitive values in the following types: + +- {py:class}`ParamInt` for integers +- {py:class}`ParamFloat` for float +- {py:class}`ParamBool` for booleans +- {py:class}`ParamStr` for strings +- {py:class}`ParamNone` for `None` + +For example: + +```{jupyter-execute} +from paramdb import ParamInt + +param_int = ParamInt(123) +param_int +``` + +```{jupyter-execute} +print(param_int.last_updated) +``` + +````{tip} +Methods from the builtin primitive types work on parameter primitives, with the caveat +that they return the builtin type. For example: + +```{jupyter-execute} +param_int + 123 +``` + +```{jupyter-execute} +type(param_int + 123) +``` +```` + ## Data Classes A parameter data class is defined from the base class {py:class}`ParamDataclass`. This @@ -37,18 +77,18 @@ function is generated. An example of a defining a custom parameter Data Class is below. ```{jupyter-execute} -from paramdb import ParamDataclass +from paramdb import ParamFloat, ParamDataclass class CustomParam(ParamDataclass): - value: float + value: ParamFloat -custom_param = CustomParam(value=1.23) +custom_param = CustomParam(value=ParamFloat(1.23)) ``` These properties can then be accessed and updated. ```{jupyter-execute} -custom_param.value += 0.004 +custom_param.value = ParamFloat(1.234) custom_param.value ``` @@ -85,13 +125,13 @@ decorator. For example: ```{jupyter-execute} class ParamWithProperty(ParamDataclass): - value: int + value: ParamInt @property def value_cubed(self) -> int: return self.value ** 3 -param_with_property = ParamWithProperty(value=16) +param_with_property = ParamWithProperty(value=ParamInt(16)) param_with_property.value_cubed ``` @@ -115,15 +155,15 @@ Parameter data track when any of their properties were last updated, and this va accessed by the read-only {py:attr}`~ParamData.last_updated` property. For example: ```{jupyter-execute} -custom_param.last_updated +print(custom_param.last_updated) ``` ```{jupyter-execute} import time time.sleep(1) -custom_param.value += 1 -custom_param.last_updated +custom_param.value = ParamFloat(4.56) +print(custom_param.last_updated) ``` Parameter dataclasses can also be nested, in which case the @@ -136,14 +176,14 @@ class NestedParam(ParamDataclass): value: float child_param: CustomParam -nested_param = NestedParam(value=1.23, child_param=CustomParam(value=4.56)) -nested_param.last_updated +nested_param = NestedParam(value=1.23, child_param=CustomParam(value=ParamFloat(4.56))) +print(nested_param.last_updated) ``` ```{jupyter-execute} time.sleep(1) -nested_param.child_param.value += 1 -nested_param.last_updated +nested_param.child_param.value = ParamFloat(2) +print(nested_param.last_updated) ``` You can access the parent of any parameter data using the {py:attr}`ParamData.parent` @@ -207,18 +247,18 @@ properly. For example: ```{jupyter-execute} from paramdb import ParamList -param_list = ParamList([CustomParam(value=1), CustomParam(value=2), CustomParam(value=3)]) +param_list = ParamList([ParamInt(1), ParamInt(2), ParamInt(3)]) param_list[1].parent is param_list ``` ```{jupyter-execute} -param_list.last_updated +print(param_list.last_updated) ``` ```{jupyter-execute} time.sleep(1) -param_list[1].value += 1 -param_list.last_updated +param_list[1] = ParamInt(4) +print(param_list.last_updated) ``` ### Parameter Dictionaries @@ -231,28 +271,24 @@ example: ```{jupyter-execute} from paramdb import ParamDict -param_dict = ParamDict( - p1=CustomParam(value=1.23), - p2=CustomParam(value=4.56), - p3=CustomParam(value=7.89), -) +param_dict = ParamDict(p1=ParamFloat(1.23), p2=ParamFloat(4.56), p3=ParamFloat(7.89)) param_dict.p2.root == param_dict ``` ```{jupyter-execute} -param_list.last_updated +print(param_dict.last_updated) ``` ```{jupyter-execute} time.sleep(1) -param_list[1].value += 1 -param_list.last_updated +param_dict.p2 = ParamFloat(0) +print(param_dict.last_updated) ``` Parameter collections can also be subclassed to provide custom functionality. For example: ```{jupyter-execute} -class CustomDict(ParamDict[CustomParam]): +class CustomDict(ParamDict[ParamFloat]): @property def total(self) -> float: return sum(param.value for param in self.values()) diff --git a/paramdb/__init__.py b/paramdb/__init__.py index 7232745..56f0a8d 100644 --- a/paramdb/__init__.py +++ b/paramdb/__init__.py @@ -1,6 +1,13 @@ """Python package for storing and retrieving experiment parameters.""" from paramdb._param_data._param_data import ParamData +from paramdb._param_data._primitives import ( + ParamInt, + ParamBool, + ParamFloat, + ParamStr, + ParamNone, +) from paramdb._param_data._dataclasses import ParamDataclass from paramdb._param_data._collections import ParamList, ParamDict from paramdb._param_data._type_mixins import ParentType, RootType @@ -8,6 +15,11 @@ __all__ = [ "ParamData", + "ParamInt", + "ParamBool", + "ParamFloat", + "ParamStr", + "ParamNone", "ParamDataclass", "ParamList", "ParamDict", diff --git a/paramdb/_database.py b/paramdb/_database.py index f2de3d0..42af752 100644 --- a/paramdb/_database.py +++ b/paramdb/_database.py @@ -24,7 +24,7 @@ _ASTROPY_INSTALLED = False T = TypeVar("T") -SelectT = TypeVar("SelectT", bound=Select[Any]) +_SelectT = TypeVar("_SelectT", bound=Select[Any]) CLASS_NAME_KEY = "__type" """ @@ -51,27 +51,6 @@ def _full_class_name(cls: type) -> str: return f"{cls.__module__}.{cls.__name__}" -def _to_dict(obj: Any) -> Any: - """ - Convert the given object into a dictionary to be passed to ``json.dumps()``. - - Note that objects within the dictionary do not need to be JSON serializable, - since they will be recursively processed by ``json.dumps()``. - """ - class_full_name = _full_class_name(type(obj)) - class_full_name_dict = {CLASS_NAME_KEY: class_full_name} - if isinstance(obj, datetime): - return class_full_name_dict | {"isoformat": obj.isoformat()} - if _ASTROPY_INSTALLED and isinstance(obj, Quantity): - return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)} - if isinstance(obj, ParamData): - return {CLASS_NAME_KEY: type(obj).__name__} | obj.to_dict() - raise TypeError( - f"'{class_full_name}' object {repr(obj)} is not JSON serializable, so the" - " commit failed" - ) - - def _from_dict(json_dict: dict[str, Any]) -> Any: """ If the given dictionary created by ``json.loads()`` has the key ``CLASS_NAME_KEY``, @@ -96,9 +75,36 @@ def _from_dict(json_dict: dict[str, Any]) -> Any: ) +def _preprocess_json(obj: Any) -> Any: + """ + Preprocess the given object and its children into a JSON-serializable format. + Compared with ``json.dumps()``, this function can define custom logic for dealing + with subclasses of ``int``, ``float``, and ``str``. + """ + if isinstance(obj, ParamData): + return {CLASS_NAME_KEY: type(obj).__name__} | _preprocess_json(obj.to_dict()) + if isinstance(obj, (int, float, bool, str)) or obj is None: + return obj + if isinstance(obj, (list, tuple)): + return [_preprocess_json(value) for value in obj] + if isinstance(obj, dict): + return {key: _preprocess_json(value) for key, value in obj.items()} + class_full_name = _full_class_name(type(obj)) + class_full_name_dict = {CLASS_NAME_KEY: class_full_name} + if isinstance(obj, datetime): + return class_full_name_dict | {"isoformat": obj.isoformat()} + if _ASTROPY_INSTALLED and isinstance(obj, Quantity): + return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)} + raise TypeError( + f"'{class_full_name}' object {repr(obj)} is not JSON serializable, so the" + " commit failed" + ) + + def _encode(obj: Any) -> bytes: """Encode the given object into bytes that will be stored in the database.""" - return _compress(json.dumps(obj, default=_to_dict)) + # pylint: disable=no-member + return _compress(json.dumps(_preprocess_json(obj))) def _decode(data: bytes, load_classes: bool) -> Any: @@ -194,7 +200,7 @@ def _index_error(self, commit_id: int | None) -> IndexError: else f"commit {commit_id} does not exist in database" f" '{self._path}'" ) - def _select_commit(self, select_stmt: SelectT, commit_id: int | None) -> SelectT: + def _select_commit(self, select_stmt: _SelectT, commit_id: int | None) -> _SelectT: """ Modify the given ``_Snapshot`` select statement to return the commit specified by the given commit ID, or the latest commit if the commit ID is None. @@ -206,8 +212,8 @@ def _select_commit(self, select_stmt: SelectT, commit_id: int | None) -> SelectT ) def _select_slice( - self, select_stmt: SelectT, start: int | None, end: int | None - ) -> SelectT: + self, select_stmt: _SelectT, start: int | None, end: int | None + ) -> _SelectT: """ Modify the given Snapshot select statement to sort by commit ID and return the slice specified by the given start and end indices. diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index c7439ba..46142cc 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -18,14 +18,14 @@ from paramdb._param_data._param_data import ParamData T = TypeVar("T") -CollectionT = TypeVar("CollectionT", bound=Collection[Any]) +_CollectionT = TypeVar("_CollectionT", bound=Collection[Any]) # pylint: disable-next=abstract-method -class _ParamCollection(ParamData, Generic[CollectionT]): +class _ParamCollection(ParamData, Generic[_CollectionT]): """Base class for parameter collections.""" - _contents: CollectionT + _contents: _CollectionT def __len__(self) -> int: return len(self._contents) @@ -41,12 +41,12 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: return f"{type(self).__name__}({self._contents})" - def _to_json(self) -> CollectionT: + def _to_json(self) -> _CollectionT: return self._contents @classmethod @abstractmethod - def _from_json(cls, json_data: CollectionT) -> Self: ... + def _from_json(cls, json_data: _CollectionT) -> Self: ... class ParamList(_ParamCollection[list[T]], MutableSequence[T], Generic[T]): diff --git a/paramdb/_param_data/_dataclasses.py b/paramdb/_param_data/_dataclasses.py index 3dd248d..717f7d3 100644 --- a/paramdb/_param_data/_dataclasses.py +++ b/paramdb/_param_data/_dataclasses.py @@ -1,4 +1,4 @@ -"""Base classes for parameter dataclasses.""" +"""Base class for parameter dataclasses.""" from __future__ import annotations from typing import Any @@ -73,7 +73,7 @@ def __init_subclass__( if pydantic_config is not None: # Merge new Pydantic config with the old one cls.__pydantic_config = cls.__pydantic_config | pydantic_config - cls.__base_setattr = object.__setattr__ # type: ignore + cls.__base_setattr = super().__setattr__ # type: ignore if _PYDANTIC_INSTALLED and cls.__type_validation: # Transform the class into a Pydantic data class, with custom handling for # validate_assignment diff --git a/paramdb/_param_data/_param_data.py b/paramdb/_param_data/_param_data.py index 843fbf0..3d50546 100644 --- a/paramdb/_param_data/_param_data.py +++ b/paramdb/_param_data/_param_data.py @@ -54,7 +54,7 @@ def _update_last_updated(self) -> None: # Continue up the chain of parents, stopping if we reach a last updated # timestamp that is more recent than the new one - while current and not ( + while current is not None and not ( current._last_updated and current._last_updated >= new_last_updated ): super(ParamData, current).__setattr__("_last_updated", new_last_updated) @@ -69,8 +69,8 @@ def _to_json(self) -> Any: The last updated timestamp is handled separately and does not need to be saved here. - Note that objects within the dictionary do not need to be JSON serializable, - since they will be recursively processed by ``json.dumps()``. + Note that objects within a list or dictionary returned by this function do not + need to be JSON serializable, since they will be processed recursively. """ @classmethod diff --git a/paramdb/_param_data/_primitives.py b/paramdb/_param_data/_primitives.py new file mode 100644 index 0000000..f5c4683 --- /dev/null +++ b/paramdb/_param_data/_primitives.py @@ -0,0 +1,188 @@ +"""Parameter data primitive classes.""" + +from __future__ import annotations +from typing import ( + Union, + Protocol, + TypeVar, + Generic, + SupportsInt, + SupportsFloat, + SupportsIndex, + Any, + overload, +) +from abc import abstractmethod +from typing_extensions import Self, Buffer +from paramdb._param_data._param_data import ParamData + +_T = TypeVar("_T") + + +# Based on https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi +class _SupportsTrunc(Protocol): + def __trunc__(self) -> int: ... + + +# Based on https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi +ConvertibleToInt = Union[str, Buffer, SupportsInt, SupportsIndex, _SupportsTrunc] +ConvertibleToFloat = Union[str, Buffer, SupportsFloat, SupportsIndex] + + +class _ParamPrimitive(ParamData, Generic[_T]): + """Base class for parameter primitives.""" + + @property + @abstractmethod + def value(self) -> _T: + """Primitive value stored by this parameter primitive.""" + + def _to_json(self) -> _T: + return self.value + + @classmethod + def _from_json(cls, json_data: _T) -> Self: + return cls(json_data) # type: ignore # pylint: disable=too-many-function-args + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.value!r})" + + +class ParamInt(int, _ParamPrimitive[int]): + """ + Subclass of :py:class:`ParamData` and ``int``. + + Parameter data integer. All ``int`` methods and operations are available; however, + note that ordinary ``int`` objects will be returned. + """ + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + @overload + def __init__(self, x: ConvertibleToInt = 0, /): ... + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + @overload + def __init__(self, x: str | bytes | bytearray, /, base: SupportsIndex = 10): ... + + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + super().__init__() + + @property + def value(self) -> int: + return int(self) + + def __repr__(self) -> str: + # Override __repr__() from base class int + return _ParamPrimitive.__repr__(self) + + +class ParamBool(ParamInt): + """ + Subclass of :py:class:`ParamInt`. + + Parameter data boolean. All ``int`` and ``bool`` methods and operations are + available; however, note that ordinary ``int`` objects (``0`` or ``1``) will be + returned. + """ + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + def __new__(cls, x: object = False, /) -> Self: + # Convert any object to a boolean to emulate bool() + return super().__new__(cls, bool(x)) + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + # pylint: disable-next=unused-argument + def __init__(self, o: object = False, /) -> None: + super().__init__() + + @property + def value(self) -> bool: + return bool(self) + + +class ParamFloat(float, _ParamPrimitive[float]): + """ + Subclass of :py:class:`ParamData` and ``float``. + + Parameter data float. All ``float`` methods and operations are available; however, + note that ordinary ``float`` objects will be returned. + """ + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + # pylint: disable-next=unused-argument + def __init__(self, x: ConvertibleToFloat = 0.0, /) -> None: + super().__init__() + + @property + def value(self) -> float: + return float(self) + + def __repr__(self) -> str: + # Override __repr__() from base class float + return _ParamPrimitive.__repr__(self) + + +class ParamStr(str, _ParamPrimitive[str]): + """ + Subclass of :py:class:`ParamData` and ``str``. + + Parameter data string. All ``str`` methods and operations are available; however, + note that ordinary ``str`` objects will be returned. + """ + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + @overload + def __init__( + self, + object: object = "", # pylint: disable=redefined-builtin + /, + ) -> None: ... + + # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi + @overload + def __init__( + self, + object: Buffer = b"", # pylint: disable=redefined-builtin + encoding: str = "utf-8", + errors: str = "strict", + ) -> None: ... + + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + super().__init__() + + @property + def value(self) -> str: + return str(self) + + def __repr__(self) -> str: + # Override __repr__() from base class str + return _ParamPrimitive.__repr__(self) + + +class ParamNone(_ParamPrimitive[None]): + """ + Subclass of :py:class:`ParamData`. + + Parameter data ``None``. Just like ``None``, its truth value is false. + """ + + @property + def value(self) -> None: + return None + + @classmethod + def _from_json(cls, json_data: None) -> Self: + return cls() + + def __bool__(self) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return other is None or isinstance(other, ParamNone) + + def __hash__(self) -> int: + return hash(self.value) + + def __repr__(self) -> str: + # Show empty parentheses + return f"{type(self).__name__}()" diff --git a/tests/_param_data/test_collections.py b/tests/_param_data/test_collections.py index feb3ea7..62adb04 100644 --- a/tests/_param_data/test_collections.py +++ b/tests/_param_data/test_collections.py @@ -1,6 +1,6 @@ """Tests for the paramdb._param_data._collections module.""" -from typing import Union, Any +from typing import Union, Any, cast from copy import deepcopy import pytest from tests.helpers import ( @@ -21,8 +21,7 @@ ) def fixture_param_collection(request: pytest.FixtureRequest) -> ParamCollection: """Parameter collection.""" - param_collection: ParamCollection = deepcopy(request.getfixturevalue(request.param)) - return param_collection + return cast(ParamCollection, deepcopy(request.getfixturevalue(request.param))) @pytest.fixture(name="param_collection_type") @@ -59,8 +58,7 @@ def fixture_custom_param_collection_type( """Custom parameter collection subclass.""" if isinstance(param_collection, ParamList): return CustomParamList - if isinstance(param_collection, ParamDict): - return CustomParamDict + return CustomParamDict @pytest.fixture(name="custom_param_collection") @@ -341,6 +339,18 @@ def test_param_list_del_parent( _ = param_data.parent +def test_param_list_empty_last_updated() -> None: + """ + A parameter list updates its last updated time when it becomes empty. (A previous + bug only updated ``last_updated`` if the ``ParamData`` object had a truth value of + true.) + """ + param_list = ParamList([123]) + with capture_start_end_times() as times: + del param_list[0] + assert times.start < param_list.last_updated.timestamp() < times.end + + def test_param_dict_key_error(param_dict: ParamDict[Any]) -> None: """Getting or deleting a nonexistent key raises a KeyError.""" with pytest.raises(KeyError): @@ -454,6 +464,18 @@ def test_param_dict_del_parent( _ = param_data.parent +def test_param_dict_empty_last_updated() -> None: + """ + A parameter dictionary updates its last updated time when it becomes empty. (A + previous bug only updated ``last_updated`` if the ``ParamData`` object had a truth + value of true.) + """ + param_dict = ParamDict(test=123) + with capture_start_end_times() as times: + del param_dict.test + assert times.start < param_dict.last_updated.timestamp() < times.end + + def test_param_dict_iter( param_dict: ParamDict[Any], param_dict_contents: dict[str, Any] ) -> None: diff --git a/tests/_param_data/test_primitives.py b/tests/_param_data/test_primitives.py new file mode 100644 index 0000000..be2ce5c --- /dev/null +++ b/tests/_param_data/test_primitives.py @@ -0,0 +1,240 @@ +"""Tests for the paramdb._param_data._primitives module.""" + +from typing import Union, cast +from copy import deepcopy +import math +import pytest +from paramdb import ParamInt, ParamFloat, ParamBool, ParamStr, ParamNone +from tests.helpers import ( + SimpleParam, + CustomParamInt, + CustomParamFloat, + CustomParamBool, + CustomParamStr, + CustomParamNone, +) + +ParamPrimitive = Union[ParamInt, ParamFloat, ParamBool, ParamStr, ParamNone] +CustomParamPrimitive = Union[ + CustomParamInt, CustomParamFloat, CustomParamBool, CustomParamStr, CustomParamNone +] + + +@pytest.fixture( + name="param_primitive", + params=["param_int", "param_float", "param_bool", "param_str", "param_none"], +) +def fixture_param_primitive(request: pytest.FixtureRequest) -> ParamPrimitive: + """Parameter primitive.""" + return cast(ParamPrimitive, deepcopy(request.getfixturevalue(request.param))) + + +@pytest.fixture(name="custom_param_primitive") +def fixture_custom_param_primitive( + param_primitive: ParamPrimitive, +) -> CustomParamPrimitive: + """Custom parameter primitive.""" + if isinstance(param_primitive, ParamInt): + return CustomParamInt(param_primitive.value) + if isinstance(param_primitive, ParamFloat): + return CustomParamFloat(param_primitive.value) + if isinstance(param_primitive, ParamBool): + return CustomParamBool(param_primitive.value) + if isinstance(param_primitive, ParamStr): + return CustomParamStr(param_primitive.value) + return CustomParamNone() + + +def test_param_int_isinstance(param_int: ParamInt) -> None: + """Parameter integers are instances of ``int``.""" + assert isinstance(param_int, int) + assert isinstance(CustomParamInt(param_int.value), int) + + +def test_param_float_isinstance(param_float: ParamFloat) -> None: + """Parameter floats are instances of ``float``.""" + assert isinstance(param_float, float) + assert isinstance(CustomParamFloat(param_float.value), float) + + +def test_param_bool_isinstance(param_bool: ParamBool) -> None: + """Parameter booleans are instances of ``int``.""" + assert isinstance(param_bool, int) + assert isinstance(CustomParamBool(param_bool.value), int) + + +def test_param_str_isinstance(param_str: ParamStr) -> None: + """Parameter strings are instances of ``str``.""" + assert isinstance(param_str, str) + assert isinstance(CustomParamStr(param_str.value), str) + + +def test_param_int_constructor() -> None: + """The ``ParamInt`` constructor behaves like the ``int`` constructor.""" + assert ParamInt() == 0 + assert ParamInt(123) == 123 + assert ParamInt(123.0) == 123.0 + assert ParamInt(123.1) == 123 + assert ParamInt(123.9) == 123 + assert ParamInt("123") == 123 + assert ParamInt("0x42", base=16) == 66 + with pytest.raises(ValueError): + ParamInt("hello") + + +def test_param_float_constructor() -> None: + """The ``ParamFloat`` constructor behaves like the ``float`` constructor.""" + assert ParamFloat() == 0.0 + assert ParamFloat(123) == 123.0 + assert ParamFloat(1.23) == 1.23 + assert ParamFloat("1.23") == 1.23 + assert ParamFloat("inf") == float("inf") + assert math.isnan(ParamFloat("nan")) + with pytest.raises(ValueError): + ParamFloat("hello") + + +def test_param_bool_constructor() -> None: + """The ``ParamBool`` constructor behaves like the ``bool`` constructor.""" + assert ParamBool().value is False + assert ParamBool(123).value is True + assert ParamBool("").value is False + assert ParamBool("hello").value is True + + +def test_param_str_constructor() -> None: + """The ``ParamStr`` constructor behaves like the ``str`` constructor.""" + assert ParamStr() == "" + assert ParamStr("hello") == "hello" + assert ParamStr(123) == "123" + assert ParamStr(b"hello", encoding="utf-8") == "hello" + + +def test_param_int_value() -> None: + """Can access the primitive value of a parameter integer.""" + param_int_value = ParamInt(123).value + assert type(param_int_value) is int # pylint: disable=unidiomatic-typecheck + assert param_int_value == 123 + + +def test_param_float_value(number: float) -> None: + """Can access the primitive value of a parameter float.""" + param_float_value = ParamFloat(number).value + assert type(param_float_value) is float # pylint: disable=unidiomatic-typecheck + assert param_float_value == number + + +def test_param_bool_value() -> None: + """Can access the primitive value of a parameter boolean.""" + assert ParamBool(True).value is True + assert ParamBool(False).value is False + assert ParamBool(0).value is False + assert ParamBool(123).value is True + + +def test_param_str_value(string: str) -> None: + """Can access the primitive value of a parameter string.""" + param_str_value = ParamStr(string).value + assert type(param_str_value) is str # pylint: disable=unidiomatic-typecheck + assert param_str_value == string + + +def test_param_none_value() -> None: + """Can access the primitive value of a parameter ``None``.""" + assert ParamNone().value is None + + +def test_param_primitive_repr( + param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive +) -> None: + """Can represent a parameter primitive as a string using ``repr()``.""" + if isinstance(param_primitive, ParamNone): + assert repr(param_primitive) == f"{ParamNone.__name__}()" + assert repr(custom_param_primitive) == f"{CustomParamNone.__name__}()" + else: + assert ( + repr(param_primitive) + == f"{type(param_primitive).__name__}({param_primitive.value!r})" + ) + assert ( + repr(custom_param_primitive) == f"{type(custom_param_primitive).__name__}" + f"({custom_param_primitive.value!r})" + ) + + +def test_param_primitive_bool( + param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive +) -> None: + """Parameter primitive objects have the correct truth values.""" + assert bool(param_primitive) is bool(param_primitive.value) + assert bool(custom_param_primitive) is bool(custom_param_primitive.value) + + +def test_param_primitive_eq( + param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive +) -> None: + """ + Parameter primitive objects are equal to themselves, their vaues, and custom + parameter primitive objects. + """ + # pylint: disable=comparison-with-itself + assert param_primitive == param_primitive + assert param_primitive == deepcopy(param_primitive) + assert param_primitive == custom_param_primitive + assert param_primitive == param_primitive.value + + +def test_param_primitive_ne( + simple_param: SimpleParam, + param_primitive: ParamPrimitive, + custom_param_primitive: CustomParamPrimitive, +) -> None: + """ + Parameter primitive objects are not equal to other objects or parameter primitives + with different values. + """ + assert param_primitive != simple_param + assert custom_param_primitive != simple_param + if not isinstance(param_primitive, ParamNone): + assert param_primitive != type(param_primitive)() + assert custom_param_primitive != type(custom_param_primitive)() + + +def test_param_primitive_hash( + param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive +) -> None: + """Parameter primitive objects has the same hash as objects they are equal to.""" + assert hash(param_primitive) == hash(deepcopy(param_primitive)) + assert hash(param_primitive) == hash(custom_param_primitive) + assert hash(param_primitive) == hash(param_primitive.value) + + +def test_param_int_methods_return_int(param_int: ParamInt) -> None: + """``ParamInt`` methods inherited from ``int`` return ``int`` objects.""" + # pylint: disable=unidiomatic-typecheck + assert type(param_int + 123) is int + assert type(-param_int) is int + assert type(param_int.real) is int + + +def test_param_float_methods_return_float(param_float: ParamFloat) -> None: + """``ParamFloat`` methods inherited from ``float`` return ``float`` objects.""" + # pylint: disable=unidiomatic-typecheck + assert type(param_float + 123) is float + assert type(-param_float) is float + assert type(param_float.real) is float + + +def test_param_bool_methods_return_int() -> None: + """``ParamBool`` methods inherited from ``int`` return ``int`` objects.""" + # pylint: disable=unidiomatic-typecheck + assert type(ParamBool(True) ^ ParamBool(False)) is int + assert type(ParamBool(True) | ParamBool(False)) is int + + +def test_param_str_methods_return_str(param_str: ParamStr) -> None: + """``ParamStr`` methods inherited from ``str`` return ``str`` objects.""" + # pylint: disable=unidiomatic-typecheck + assert type(param_str + "") is str + assert type(param_str[::-1]) is str + assert type(param_str.capitalize()) is str diff --git a/tests/conftest.py b/tests/conftest.py index c9cf7e1..5f463ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,16 @@ from typing import Any from copy import deepcopy import pytest -from paramdb import ParamData, ParamList, ParamDict +from paramdb import ( + ParamData, + ParamInt, + ParamFloat, + ParamBool, + ParamStr, + ParamNone, + ParamList, + ParamDict, +) from tests.helpers import ( DEFAULT_NUMBER, DEFAULT_STRING, @@ -32,6 +41,36 @@ def fixture_string() -> str: return DEFAULT_STRING +@pytest.fixture(name="param_int") +def fixture_param_int() -> ParamInt: + """Parameter integer object.""" + return ParamInt(123) + + +@pytest.fixture(name="param_float") +def fixture_param_float(number: float) -> ParamFloat: + """Parameter float object.""" + return ParamFloat(number) + + +@pytest.fixture(name="param_bool") +def fixture_param_bool() -> ParamBool: + """Parameter boolean object.""" + return ParamBool(True) + + +@pytest.fixture(name="param_str") +def fixture_param_str(string: str) -> ParamStr: + """Parameter string object.""" + return ParamStr(string) + + +@pytest.fixture(name="param_none") +def fixture_param_none() -> ParamNone: + """Parameter ``None`` object.""" + return ParamNone() + + @pytest.fixture(name="empty_param") def fixture_empty_param() -> EmptyParam: """Empty parameter data class object.""" @@ -109,6 +148,11 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: return [ number, string, + ParamInt(), + ParamFloat(number), + ParamBool(), + ParamStr(string), + ParamNone(), EmptyParam(), SimpleParam(), NoTypeValidationParam(), @@ -127,6 +171,11 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: def fixture_param_dict_contents( number: float, string: str, + param_int: ParamInt, + param_float: ParamFloat, + param_bool: ParamBool, + param_str: ParamStr, + param_none: ParamNone, empty_param: EmptyParam, simple_param: SimpleParam, no_type_validation_param: NoTypeValidationParam, @@ -140,6 +189,11 @@ def fixture_param_dict_contents( return { "number": number, "string": string, + "param_int": param_int, + "param_float": deepcopy(param_float), + "param_bool": deepcopy(param_bool), + "param_str": deepcopy(param_str), + "param_none": deepcopy(param_none), "empty_param": deepcopy(empty_param), "simple_param": deepcopy(simple_param), "no_type_validation_param": deepcopy(no_type_validation_param), @@ -180,6 +234,11 @@ def fixture_param_dict(param_dict_contents: dict[str, Any]) -> ParamDict[Any]: @pytest.fixture( name="param_data", params=[ + "param_int", + "param_float", + "param_bool", + "param_str", + "param_none", "empty_param", "simple_param", "no_type_validation_param", @@ -208,9 +267,13 @@ def fixture_updated_param_data_and_times( """ updated_param_data = deepcopy(param_data) with capture_start_end_times() as times: - if isinstance(updated_param_data, EmptyParam): - # pylint: disable-next=protected-access - updated_param_data._update_last_updated() + if isinstance( + updated_param_data, + (ParamInt, ParamFloat, ParamBool, ParamStr), + ): + updated_param_data = type(updated_param_data)(updated_param_data.value) + elif isinstance(updated_param_data, (ParamNone, EmptyParam)): + updated_param_data = type(updated_param_data)() elif isinstance(updated_param_data, SimpleParam): updated_param_data.number += 1 elif isinstance(updated_param_data, SubclassParam): @@ -222,7 +285,7 @@ def fixture_updated_param_data_and_times( if len(updated_param_data) == 0: updated_param_data.append(number) else: - updated_param_data[3].number += 1 + updated_param_data[8].number += 1 elif isinstance(updated_param_data, ParamDict): if len(updated_param_data) == 0: updated_param_data["number"] = number diff --git a/tests/helpers.py b/tests/helpers.py index 49403fa..0768263 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,7 +12,17 @@ import time import pydantic from astropy.units import Quantity # type: ignore # pylint: disable=import-error -from paramdb import ParamData, ParamDataclass, ParamList, ParamDict +from paramdb import ( + ParamData, + ParamInt, + ParamBool, + ParamFloat, + ParamStr, + ParamNone, + ParamDataclass, + ParamList, + ParamDict, +) DEFAULT_NUMBER = 1.23 DEFAULT_STRING = "test" @@ -29,6 +39,11 @@ class SimpleParam(ParamDataclass): number_init_false: float = field(init=False, default=DEFAULT_NUMBER) number_with_units: Quantity = Quantity(12, "m") string: str = DEFAULT_STRING + param_int: ParamInt = ParamInt(123) + param_float: ParamFloat = ParamFloat(DEFAULT_NUMBER) + param_bool: ParamBool = ParamBool(False) + param_str: ParamStr = ParamStr(DEFAULT_STRING) + param_none: ParamNone = ParamNone() class NoTypeValidationParam(SimpleParam, type_validation=False): @@ -63,6 +78,11 @@ class ComplexParam(ParamDataclass): number: float = DEFAULT_NUMBER number_init_false: float = field(init=False, default=DEFAULT_NUMBER) string: str = DEFAULT_STRING + param_int: ParamInt = ParamInt(123) + param_float: ParamFloat = ParamFloat(DEFAULT_NUMBER) + param_bool: ParamBool = ParamBool(False) + param_str: ParamStr = ParamStr(DEFAULT_STRING) + param_none: ParamNone = ParamNone() list: list[Any] = field(default_factory=list) dict: dict[str, Any] = field(default_factory=dict) empty_param: EmptyParam | None = None @@ -86,6 +106,26 @@ class CustomParamDict(ParamDict[Any]): """Custom parameter dictionary subclass.""" +class CustomParamInt(ParamInt): + """Custom parameter integer subclass.""" + + +class CustomParamFloat(ParamFloat): + """Custom parameter float subclass.""" + + +class CustomParamBool(ParamBool): + """Custom parameter boolean subclass.""" + + +class CustomParamStr(ParamStr): + """Custom parameter string subclass.""" + + +class CustomParamNone(ParamNone): + """Custom parameter ``None`` subclass.""" + + @dataclass class Times: """Start and end times captured by ``capture_start_end_times()``.""" diff --git a/tests/test_database.py b/tests/test_database.py index 0c5fa40..51e88c0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -15,6 +15,11 @@ SimpleParam, SubclassParam, ComplexParam, + CustomParamInt, + CustomParamFloat, + CustomParamBool, + CustomParamStr, + CustomParamNone, CustomParamList, CustomParamDict, Times, @@ -22,6 +27,11 @@ ) from paramdb import ( ParamData, + ParamInt, + ParamFloat, + ParamBool, + ParamStr, + ParamNone, ParamDataclass, ParamList, ParamDict, @@ -248,6 +258,11 @@ def test_commit_and_load_complex( string: str, param_list_contents: list[Any], param_dict_contents: dict[str, Any], + param_int: ParamInt, + param_float: ParamFloat, + param_bool: ParamBool, + param_str: ParamStr, + param_none: ParamNone, empty_param: EmptyParam, simple_param: SimpleParam, subclass_param: SubclassParam, @@ -264,12 +279,22 @@ class Root(ParamDataclass): string: str list: list[Any] dict: dict[str, Any] + param_int: ParamInt + param_float: ParamFloat + param_bool: ParamBool + param_str: ParamStr + param_none: ParamNone empty_param: EmptyParam simple_param: SimpleParam subclass_param: SubclassParam complex_param: ComplexParam param_list: ParamList[Any] param_dict: ParamDict[Any] + custom_param_int: CustomParamInt + custom_param_float: CustomParamFloat + custom_param_bool: CustomParamBool + custom_param_str: CustomParamStr + custom_param_none: CustomParamNone custom_param_list: CustomParamList custom_param_dict: CustomParamDict @@ -278,12 +303,22 @@ class Root(ParamDataclass): string=string, list=param_list_contents, dict=param_dict_contents, + param_int=param_int, + param_float=param_float, + param_bool=param_bool, + param_str=param_str, + param_none=param_none, empty_param=empty_param, simple_param=simple_param, subclass_param=subclass_param, complex_param=complex_param, param_list=param_list, param_dict=param_dict, + custom_param_int=CustomParamInt(param_int.value), + custom_param_float=CustomParamFloat(param_float.value), + custom_param_bool=CustomParamBool(param_bool.value), + custom_param_str=CustomParamStr(param_str.value), + custom_param_none=CustomParamNone(), custom_param_list=CustomParamList(deepcopy(param_list_contents)), custom_param_dict=CustomParamDict(deepcopy(param_dict_contents)), )