From 1c5f303d61b80076e1adca99cac6a54636203c13 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Tue, 14 May 2024 12:31:18 -0700 Subject: [PATCH 01/12] #187 Replace primitives with child last updated tracking --- paramdb/__init__.py | 12 -- paramdb/_database.py | 2 +- paramdb/_param_data/_collections.py | 133 ++++++++++-------- paramdb/_param_data/_dataclasses.py | 70 +++++++--- paramdb/_param_data/_param_data.py | 91 +++++++++++-- paramdb/_param_data/_primitives.py | 188 -------------------------- paramdb/_param_data/_type_mixins.py | 20 +-- tests/_param_data/test_dataclasses.py | 22 +-- tests/_param_data/test_param_data.py | 5 +- tests/helpers.py | 2 +- tests/test_database.py | 16 +-- 11 files changed, 239 insertions(+), 322 deletions(-) delete mode 100644 paramdb/_param_data/_primitives.py diff --git a/paramdb/__init__.py b/paramdb/__init__.py index e70b4c9..8465303 100644 --- a/paramdb/__init__.py +++ b/paramdb/__init__.py @@ -1,13 +1,6 @@ """Python package for storing and retrieving experiment parameters.""" from paramdb._param_data._param_data import ParamData -from paramdb._param_data._primitives import ( - ParamInt, - ParamBool, - ParamFloat, - ParamStr, - ParamNone, -) from paramdb._param_data._dataclasses import ParamDataclass from paramdb._param_data._files import ParamFile from paramdb._param_data._collections import ParamList, ParamDict @@ -16,11 +9,6 @@ __all__ = [ "ParamData", - "ParamInt", - "ParamBool", - "ParamFloat", - "ParamStr", - "ParamNone", "ParamDataclass", "ParamFile", "ParamList", diff --git a/paramdb/_database.py b/paramdb/_database.py index 42af752..581bd36 100644 --- a/paramdb/_database.py +++ b/paramdb/_database.py @@ -17,7 +17,7 @@ from paramdb._param_data._param_data import ParamData, get_param_class try: - from astropy.units import Quantity # type: ignore + from astropy.units import Quantity # type: ignore[import-untyped] _ASTROPY_INSTALLED = True except ImportError: diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index 46142cc..8d8d733 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -1,7 +1,7 @@ """Parameter data collection classes.""" from __future__ import annotations -from typing import TypeVar, Generic, SupportsIndex, Any, overload +from typing import Union, TypeVar, Generic, SupportsIndex, Any, cast, overload from collections.abc import ( Iterator, Collection, @@ -9,20 +9,17 @@ Mapping, MutableSequence, MutableMapping, - KeysView, - ValuesView, - ItemsView, ) -from abc import abstractmethod from typing_extensions import Self -from paramdb._param_data._param_data import ParamData +from paramdb._param_data._param_data import ParamData, _ParamWrapper T = TypeVar("T") +_ChildNameT = TypeVar("_ChildNameT", str, int) _CollectionT = TypeVar("_CollectionT", bound=Collection[Any]) # pylint: disable-next=abstract-method -class _ParamCollection(ParamData, Generic[_CollectionT]): +class _ParamCollection(ParamData[_ChildNameT], Generic[_ChildNameT, _CollectionT]): """Base class for parameter collections.""" _contents: _CollectionT @@ -32,11 +29,7 @@ def __len__(self) -> int: def __eq__(self, other: Any) -> bool: # Equal if they have are of the same class and their contents are equal - return ( - isinstance(other, _ParamCollection) - and type(other) is type(self) - and self._contents == other._contents - ) + return type(other) is type(self) and self._contents == other._contents def __repr__(self) -> str: return f"{type(self).__name__}({self._contents})" @@ -44,12 +37,32 @@ def __repr__(self) -> str: def _to_json(self) -> _CollectionT: return self._contents - @classmethod - @abstractmethod - def _from_json(cls, json_data: _CollectionT) -> Self: ... - + def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: + # If a TypeError, IndexError, or KeyError occurs from _contents, raise the + # superclass ValueError from the _contents exception + try: + return cast( + ParamData[Any], self._contents[child_name] # type: ignore[index] + ) + except (TypeError, IndexError, KeyError) as contents_exc: + try: + return super()._get_wrapped_child(child_name) # type: ignore[arg-type] + except ValueError as super_exc: + raise super_exc from contents_exc -class ParamList(_ParamCollection[list[T]], MutableSequence[T], Generic[T]): + @classmethod + def _from_json(cls, json_data: _CollectionT) -> Self: + # Set contents directly since __init__() will contain child wrapping logic + new_param_collection = cls() + new_param_collection._contents = json_data + return new_param_collection + + +class ParamList( + _ParamCollection[int, list[Union[T, _ParamWrapper[T]]]], + MutableSequence[T], + Generic[T], +): """ Subclass of :py:class:`ParamData` and ``MutableSequence``. @@ -59,19 +72,26 @@ class ParamList(_ParamCollection[list[T]], MutableSequence[T], Generic[T]): def __init__(self, iterable: Iterable[T] | None = None) -> None: super().__init__() - self._contents = [] if iterable is None else list(iterable) - if iterable is not None: - for item in self._contents: - self._add_child(item) + initial_contents = iterable or [] + self._contents = [self._wrap_child(item) for item in initial_contents] + for item in initial_contents: + self._add_child(item) @overload def __getitem__(self, index: SupportsIndex) -> T: ... @overload - def __getitem__(self, index: slice) -> list[T]: ... + def __getitem__(self, index: slice) -> Self: ... - def __getitem__(self, index: Any) -> Any: - return self._contents[index] + def __getitem__(self, index: SupportsIndex | slice) -> T | Self: + if isinstance(index, slice): + return type(self)( + [ + self._unwrap_child(wrapped_child) + for wrapped_child in self._contents[index] + ] + ) + return self._unwrap_child(self._contents[index]) @overload def __setitem__(self, index: SupportsIndex, value: T) -> None: ... @@ -80,35 +100,45 @@ def __setitem__(self, index: SupportsIndex, value: T) -> None: ... def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... def __setitem__(self, index: SupportsIndex | slice, value: Any) -> None: - old_value: Any = self._contents[index] - self._contents[index] = value - self._update_last_updated() if isinstance(index, slice): - for item in old_value: - self._remove_child(item) + old_values = self._contents[index] + self._contents[index] = [self._wrap_child(item) for item in value] + for old_item in old_values: + self._remove_child(old_item) for item in value: self._add_child(item) else: + old_value = self._contents[index] + self._contents[index] = self._wrap_child(value) self._remove_child(old_value) self._add_child(value) def __delitem__(self, index: SupportsIndex | slice) -> None: old_value = self._contents[index] del self._contents[index] - self._update_last_updated() - self._remove_child(old_value) + if isinstance(index, slice) and isinstance(old_value, list): + for old_item in old_value: + self._remove_child(old_item) + else: + self._remove_child(old_value) def insert(self, index: SupportsIndex, value: T) -> None: - self._contents.insert(index, value) - self._update_last_updated() + self._contents.insert(index, self._wrap_child(value)) self._add_child(value) @classmethod - def _from_json(cls, json_data: list[T]) -> Self: - return cls(json_data) + def _from_json(cls, json_data: list[T | _ParamWrapper[T]]) -> Self: + + new_obj = cls() + new_obj._contents = json_data + return new_obj -class ParamDict(_ParamCollection[dict[str, T]], MutableMapping[str, T], Generic[T]): +class ParamDict( + _ParamCollection[str, dict[str, Union[T, _ParamWrapper[T]]]], + MutableMapping[str, T], + Generic[T], +): """ Subclass of :py:class:`ParamData` and ``MutableMapping``. @@ -121,9 +151,12 @@ class ParamDict(_ParamCollection[dict[str, T]], MutableMapping[str, T], Generic[ def __init__(self, mapping: Mapping[str, T] | None = None, /, **kwargs: T): super().__init__() - self._contents = ({} if mapping is None else dict(mapping)) | kwargs - for item in self._contents.values(): - self._add_child(item) + initial_contents = {**(mapping or {}), **kwargs} + self._contents = { + key: self._wrap_child(value) for key, value in initial_contents.items() + } + for value in initial_contents.values(): + self._add_child(value) def __dir__(self) -> Iterable[str]: # Return keys that are not attribute names (i.e. do not pass self._is_attribute) @@ -134,19 +167,17 @@ def __dir__(self) -> Iterable[str]: ] def __getitem__(self, key: str) -> T: - return self._contents[key] + return self._unwrap_child(self._contents[key]) def __setitem__(self, key: str, value: T) -> None: old_value = self._contents[key] if key in self._contents else None - self._contents[key] = value - self._update_last_updated() + self._contents[key] = self._wrap_child(value) self._remove_child(old_value) self._add_child(value) def __delitem__(self, key: str) -> None: old_value = self._contents[key] if key in self._contents else None del self._contents[key] - self._update_last_updated() self._remove_child(old_value) def __iter__(self) -> Iterator[str]: @@ -184,19 +215,3 @@ def _is_attribute(self, name: str) -> bool: (i.e. dunder variables), and to allow for true attributes to be used if needed. """ return len(name) > 0 and name[0] == "_" - - def keys(self) -> KeysView[str]: - # Use dict_keys so keys print nicely - return self._contents.keys() - - def values(self) -> ValuesView[T]: - # Use dict_values so values print nicely - return self._contents.values() - - def items(self) -> ItemsView[str, T]: - # Use dict_items so items print nicely - return self._contents.items() - - @classmethod - def _from_json(cls, json_data: dict[str, T]) -> Self: - return cls(json_data) diff --git a/paramdb/_param_data/_dataclasses.py b/paramdb/_param_data/_dataclasses.py index 717f7d3..33e96f3 100644 --- a/paramdb/_param_data/_dataclasses.py +++ b/paramdb/_param_data/_dataclasses.py @@ -1,7 +1,7 @@ """Base class for parameter dataclasses.""" from __future__ import annotations -from typing import Any +from typing import Any, cast from dataclasses import dataclass, is_dataclass, fields from typing_extensions import Self, dataclass_transform from paramdb._param_data._param_data import ParamData @@ -16,7 +16,7 @@ @dataclass_transform() -class ParamDataclass(ParamData): +class ParamDataclass(ParamData[str]): """ Subclass of :py:class:`ParamData`. @@ -46,7 +46,7 @@ class CustomParam(ParamDataclass): See https://docs.pydantic.dev/latest/api/config for full configuration options. """ - __field_names: set[str] # Data class field names + _field_names: set[str] # Data class field names __type_validation: bool = True # Whether to use Pydantic __pydantic_config: pydantic.ConfigDict = { "extra": "forbid", @@ -56,9 +56,16 @@ class CustomParam(ParamDataclass): "validate_default": True, } - # Set in __init_subclass__() and used to set attributes within __setattr__() # pylint: disable-next=unused-argument - def __base_setattr(self: Any, name: str, value: Any) -> None: ... + def __base_setattr(self: Any, name: str, value: Any) -> None: + """ + If Pydantic is enabled and ``validate_assignment`` is True, this function will + both set and validate the attribute; otherwise, it will be an ordinary setattr + function. + + Set in ``__init_subclass__()`` and used to set attributes within + ``__setattr__()``. + """ def __init_subclass__( cls, @@ -73,7 +80,7 @@ def __init_subclass__( if pydantic_config is not None: # Merge new Pydantic config with the old one cls.__pydantic_config = cls.__pydantic_config | pydantic_config - cls.__base_setattr = super().__setattr__ # type: ignore + cls.__base_setattr = super().__setattr__ # type: ignore[assignment] if _PYDANTIC_INSTALLED and cls.__type_validation: # Transform the class into a Pydantic data class, with custom handling for # validate_assignment @@ -93,13 +100,11 @@ def __init_subclass__( def __base_setattr(self: Any, name: str, value: Any) -> None: pydantic_validator.validate_assignment(self, name, value) - cls.__base_setattr = __base_setattr # type: ignore + cls.__base_setattr = __base_setattr # type: ignore[method-assign] else: # Transform the class into a data class dataclass(**kwargs)(cls) - cls.__field_names = ( - {f.name for f in fields(cls)} if is_dataclass(cls) else set() - ) + cls._field_names = {f.name for f in fields(cls)} if is_dataclass(cls) else set() # pylint: disable-next=unused-argument def __new__(cls, *args: Any, **kwargs: Any) -> Self: @@ -110,12 +115,15 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self: f"only subclasses of {ParamDataclass.__name__} can be instantiated" ) self = super().__new__(cls) - super().__init__(self) + super().__init__(self) # type: ignore[arg-type] return self def __post_init__(self) -> None: - for field_name in self.__field_names: - self._add_child(getattr(self, field_name)) + # Add and wrap fields as children + for field_name in self._field_names: + child = getattr(self, field_name) + super().__setattr__(field_name, self._wrap_child(child)) + self._add_child(child) def __getitem__(self, name: str) -> Any: # Enable getting attributes via square brackets @@ -125,16 +133,38 @@ def __setitem__(self, name: str, value: Any) -> None: # Enable setting attributes via square brackets setattr(self, name, value) + def __delitem__(self, name: str) -> None: + # Enable deleting attributes via square brackets + delattr(self, name) + + def __getattribute__(self, name: str) -> Any: + # Unwrap child if the attribute is a field + value = super().__getattribute__(name) + if name in super().__getattribute__("_field_names"): + return self._unwrap_child(value) + return value + def __setattr__(self, name: str, value: Any) -> None: - # If this attribute is a Data Class field, update last updated and children - if name in self.__field_names: - old_value = getattr(self, name) if hasattr(self, name) else None + # If this attribute is a field, process the old and new child + if name in self._field_names: + old_value = getattr(self, name) + self.__base_setattr(name, value) # May perform type validation + super().__setattr__(name, self._wrap_child(value)) + self._add_child(value) + self._remove_child(old_value) + else: self.__base_setattr(name, value) - self._update_last_updated() + + def __delattr__(self, name: str) -> None: + old_value = getattr(self, name) + super().__delattr__(name) + if name in self._field_names: self._remove_child(old_value) - self._add_child(value) - return - self.__base_setattr(name, value) + + def _get_wrapped_child(self, child_name: str) -> ParamData[Any]: + if child_name in self._field_names: + return cast(ParamData[Any], super().__getattribute__(child_name)) + return super()._get_wrapped_child(child_name) def _to_json(self) -> dict[str, Any]: if is_dataclass(self): diff --git a/paramdb/_param_data/_param_data.py b/paramdb/_param_data/_param_data.py index 3d50546..bd58f12 100644 --- a/paramdb/_param_data/_param_data.py +++ b/paramdb/_param_data/_param_data.py @@ -1,35 +1,37 @@ """Base class for all parameter data.""" from __future__ import annotations -from typing import Any +from typing import Union, TypeVar, Generic, Any, cast from abc import ABC, abstractmethod from weakref import WeakValueDictionary from datetime import datetime, timezone -from typing_extensions import Self +from typing_extensions import Self, Never + +_T = TypeVar("_T") +_ChildNameT = TypeVar("_ChildNameT", bound=Union[str, int]) _LAST_UPDATED_KEY = "last_updated" """Dictionary key corresponding to a ``ParamData`` object's last updated time.""" _DATA_KEY = "data" """Dictionary key corresponding to a ``ParamData`` object's data.""" -_param_classes: WeakValueDictionary[str, type[ParamData]] = WeakValueDictionary() +_param_classes: WeakValueDictionary[str, type[ParamData[Any]]] = WeakValueDictionary() """Dictionary of weak references to existing ``ParamData`` classes.""" -def get_param_class(class_name: str) -> type[ParamData] | None: +def get_param_class(class_name: str) -> type[ParamData[Any]] | None: """Get a parameter class given its name, or ``None`` if the class does not exist.""" return _param_classes[class_name] if class_name in _param_classes else None -class ParamData(ABC): +class ParamData(ABC, Generic[_ChildNameT]): """Abstract base class for all parameter data.""" - _parent: ParamData | None = None + _parent: ParamData[Any] | None = None _last_updated: datetime def __init_subclass__(cls, /, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - # Add subclass to dictionary of ParamData classes _param_classes[cls.__name__] = cls @@ -37,20 +39,47 @@ def __init__(self) -> None: super().__setattr__("_last_updated", datetime.now(timezone.utc).astimezone()) def _add_child(self, child: Any) -> None: - """Add the given object as a child, if it is ``ParamData``.""" + """ + This method should be called to process a new child. + + If the child is ``ParamData``, its parent and last updated attributes will be + updated. + """ if isinstance(child, ParamData): super(ParamData, child).__setattr__("_parent", self) + child._update_last_updated() # pylint: disable=protected-access def _remove_child(self, child: Any) -> None: - """Remove the given object as a child, if it is ``ParamData``.""" + """ + This method should be called to process a child that has just been removed. + + If the child is ``ParamData``, its parent will be reset to ``None``. + """ if isinstance(child, ParamData): super(ParamData, child).__setattr__("_parent", None) + self._update_last_updated() + + def _wrap_child(self, child: _T) -> _T | _ParamWrapper[_T]: + """ + If the given child is not ``ParamData``, it will be wrapped by + ``_ParamWrapper``; otherwise, the original object will be returned. + """ + return cast(_T, child) if isinstance(child, ParamData) else _ParamWrapper(child) + + def _unwrap_child(self, wrapped_child: _T | _ParamWrapper[_T]) -> _T: + """ + If the given child is wrapped by ``_ParamWrapper``, return the inner value. + Otherwise, return the child directly. + """ + if isinstance(wrapped_child, _ParamWrapper): + return wrapped_child.value + return wrapped_child def _update_last_updated(self) -> None: """Update last updated for this object and its chain of parents.""" # pylint: disable=protected-access,unused-private-member new_last_updated = datetime.now(timezone.utc).astimezone() - current: ParamData | None = self + current: ParamData[Any] | None = self # Continue up the chain of parents, stopping if we reach a last updated # timestamp that is more recent than the new one @@ -106,8 +135,23 @@ def last_updated(self) -> datetime: """When any parameter within this parameter data was last updated.""" return self._last_updated + def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: + """ + Get the wrapped child corresponding to the given name. + + Subclasses with children should implement this method and call the superclass + function if the child does not exist. + """ + raise ValueError( + f"'{type(self).__name__}' parameter data object has no child {child_name!r}" + ) + + def child_last_updated(self, child_name: _ChildNameT) -> datetime: + """Return the last updated time of the given child.""" + return self._get_wrapped_child(child_name).last_updated + @property - def parent(self) -> ParamData: + def parent(self) -> ParamData[Any]: """ Parent of this parameter data. The parent is defined to be the :py:class:`ParamData` object that most recently had this object added as a @@ -124,7 +168,7 @@ def parent(self) -> ParamData: return self._parent @property - def root(self) -> ParamData: + def root(self) -> ParamData[Any]: """ Root of this parameter data. The root is defined to be the first object with no parent when going up the chain of parents. @@ -134,3 +178,26 @@ def root(self) -> ParamData: while root._parent is not None: root = root._parent return root + + +class _ParamWrapper(ParamData[Never], Generic[_T]): + """ + Wrapper around a non-``ParamData`` value, mainly to track its last updated time. + """ + + def __init__(self, value: _T) -> None: + super().__init__() + self.value = value + + def _to_json(self) -> _T: + return self.value + + @classmethod + def _from_json(cls, json_data: _T) -> Self: + return cls(json_data) + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _ParamWrapper) and self.value == other.value + + def __repr__(self) -> str: + return repr(self.value) diff --git a/paramdb/_param_data/_primitives.py b/paramdb/_param_data/_primitives.py deleted file mode 100644 index f5c4683..0000000 --- a/paramdb/_param_data/_primitives.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Parameter data primitive classes.""" - -from __future__ import annotations -from typing import ( - Union, - Protocol, - TypeVar, - Generic, - SupportsInt, - SupportsFloat, - SupportsIndex, - Any, - overload, -) -from abc import abstractmethod -from typing_extensions import Self, Buffer -from paramdb._param_data._param_data import ParamData - -_T = TypeVar("_T") - - -# Based on https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi -class _SupportsTrunc(Protocol): - def __trunc__(self) -> int: ... - - -# Based on https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi -ConvertibleToInt = Union[str, Buffer, SupportsInt, SupportsIndex, _SupportsTrunc] -ConvertibleToFloat = Union[str, Buffer, SupportsFloat, SupportsIndex] - - -class _ParamPrimitive(ParamData, Generic[_T]): - """Base class for parameter primitives.""" - - @property - @abstractmethod - def value(self) -> _T: - """Primitive value stored by this parameter primitive.""" - - def _to_json(self) -> _T: - return self.value - - @classmethod - def _from_json(cls, json_data: _T) -> Self: - return cls(json_data) # type: ignore # pylint: disable=too-many-function-args - - def __repr__(self) -> str: - return f"{type(self).__name__}({self.value!r})" - - -class ParamInt(int, _ParamPrimitive[int]): - """ - Subclass of :py:class:`ParamData` and ``int``. - - Parameter data integer. All ``int`` methods and operations are available; however, - note that ordinary ``int`` objects will be returned. - """ - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - @overload - def __init__(self, x: ConvertibleToInt = 0, /): ... - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - @overload - def __init__(self, x: str | bytes | bytearray, /, base: SupportsIndex = 10): ... - - def __init__(self, *_args: Any, **_kwargs: Any) -> None: - super().__init__() - - @property - def value(self) -> int: - return int(self) - - def __repr__(self) -> str: - # Override __repr__() from base class int - return _ParamPrimitive.__repr__(self) - - -class ParamBool(ParamInt): - """ - Subclass of :py:class:`ParamInt`. - - Parameter data boolean. All ``int`` and ``bool`` methods and operations are - available; however, note that ordinary ``int`` objects (``0`` or ``1``) will be - returned. - """ - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - def __new__(cls, x: object = False, /) -> Self: - # Convert any object to a boolean to emulate bool() - return super().__new__(cls, bool(x)) - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - # pylint: disable-next=unused-argument - def __init__(self, o: object = False, /) -> None: - super().__init__() - - @property - def value(self) -> bool: - return bool(self) - - -class ParamFloat(float, _ParamPrimitive[float]): - """ - Subclass of :py:class:`ParamData` and ``float``. - - Parameter data float. All ``float`` methods and operations are available; however, - note that ordinary ``float`` objects will be returned. - """ - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - # pylint: disable-next=unused-argument - def __init__(self, x: ConvertibleToFloat = 0.0, /) -> None: - super().__init__() - - @property - def value(self) -> float: - return float(self) - - def __repr__(self) -> str: - # Override __repr__() from base class float - return _ParamPrimitive.__repr__(self) - - -class ParamStr(str, _ParamPrimitive[str]): - """ - Subclass of :py:class:`ParamData` and ``str``. - - Parameter data string. All ``str`` methods and operations are available; however, - note that ordinary ``str`` objects will be returned. - """ - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - @overload - def __init__( - self, - object: object = "", # pylint: disable=redefined-builtin - /, - ) -> None: ... - - # Based on https://github.com/python/typeshed/blob/main/stdlib/builtins.pyi - @overload - def __init__( - self, - object: Buffer = b"", # pylint: disable=redefined-builtin - encoding: str = "utf-8", - errors: str = "strict", - ) -> None: ... - - def __init__(self, *_args: Any, **_kwargs: Any) -> None: - super().__init__() - - @property - def value(self) -> str: - return str(self) - - def __repr__(self) -> str: - # Override __repr__() from base class str - return _ParamPrimitive.__repr__(self) - - -class ParamNone(_ParamPrimitive[None]): - """ - Subclass of :py:class:`ParamData`. - - Parameter data ``None``. Just like ``None``, its truth value is false. - """ - - @property - def value(self) -> None: - return None - - @classmethod - def _from_json(cls, json_data: None) -> Self: - return cls() - - def __bool__(self) -> bool: - return False - - def __eq__(self, other: object) -> bool: - return other is None or isinstance(other, ParamNone) - - def __hash__(self) -> int: - return hash(self.value) - - def __repr__(self) -> str: - # Show empty parentheses - return f"{type(self).__name__}()" diff --git a/paramdb/_param_data/_type_mixins.py b/paramdb/_param_data/_type_mixins.py index d51ae5f..4ad41e6 100644 --- a/paramdb/_param_data/_type_mixins.py +++ b/paramdb/_param_data/_type_mixins.py @@ -1,16 +1,16 @@ """Type mixins for parameter data.""" from __future__ import annotations -from typing import TypeVar, Generic, cast +from typing import TypeVar, Generic, Any, cast from paramdb._param_data._param_data import ParamData -PT = TypeVar("PT", bound=ParamData) +T = TypeVar("T", bound=ParamData[Any]) -class ParentType(ParamData, Generic[PT]): +class ParentType(ParamData[Any], Generic[T]): """ Mixin for :py:class:`ParamData` that sets the type hint for - :py:attr:`ParamData.parent` to type parameter ``PT``. For example:: + :py:attr:`ParamData.parent` to type parameter ``T``. For example:: class CustomParam(ParentType[ParentParam], Param): ... @@ -20,14 +20,14 @@ class CustomParam(ParentType[ParentParam], Param): """ @property - def parent(self) -> PT: - return cast(PT, super().parent) + def parent(self) -> T: + return cast(T, super().parent) -class RootType(ParamData, Generic[PT]): +class RootType(ParamData[Any], Generic[T]): """ Mixin for :py:class:`ParamData` that sets the type hint for - :py:attr:`ParamData.root` to type parameter ``PT``. For example:: + :py:attr:`ParamData.root` to type parameter ``T``. For example:: class CustomParam(RootType[RootParam], Param): ... @@ -37,5 +37,5 @@ class CustomParam(RootType[RootParam], Param): """ @property - def root(self) -> PT: - return cast(PT, super().root) + def root(self) -> T: + return cast(T, super().root) diff --git a/tests/_param_data/test_dataclasses.py b/tests/_param_data/test_dataclasses.py index 9b3c781..7f85055 100644 --- a/tests/_param_data/test_dataclasses.py +++ b/tests/_param_data/test_dataclasses.py @@ -139,11 +139,11 @@ def test_param_dataclass_init_wrong_type( string = "123" # Use a string of a number to make sure strict mode is enabled param_dataclass_class = type(param_dataclass_object) if param_dataclass_class is NoTypeValidationParam: - param = param_dataclass_class(number=string) # type: ignore - assert param.number == string # type: ignore + param = param_dataclass_class(number=string) # type: ignore[arg-type] + assert param.number == string # type: ignore[comparison-overlap] else: with pytest.raises(pydantic.ValidationError) as exc_info: - param_dataclass_class(number=string) # type: ignore + param_dataclass_class(number=string) # type: ignore[arg-type] assert "Input should be a valid number" in str(exc_info.value) @@ -156,7 +156,7 @@ def test_param_dataclass_init_default_wrong_type() -> None: class DefaultWrongTypeParam(SimpleParam): """Parameter data class with a default value having the wrong type.""" - default_number: float = "123" # type: ignore + default_number: float = "123" # type: ignore[assignment] with pytest.raises(pydantic.ValidationError) as exc_info: DefaultWrongTypeParam() @@ -171,11 +171,11 @@ def test_param_dataclass_init_extra( exc_info: pytest.ExceptionInfo[Exception] if param_dataclass_class is NoTypeValidationParam: with pytest.raises(TypeError) as exc_info: - param_dataclass_class(extra=number) # type: ignore + param_dataclass_class(extra=number) # type: ignore[call-arg] assert "__init__() got an unexpected keyword argument" in str(exc_info.value) else: with pytest.raises(pydantic.ValidationError) as exc_info: - param_dataclass_class(extra=number) # type: ignore + param_dataclass_class(extra=number) # type: ignore[call-arg] assert "Unexpected keyword argument" in str(exc_info.value) @@ -190,11 +190,13 @@ def test_param_dataclass_assignment_wrong_type( if isinstance( param_dataclass_object, (NoTypeValidationParam, NoAssignmentValidationParam) ): - param_dataclass_object.number = string # type: ignore - assert param_dataclass_object.number == string # type: ignore + param_dataclass_object.number = string # type: ignore[assignment] + assert ( + param_dataclass_object.number == string # type: ignore[comparison-overlap] + ) else: with pytest.raises(pydantic.ValidationError) as exc_info: - param_dataclass_object.number = string # type: ignore + param_dataclass_object.number = string # type: ignore[assignment] assert "Input should be a valid number" in str(exc_info.value) @@ -209,7 +211,7 @@ def test_param_dataclass_assignment_extra( param_dataclass_object, (NoTypeValidationParam, NoAssignmentValidationParam) ): param_dataclass_object.extra = number - assert param_dataclass_object.extra == number # type: ignore + assert param_dataclass_object.extra == number else: with pytest.raises(pydantic.ValidationError) as exc_info: param_dataclass_object.extra = number diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index 162556c..cd2c3df 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -18,7 +18,10 @@ def test_custom_subclass_extra_kwarg_fails(param_data_type: type[ParamData]) -> """Extra keyword arugments in a custom parameter data subclass raise a TypeError.""" with pytest.raises(TypeError) as exc_info: # pylint: disable-next=unused-variable - class CustomParamData(param_data_type, extra_kwarg="test"): # type: ignore + class CustomParamData( + param_data_type, # type: ignore[valid-type, misc] + extra_kwarg="test", # type: ignore[call-arg] + ): """Custom parameter data class with an extra keyword arugment.""" error_message = str(exc_info.value) diff --git a/tests/helpers.py b/tests/helpers.py index 28efe01..217649e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -11,7 +11,7 @@ from contextlib import contextmanager import time import pydantic -from astropy.units import Quantity # type: ignore # pylint: disable=import-error +from astropy.units import Quantity # type: ignore[import-untyped] from paramdb import ( ParamData, ParamInt, diff --git a/tests/test_database.py b/tests/test_database.py index 4b61ea3..f5a4490 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -111,7 +111,7 @@ def test_load_empty_fails(db_path: str) -> None: param_db = ParamDB[Any](db_path) for load_func in param_db.load, param_db.load_commit_entry: with pytest.raises(IndexError) as exc_info: - load_func() # type: ignore + load_func() # type: ignore[operator] assert ( str(exc_info.value) == f"cannot load most recent commit because database '{db_path}' has no" @@ -125,14 +125,14 @@ def test_load_nonexistent_commit_fails(db_path: str) -> None: param_db = ParamDB[Any](db_path) for load_func in param_db.load, param_db.load_commit_entry: with pytest.raises(IndexError) as exc_info: - load_func(1) # type: ignore + load_func(1) # type: ignore[operator] assert str(exc_info.value) == f"commit 1 does not exist in database '{db_path}'" # Database with one commit param_db.commit("Initial commit", {}) for load_func in param_db.load, param_db.load_commit_entry: with pytest.raises(IndexError) as exc_info: - load_func(100) # type: ignore + load_func(100) # type: ignore[operator] assert ( str(exc_info.value) == "commit 100 does not exist in database" f" '{db_path}'" @@ -421,17 +421,17 @@ def test_empty_commit_history(db_path: str) -> None: """Loads an empty commit history from an empty database.""" param_db = ParamDB[SimpleParam](db_path) for history_func in param_db.commit_history, param_db.commit_history_with_data: - assert history_func() == [] # type: ignore + assert history_func() == [] # type: ignore[operator] def test_empty_commit_history_slice(db_path: str) -> None: """Correctly slices an empty commit history.""" param_db = ParamDB[SimpleParam](db_path) for history_func in param_db.commit_history, param_db.commit_history_with_data: - assert history_func(0) == [] # type: ignore - assert history_func(0, 10) == [] # type: ignore - assert history_func(-10) == [] # type: ignore - assert history_func(-10, -5) == [] # type: ignore + assert history_func(0) == [] # type: ignore[operator] + assert history_func(0, 10) == [] # type: ignore[operator] + assert history_func(-10) == [] # type: ignore[operator] + assert history_func(-10, -5) == [] # type: ignore[operator] def test_commit_history(db_path: str, simple_param: SimpleParam) -> None: From 3872550b1c9419f3c649b5f68fca238ba83c2d3c Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Fri, 7 Jun 2024 12:29:33 -0700 Subject: [PATCH 02/12] #187 Finish implementation of primitive last updated tracking --- paramdb/_database.py | 16 +- paramdb/_param_data/_collections.py | 150 ++++++++-------- paramdb/_param_data/_dataclasses.py | 49 ++++-- paramdb/_param_data/_files.py | 16 +- paramdb/_param_data/_param_data.py | 86 +++++---- paramdb/_param_data/_type_mixins.py | 19 +- tests/_param_data/test_collections.py | 45 +++-- tests/_param_data/test_dataclasses.py | 4 +- tests/_param_data/test_param_data.py | 31 ++-- tests/_param_data/test_primitives.py | 240 -------------------------- tests/conftest.py | 105 ++++------- tests/helpers.py | 59 +++---- tests/test_database.py | 77 +++------ 13 files changed, 303 insertions(+), 594 deletions(-) delete mode 100644 tests/_param_data/test_primitives.py diff --git a/paramdb/_database.py b/paramdb/_database.py index 581bd36..0408870 100644 --- a/paramdb/_database.py +++ b/paramdb/_database.py @@ -23,7 +23,7 @@ except ImportError: _ASTROPY_INSTALLED = False -T = TypeVar("T") +DataT = TypeVar("DataT") _SelectT = TypeVar("_SelectT", bound=Select[Any]) CLASS_NAME_KEY = "__type" @@ -157,22 +157,22 @@ def __post_init__(self) -> None: @dataclass(frozen=True) -class CommitEntryWithData(CommitEntry, Generic[T]): +class CommitEntryWithData(CommitEntry, Generic[DataT]): """ Subclass of :py:class:`CommitEntry`. Entry for a commit containing the ID, message, and timestamp, as well as the data. """ - data: T + data: DataT """Data contained in this commit.""" -class ParamDB(Generic[T]): +class ParamDB(Generic[DataT]): """ Parameter database. The database is created in a file at the given path if it does not exist. To work with type checking, this class can be parameterized with a root - data type ``T``. For example:: + data type ``DataT``. For example:: from paramdb import ParamDataclass, ParamDB @@ -233,7 +233,7 @@ def path(self) -> str: return self._path def commit( - self, message: str, data: T, timestamp: datetime | None = None + self, message: str, data: DataT, timestamp: datetime | None = None ) -> CommitEntry: """ Commit the given data to the database with the given message and return a commit @@ -268,7 +268,7 @@ def num_commits(self) -> int: @overload def load( self, commit_id: int | None = None, *, load_classes: Literal[True] = True - ) -> T: ... + ) -> DataT: ... @overload def load( @@ -331,7 +331,7 @@ def commit_history_with_data( end: int | None = None, *, load_classes: Literal[True] = True, - ) -> list[CommitEntryWithData[T]]: ... + ) -> list[CommitEntryWithData[DataT]]: ... @overload def commit_history_with_data( diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index 8d8d733..5a3244a 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -10,16 +10,19 @@ MutableSequence, MutableMapping, ) +from copy import copy from typing_extensions import Self from paramdb._param_data._param_data import ParamData, _ParamWrapper -T = TypeVar("T") +ItemT = TypeVar("ItemT") _ChildNameT = TypeVar("_ChildNameT", str, int) -_CollectionT = TypeVar("_CollectionT", bound=Collection[Any]) +_CollectionT = TypeVar("_CollectionT", bound=Union[list[Any], dict[str, Any]]) # pylint: disable-next=abstract-method -class _ParamCollection(ParamData[_ChildNameT], Generic[_ChildNameT, _CollectionT]): +class _ParamCollection( + ParamData[_ChildNameT], Collection[Any], Generic[_ChildNameT, _CollectionT] +): """Base class for parameter collections.""" _contents: _CollectionT @@ -27,12 +30,10 @@ class _ParamCollection(ParamData[_ChildNameT], Generic[_ChildNameT, _CollectionT def __len__(self) -> int: return len(self._contents) - def __eq__(self, other: Any) -> bool: - # Equal if they have are of the same class and their contents are equal - return type(other) is type(self) and self._contents == other._contents - def __repr__(self) -> str: - return f"{type(self).__name__}({self._contents})" + # Show contents as self converted to an ordinary list or dict to hide internal + # _ParamWrapper objects + return f"{type(self).__name__}({type(self._contents)(self)})" def _to_json(self) -> _CollectionT: return self._contents @@ -42,7 +43,8 @@ def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: # superclass ValueError from the _contents exception try: return cast( - ParamData[Any], self._contents[child_name] # type: ignore[index] + ParamData[Any], + self._contents[child_name], # type: ignore[index,call-overload] ) except (TypeError, IndexError, KeyError) as contents_exc: try: @@ -50,18 +52,11 @@ def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: except ValueError as super_exc: raise super_exc from contents_exc - @classmethod - def _from_json(cls, json_data: _CollectionT) -> Self: - # Set contents directly since __init__() will contain child wrapping logic - new_param_collection = cls() - new_param_collection._contents = json_data - return new_param_collection - class ParamList( - _ParamCollection[int, list[Union[T, _ParamWrapper[T]]]], - MutableSequence[T], - Generic[T], + _ParamCollection[int, list[Union[ItemT, _ParamWrapper[ItemT]]]], + MutableSequence[ItemT], + Generic[ItemT], ): """ Subclass of :py:class:`ParamData` and ``MutableSequence``. @@ -70,74 +65,73 @@ class ParamList( iterable (like builtin ``list``). """ - def __init__(self, iterable: Iterable[T] | None = None) -> None: + def __init__(self, iterable: Iterable[ItemT] | None = None) -> None: super().__init__() initial_contents = iterable or [] - self._contents = [self._wrap_child(item) for item in initial_contents] - for item in initial_contents: - self._add_child(item) + wrapped_initial_contents = [self._wrap_child(item) for item in initial_contents] + for wrapped_item in wrapped_initial_contents: + self._add_child(wrapped_item) + self._contents = wrapped_initial_contents + + def __eq__(self, other: Any) -> bool: + # Equal if the other object is also a ParamList and has the same contents + return isinstance(other, ParamList) and self._contents == other._contents @overload - def __getitem__(self, index: SupportsIndex) -> T: ... + def __getitem__(self, index: SupportsIndex) -> ItemT: ... @overload def __getitem__(self, index: slice) -> Self: ... - def __getitem__(self, index: SupportsIndex | slice) -> T | Self: + def __getitem__(self, index: SupportsIndex | slice) -> ItemT | Self: if isinstance(index, slice): - return type(self)( - [ - self._unwrap_child(wrapped_child) - for wrapped_child in self._contents[index] - ] - ) + # The slice has the same last updated time and item objects as the original + self_copy = copy(self) + self_copy._contents = self._contents[index] + return self_copy return self._unwrap_child(self._contents[index]) @overload - def __setitem__(self, index: SupportsIndex, value: T) -> None: ... + def __setitem__(self, index: SupportsIndex, value: ItemT) -> None: ... @overload - def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... + def __setitem__(self, index: slice, value: Iterable[ItemT]) -> None: ... def __setitem__(self, index: SupportsIndex | slice, value: Any) -> None: if isinstance(index, slice): - old_values = self._contents[index] - self._contents[index] = [self._wrap_child(item) for item in value] - for old_item in old_values: - self._remove_child(old_item) - for item in value: - self._add_child(item) + old_wrapped_values = self._contents[index] + wrapped_values = [self._wrap_child(item) for item in value] + self._contents[index] = wrapped_values + for old_wrapped_item in old_wrapped_values: + self._remove_child(old_wrapped_item) + for wrapped_item in wrapped_values: + self._add_child(wrapped_item) else: - old_value = self._contents[index] - self._contents[index] = self._wrap_child(value) - self._remove_child(old_value) - self._add_child(value) + old_wrapped_value = self._contents[index] + wrapped_value = self._wrap_child(value) + self._contents[index] = wrapped_value + self._remove_child(old_wrapped_value) + self._add_child(wrapped_value) def __delitem__(self, index: SupportsIndex | slice) -> None: - old_value = self._contents[index] + old_wrapped_value = self._contents[index] del self._contents[index] - if isinstance(index, slice) and isinstance(old_value, list): - for old_item in old_value: - self._remove_child(old_item) + if isinstance(index, slice) and isinstance(old_wrapped_value, list): + for old_wrapped_item in old_wrapped_value: + self._remove_child(old_wrapped_item) else: - self._remove_child(old_value) - - def insert(self, index: SupportsIndex, value: T) -> None: - self._contents.insert(index, self._wrap_child(value)) - self._add_child(value) - - @classmethod - def _from_json(cls, json_data: list[T | _ParamWrapper[T]]) -> Self: + self._remove_child(old_wrapped_value) - new_obj = cls() - new_obj._contents = json_data - return new_obj + def insert(self, index: SupportsIndex, value: ItemT) -> None: + wrapped_value = self._wrap_child(value) + self._contents.insert(index, wrapped_value) + self._add_child(wrapped_value) class ParamDict( - _ParamCollection[str, dict[str, Union[T, _ParamWrapper[T]]]], - MutableMapping[str, T], - Generic[T], + _ParamCollection[str, dict[str, Union[ItemT, _ParamWrapper[ItemT]]]], + MutableMapping[str, ItemT], + Generic[ItemT], ): """ Subclass of :py:class:`ParamData` and ``MutableMapping``. @@ -149,14 +143,19 @@ class ParamDict( and items are returned as dict_keys, dict_values, and dict_items objects. """ - def __init__(self, mapping: Mapping[str, T] | None = None, /, **kwargs: T): + def __init__(self, mapping: Mapping[str, ItemT] | None = None, /, **kwargs: ItemT): super().__init__() initial_contents = {**(mapping or {}), **kwargs} - self._contents = { + wrapped_initial_contents = { key: self._wrap_child(value) for key, value in initial_contents.items() } - for value in initial_contents.values(): - self._add_child(value) + self._contents = wrapped_initial_contents + for wrapped_value in wrapped_initial_contents.values(): + self._add_child(wrapped_value) + + def __eq__(self, other: Any) -> bool: + # Equal if the other object is also a ParamDict and has the same contents + return isinstance(other, ParamDict) and self._contents == other._contents def __dir__(self) -> Iterable[str]: # Return keys that are not attribute names (i.e. do not pass self._is_attribute) @@ -166,24 +165,25 @@ def __dir__(self) -> Iterable[str]: *filter(lambda key: not self._is_attribute(key), self._contents.keys()), ] - def __getitem__(self, key: str) -> T: + def __getitem__(self, key: str) -> ItemT: return self._unwrap_child(self._contents[key]) - def __setitem__(self, key: str, value: T) -> None: - old_value = self._contents[key] if key in self._contents else None - self._contents[key] = self._wrap_child(value) - self._remove_child(old_value) - self._add_child(value) + def __setitem__(self, key: str, value: ItemT) -> None: + old_wrapped_value = self._contents[key] if key in self._contents else None + wrapped_value = self._wrap_child(value) + self._contents[key] = wrapped_value + self._remove_child(old_wrapped_value) + self._add_child(wrapped_value) def __delitem__(self, key: str) -> None: - old_value = self._contents[key] if key in self._contents else None + old_wrapped_value = self._contents[key] if key in self._contents else None del self._contents[key] - self._remove_child(old_value) + self._remove_child(old_wrapped_value) def __iter__(self) -> Iterator[str]: yield from self._contents - def __getattr__(self, name: str) -> T: + def __getattr__(self, name: str) -> ItemT: # Enable accessing items via dot notation if self._is_attribute(name): # It is important to raise an attribute error rather than a key error for @@ -194,7 +194,7 @@ def __getattr__(self, name: str) -> T: ) return self[name] - def __setattr__(self, name: str, value: T) -> None: + def __setattr__(self, name: str, value: ItemT) -> None: # Enable setting items via dot notation if self._is_attribute(name): super().__setattr__(name, value) diff --git a/paramdb/_param_data/_dataclasses.py b/paramdb/_param_data/_dataclasses.py index 33e96f3..6c61e0a 100644 --- a/paramdb/_param_data/_dataclasses.py +++ b/paramdb/_param_data/_dataclasses.py @@ -55,6 +55,7 @@ class CustomParam(ParamDataclass): "strict": True, "validate_default": True, } + _wrapped_children: dict[str, Any] | None = None # Used when initializing from json # pylint: disable-next=unused-argument def __base_setattr(self: Any, name: str, value: Any) -> None: @@ -119,11 +120,14 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self: return self def __post_init__(self) -> None: - # Add and wrap fields as children - for field_name in self._field_names: - child = getattr(self, field_name) - super().__setattr__(field_name, self._wrap_child(child)) - self._add_child(child) + # Wrap fields as children and process them + for field in fields(self): # type: ignore[arg-type] + if self._wrapped_children is not None and field.init: + wrapped_child = self._wrapped_children[field.name] + else: + wrapped_child = self._wrap_child(super().__getattribute__(field.name)) + super().__setattr__(field.name, wrapped_child) + self._add_child(wrapped_child) def __getitem__(self, name: str) -> Any: # Enable getting attributes via square brackets @@ -147,19 +151,20 @@ def __getattribute__(self, name: str) -> Any: def __setattr__(self, name: str, value: Any) -> None: # If this attribute is a field, process the old and new child if name in self._field_names: - old_value = getattr(self, name) + old_wrapped_value = super().__getattribute__(name) self.__base_setattr(name, value) # May perform type validation - super().__setattr__(name, self._wrap_child(value)) - self._add_child(value) - self._remove_child(old_value) + wrapped_value = self._wrap_child(value) + super().__setattr__(name, wrapped_value) + self._remove_child(old_wrapped_value) + self._add_child(wrapped_value) else: self.__base_setattr(name, value) def __delattr__(self, name: str) -> None: - old_value = getattr(self, name) + old_wrapped_value = super().__getattribute__(name) super().__delattr__(name) if name in self._field_names: - self._remove_child(old_value) + self._remove_child(old_wrapped_value) def _get_wrapped_child(self, child_name: str) -> ParamData[Any]: if child_name in self._field_names: @@ -167,10 +172,18 @@ def _get_wrapped_child(self, child_name: str) -> ParamData[Any]: return super()._get_wrapped_child(child_name) def _to_json(self) -> dict[str, Any]: - if is_dataclass(self): - return {f.name: getattr(self, f.name) for f in fields(self) if f.init} - return {} - - @classmethod - def _from_json(cls, json_data: dict[str, Any]) -> Self: - return cls(**json_data) + return { + field.name: super(ParamData, self).__getattribute__(field.name) + for field in fields(self) # type: ignore[arg-type] + if field.init + } + + def _init_from_json(self, json_data: dict[str, Any]) -> None: + unwrapped_children = { + name: self._unwrap_child(wrapped_child) + for name, wrapped_child in json_data.items() + } + super().__setattr__("_wrapped_children", json_data) + # pylint: disable-next=unnecessary-dunder-call + self.__init__(**unwrapped_children) # type: ignore[misc] + super().__delattr__("_wrapped_children") diff --git a/paramdb/_param_data/_files.py b/paramdb/_param_data/_files.py index 5e2633f..2e06c18 100644 --- a/paramdb/_param_data/_files.py +++ b/paramdb/_param_data/_files.py @@ -13,10 +13,10 @@ except ImportError: PANDAS_INSTALLED = False -T = TypeVar("T") +DataT = TypeVar("DataT") -class ParamFile(ParamDataclass, Generic[T]): +class ParamFile(ParamDataclass, Generic[DataT]): """ Subclass of :py:class:`ParamDataclass`. @@ -35,28 +35,28 @@ def _load_data(self, path: str) -> str: path: str """Path to the file represented by this object.""" - initial_data: InitVar[T | None] = None + initial_data: InitVar[DataT | None] = None # pylint: disable-next=arguments-differ - def __post_init__(self, initial_data: T | None) -> None: + def __post_init__(self, initial_data: DataT | None) -> None: super().__post_init__() if initial_data is not None: self.update_data(initial_data) @abstractmethod - def _save_data(self, path: str, data: T) -> None: + def _save_data(self, path: str, data: DataT) -> None: """Save the given data in a file at the given path.""" @abstractmethod - def _load_data(self, path: str) -> T: + def _load_data(self, path: str) -> DataT: """Load data from the file at the given path.""" @property - def data(self) -> T: + def data(self) -> DataT: """Data stored in the file represented by this object.""" return self._load_data(self.path) - def update_data(self, data: T) -> None: + def update_data(self, data: DataT) -> None: """Update the data stored within the file represented by this object.""" self._save_data(self.path, data) self._update_last_updated() diff --git a/paramdb/_param_data/_param_data.py b/paramdb/_param_data/_param_data.py index bd58f12..6c82a18 100644 --- a/paramdb/_param_data/_param_data.py +++ b/paramdb/_param_data/_param_data.py @@ -8,7 +8,7 @@ from typing_extensions import Self, Never _T = TypeVar("_T") -_ChildNameT = TypeVar("_ChildNameT", bound=Union[str, int]) +ChildNameT = TypeVar("ChildNameT", bound=Union[str, int]) _LAST_UPDATED_KEY = "last_updated" """Dictionary key corresponding to a ``ParamData`` object's last updated time.""" @@ -24,11 +24,12 @@ def get_param_class(class_name: str) -> type[ParamData[Any]] | None: return _param_classes[class_name] if class_name in _param_classes else None -class ParamData(ABC, Generic[_ChildNameT]): +class ParamData(ABC, Generic[ChildNameT]): """Abstract base class for all parameter data.""" _parent: ParamData[Any] | None = None _last_updated: datetime + _last_updated_frozen: bool = False def __init_subclass__(cls, /, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -36,18 +37,40 @@ def __init_subclass__(cls, /, **kwargs: Any) -> None: _param_classes[cls.__name__] = cls def __init__(self) -> None: - super().__setattr__("_last_updated", datetime.now(timezone.utc).astimezone()) + if not self._last_updated_frozen: + super().__setattr__( + "_last_updated", datetime.now(timezone.utc).astimezone() + ) + + def _wrap_child(self, child: _T) -> _T | _ParamWrapper[_T]: + """ + If the given child is not ``ParamData``, it will be wrapped by + ``_ParamWrapper``; otherwise, the original object will be returned. + """ + return cast(_T, child) if isinstance(child, ParamData) else _ParamWrapper(child) + + def _unwrap_child(self, wrapped_child: _T | _ParamWrapper[_T]) -> _T: + """ + If the given child is wrapped by ``_ParamWrapper``, return the inner value. + Otherwise, return the child directly. + """ + if isinstance(wrapped_child, _ParamWrapper): + return wrapped_child.value + return wrapped_child def _add_child(self, child: Any) -> None: """ This method should be called to process a new child. If the child is ``ParamData``, its parent and last updated attributes will be - updated. + updated (unless ``self._last_updated_frozen`` is False). + + For primitive children, this method should be called on the wrapped child. """ if isinstance(child, ParamData): super(ParamData, child).__setattr__("_parent", self) - child._update_last_updated() # pylint: disable=protected-access + if not self._last_updated_frozen: + child._update_last_updated() # pylint: disable=protected-access def _remove_child(self, child: Any) -> None: """ @@ -57,26 +80,15 @@ def _remove_child(self, child: Any) -> None: """ if isinstance(child, ParamData): super(ParamData, child).__setattr__("_parent", None) - self._update_last_updated() + if not self._last_updated_frozen: + self._update_last_updated() - def _wrap_child(self, child: _T) -> _T | _ParamWrapper[_T]: - """ - If the given child is not ``ParamData``, it will be wrapped by - ``_ParamWrapper``; otherwise, the original object will be returned. + def _update_last_updated(self) -> None: """ - return cast(_T, child) if isinstance(child, ParamData) else _ParamWrapper(child) + Update last updated for this object and its chain of parents. - def _unwrap_child(self, wrapped_child: _T | _ParamWrapper[_T]) -> _T: + If ``self._last_updated_frozen`` is True, then this function will do nothing. """ - If the given child is wrapped by ``_ParamWrapper``, return the inner value. - Otherwise, return the child directly. - """ - if isinstance(wrapped_child, _ParamWrapper): - return wrapped_child.value - return wrapped_child - - def _update_last_updated(self) -> None: - """Update last updated for this object and its chain of parents.""" # pylint: disable=protected-access,unused-private-member new_last_updated = datetime.now(timezone.utc).astimezone() current: ParamData[Any] | None = self @@ -102,16 +114,18 @@ def _to_json(self) -> Any: need to be JSON serializable, since they will be processed recursively. """ - @classmethod - @abstractmethod - def _from_json(cls, json_data: Any) -> Self: + def _init_from_json(self, json_data: Any) -> None: """ - Construct a parameter data object from the given JSON data, usually created by - ``json.loads()`` and originally constructed by ``self._data_to_json()``. + Initialize a new parameter data object from the given JSON data, usually created + by ``json.loads()`` and originally constructed by ``self._data_to_json()``. By + default, this method will pass the JSON data to ``self.__init__()``. - The last updated timestamp is handled separately and does not need to be set - here. + The object will be generated by ``self.__new__()``, but ``self.__init__()`` has + not been called and ``self._last_updated_frozen`` is set to False. The last + updated timestamp is handled separately and does not need to be set here. """ + # pylint: disable-next=unnecessary-dunder-call + self.__init__(json_data) # type: ignore[misc] def to_dict(self) -> dict[str, Any]: """ @@ -126,8 +140,11 @@ def from_dict(cls, data_dict: dict[str, Any]) -> Self: Construct a parameter data object from the given dictionary, usually created by ``json.loads()`` and originally constructed by :py:meth:`from_dict`. """ - param_data = cls._from_json(data_dict[_DATA_KEY]) + param_data = cls.__new__(cls) + super().__setattr__(param_data, "_last_updated_frozen", True) + param_data._init_from_json(data_dict[_DATA_KEY]) super().__setattr__(param_data, "_last_updated", data_dict[_LAST_UPDATED_KEY]) + super().__setattr__(param_data, "_last_updated_frozen", False) return param_data @property @@ -135,7 +152,7 @@ def last_updated(self) -> datetime: """When any parameter within this parameter data was last updated.""" return self._last_updated - def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: + def _get_wrapped_child(self, child_name: ChildNameT) -> ParamData[Any]: """ Get the wrapped child corresponding to the given name. @@ -146,7 +163,7 @@ def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: f"'{type(self).__name__}' parameter data object has no child {child_name!r}" ) - def child_last_updated(self, child_name: _ChildNameT) -> datetime: + def child_last_updated(self, child_name: ChildNameT) -> datetime: """Return the last updated time of the given child.""" return self._get_wrapped_child(child_name).last_updated @@ -192,12 +209,5 @@ def __init__(self, value: _T) -> None: def _to_json(self) -> _T: return self.value - @classmethod - def _from_json(cls, json_data: _T) -> Self: - return cls(json_data) - def __eq__(self, other: Any) -> bool: return isinstance(other, _ParamWrapper) and self.value == other.value - - def __repr__(self) -> str: - return repr(self.value) diff --git a/paramdb/_param_data/_type_mixins.py b/paramdb/_param_data/_type_mixins.py index 4ad41e6..a3a130f 100644 --- a/paramdb/_param_data/_type_mixins.py +++ b/paramdb/_param_data/_type_mixins.py @@ -4,13 +4,14 @@ from typing import TypeVar, Generic, Any, cast from paramdb._param_data._param_data import ParamData -T = TypeVar("T", bound=ParamData[Any]) +ParentT = TypeVar("ParentT", bound=ParamData[Any]) +RootT = TypeVar("RootT", bound=ParamData[Any]) -class ParentType(ParamData[Any], Generic[T]): +class ParentType(ParamData[Any], Generic[ParentT]): """ Mixin for :py:class:`ParamData` that sets the type hint for - :py:attr:`ParamData.parent` to type parameter ``T``. For example:: + :py:attr:`ParamData.parent` to type parameter ``ParentT``. For example:: class CustomParam(ParentType[ParentParam], Param): ... @@ -20,14 +21,14 @@ class CustomParam(ParentType[ParentParam], Param): """ @property - def parent(self) -> T: - return cast(T, super().parent) + def parent(self) -> ParentT: + return cast(ParentT, super().parent) -class RootType(ParamData[Any], Generic[T]): +class RootType(ParamData[Any], Generic[RootT]): """ Mixin for :py:class:`ParamData` that sets the type hint for - :py:attr:`ParamData.root` to type parameter ``T``. For example:: + :py:attr:`ParamData.root` to type parameter ``RootT``. For example:: class CustomParam(RootType[RootParam], Param): ... @@ -37,5 +38,5 @@ class CustomParam(RootType[RootParam], Param): """ @property - def root(self) -> T: - return cast(T, super().root) + def root(self) -> RootT: + return cast(RootT, super().root) diff --git a/tests/_param_data/test_collections.py b/tests/_param_data/test_collections.py index 62adb04..2ff5843 100644 --- a/tests/_param_data/test_collections.py +++ b/tests/_param_data/test_collections.py @@ -146,38 +146,44 @@ def test_param_collection_len_nonempty( def test_param_collection_eq( param_collection_type: type[ParamCollection], param_collection: ParamCollection, + custom_param_collection: CustomParamCollection, contents: Any, ) -> None: """ - Two parameter collections are equal if they have the same class and contents. + Parameter collections are equal to instances of the same root collection class + (ParamList or ParamDict instances) with the same contents. """ assert param_collection == param_collection_type(contents) + assert param_collection == custom_param_collection def test_param_collection_neq_contents( param_collection_type: type[ParamCollection], param_collection: ParamCollection, + custom_param_collection: CustomParamCollection, ) -> None: """ Two parameter collections are not equal if they have the same class but different contents. """ assert param_collection != param_collection_type() + assert param_collection != type(custom_param_collection)() def test_param_collection_neq_class( contents_type: type[Contents], param_collection: ParamCollection, custom_param_collection: CustomParamCollection, - contents: Contents, ) -> None: """ Two parameter collections are not equal if they have the same contents but different - classes. + root collection classes (ParamList or ParamDict). """ - assert contents_type(param_collection) == contents_type(custom_param_collection) - assert param_collection != custom_param_collection - assert param_collection != contents + assert param_collection != contents_type(param_collection) + assert custom_param_collection != contents_type(custom_param_collection) + if isinstance(param_collection, ParamList): + assert ParamList(["a", "b", "c"]) != ParamDict(a=1, b=2, c=3) + assert ParamList(["a", "b", "c"]) == ParamList(ParamDict(a=1, b=2, c=3)) def test_param_collection_repr( @@ -217,15 +223,20 @@ def test_param_list_get_slice( param_list: ParamList[Any], param_list_contents: list[Any] ) -> None: """Can get an item by slice from a parameter list.""" - assert isinstance(param_list[0:2], list) - assert param_list[0:2] == param_list_contents[0:2] + assert isinstance(param_list[0:2], ParamList) + assert list(param_list[0:2]) == param_list_contents[0:2] def test_param_list_get_slice_parent(param_list: ParamList[Any]) -> None: - """Items gotten from a parameter list via a slice have the correct parent.""" + """ + Slices of a parameter list have no parent, and the parent of their items is the + slice, not the original parameter list. + """ sublist = param_list[2:4] - assert sublist[0].parent is param_list - assert sublist[1].parent is param_list + with pytest.raises(ValueError): + assert sublist.parent + assert sublist[0].parent is sublist + assert sublist[1].parent is sublist def test_param_list_set_index(param_list: ParamList[Any]) -> None: @@ -246,7 +257,7 @@ def test_param_list_set_index_last_updated(param_list: ParamList[Any]) -> None: def test_param_list_set_index_parent( - param_list: ParamList[Any], param_data: ParamData + param_list: ParamList[Any], param_data: ParamData[Any] ) -> None: """ A parameter data added to a parameter list via indexing has the correct parent. @@ -279,7 +290,7 @@ def test_param_list_set_slice_last_updated(param_list: ParamList[Any]) -> None: def test_param_list_set_slice_parent( - param_list: ParamList[Any], param_data: ParamData + param_list: ParamList[Any], param_data: ParamData[Any] ) -> None: """A parameter data added to a parameter list via slicing has the correct parent.""" for _ in range(2): # Run twice to check reassigning the same parameter data @@ -305,7 +316,7 @@ def test_param_list_insert_last_updated(param_list: ParamList[Any]) -> None: def test_param_list_insert_parent( - param_list: ParamList[Any], param_data: ParamData + param_list: ParamList[Any], param_data: ParamData[Any] ) -> None: """Parameter data added to a parameter list via insertion has the correct parent.""" param_list.insert(1, param_data) @@ -329,7 +340,7 @@ def test_param_list_del_last_updated(param_list: ParamList[Any]) -> None: def test_param_list_del_parent( - param_list: ParamList[Any], param_data: ParamData + param_list: ParamList[Any], param_data: ParamData[Any] ) -> None: """An item deleted from a parameter list has no parent.""" param_list.append(param_data) @@ -419,7 +430,7 @@ def test_param_dict_set_last_updated(param_dict: ParamDict[Any]) -> None: def test_param_dict_set_parent( - param_dict: ParamDict[Any], param_data: ParamData + param_dict: ParamDict[Any], param_data: ParamData[Any] ) -> None: """Parameter data added to a parameter dictionary has the correct parent.""" with pytest.raises(ValueError): @@ -454,7 +465,7 @@ def test_param_dict_del_last_updated(param_dict: ParamDict[Any]) -> None: def test_param_dict_del_parent( - param_dict: ParamDict[Any], param_data: ParamData + param_dict: ParamDict[Any], param_data: ParamData[Any] ) -> None: """An item deleted from a parameter dictionary has no parent.""" param_dict["param_data"] = param_data diff --git a/tests/_param_data/test_dataclasses.py b/tests/_param_data/test_dataclasses.py index 7f85055..039697c 100644 --- a/tests/_param_data/test_dataclasses.py +++ b/tests/_param_data/test_dataclasses.py @@ -1,6 +1,6 @@ """Tests for the paramdb._param_data._dataclasses module.""" -from typing import Union, cast +from typing import Union, Any, cast from copy import deepcopy import pydantic import pytest @@ -116,7 +116,7 @@ def test_param_dataclass_init_parent(complex_param: ComplexParam) -> None: def test_param_dataclass_set_parent( - complex_param: ComplexParam, param_data: ParamData + complex_param: ComplexParam, param_data: ParamData[Any] ) -> None: """Parameter data added to a structure has the correct parent.""" with pytest.raises(ValueError): diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index cd2c3df..7432568 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -1,5 +1,6 @@ """Tests for the paramdb._param_data._param_data module.""" +from typing import Any from dataclasses import is_dataclass from copy import deepcopy import pytest @@ -9,17 +10,19 @@ @pytest.fixture(name="param_data_type") -def fixture_param_data_type(param_data: ParamData) -> type[ParamData]: +def fixture_param_data_type(param_data: ParamData[Any]) -> type[ParamData[Any]]: """Parameter data type.""" return type(param_data) -def test_custom_subclass_extra_kwarg_fails(param_data_type: type[ParamData]) -> None: +def test_custom_subclass_extra_kwarg_fails( + param_data_type: type[ParamData[Any]], +) -> None: """Extra keyword arugments in a custom parameter data subclass raise a TypeError.""" with pytest.raises(TypeError) as exc_info: # pylint: disable-next=unused-variable class CustomParamData( - param_data_type, # type: ignore[valid-type, misc] + param_data_type, # type: ignore[valid-type,misc] extra_kwarg="test", # type: ignore[call-arg] ): """Custom parameter data class with an extra keyword arugment.""" @@ -34,21 +37,21 @@ class CustomParamData( assert "takes no keyword arguments" in error_message -def test_is_param_data(param_data: ParamData) -> None: +def test_is_param_data(param_data: ParamData[Any]) -> None: """Parameter data object is an instance of the `ParamData` class.""" assert isinstance(param_data, ParamData) -def test_get_param_class(param_data: ParamData) -> None: +def test_get_param_class(param_data: ParamData[Any]) -> None: """Parameter classes can be retrieved by name.""" param_class = type(param_data) assert get_param_class(param_class.__name__) is param_class -def test_param_data_initial_last_updated(param_data_type: type[ParamData]) -> None: +def test_param_data_initial_last_updated(param_data_type: type[ParamData[Any]]) -> None: """New parameter data objects are initialized with a last updated timestamp.""" with capture_start_end_times() as times: - new_param_data: ParamData + new_param_data: ParamData[Any] if issubclass(param_data_type, ParamDataFrame): new_param_data = param_data_type("") else: @@ -58,7 +61,7 @@ def test_param_data_initial_last_updated(param_data_type: type[ParamData]) -> No def test_param_data_updates_last_updated( - updated_param_data: ParamData, updated_times: Times + updated_param_data: ParamData[Any], updated_times: Times ) -> None: """Updating parameter data updates the last updated time.""" assert updated_param_data.last_updated is not None @@ -69,7 +72,7 @@ def test_param_data_updates_last_updated( ) -def test_child_does_not_change(param_data: ParamData) -> None: +def test_child_does_not_change(param_data: ParamData[Any]) -> None: """ Including a parameter data object as a child within a parent structure does not change the parameter in terms of equality comparison (i.e. public properties, @@ -81,7 +84,7 @@ def test_child_does_not_change(param_data: ParamData) -> None: assert param_data == param_data_original -def test_to_and_from_dict(param_data: ParamData) -> None: +def test_to_and_from_dict(param_data: ParamData[Any]) -> None: """Parameter data can be converted to and from a dictionary.""" param_data_dict = param_data.to_dict() assert isinstance(param_data_dict, dict) @@ -91,7 +94,7 @@ def test_to_and_from_dict(param_data: ParamData) -> None: assert param_data_from_dict.last_updated == param_data.last_updated -def test_no_parent_fails(param_data: ParamData) -> None: +def test_no_parent_fails(param_data: ParamData[Any]) -> None: """Fails to get the parent when there is no parent.""" with pytest.raises(ValueError) as exc_info: _ = param_data.parent @@ -102,12 +105,12 @@ def test_no_parent_fails(param_data: ParamData) -> None: ) -def test_self_is_root(param_data: ParamData) -> None: +def test_self_is_root(param_data: ParamData[Any]) -> None: """Parameter data object with no parent returns itself as the root.""" assert param_data.root is param_data -def test_parent_is_root(param_data: ParamData) -> None: +def test_parent_is_root(param_data: ParamData[Any]) -> None: """ Parameter data object with a parent that has no parent returns the parent as the root. @@ -116,7 +119,7 @@ def test_parent_is_root(param_data: ParamData) -> None: assert param_data.root is parent -def test_parent_of_parent_is_root(param_data: ParamData) -> None: +def test_parent_of_parent_is_root(param_data: ParamData[Any]) -> None: """ Parameter data object with a parent that has a parent returns the highest level parent as the root. diff --git a/tests/_param_data/test_primitives.py b/tests/_param_data/test_primitives.py deleted file mode 100644 index be2ce5c..0000000 --- a/tests/_param_data/test_primitives.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Tests for the paramdb._param_data._primitives module.""" - -from typing import Union, cast -from copy import deepcopy -import math -import pytest -from paramdb import ParamInt, ParamFloat, ParamBool, ParamStr, ParamNone -from tests.helpers import ( - SimpleParam, - CustomParamInt, - CustomParamFloat, - CustomParamBool, - CustomParamStr, - CustomParamNone, -) - -ParamPrimitive = Union[ParamInt, ParamFloat, ParamBool, ParamStr, ParamNone] -CustomParamPrimitive = Union[ - CustomParamInt, CustomParamFloat, CustomParamBool, CustomParamStr, CustomParamNone -] - - -@pytest.fixture( - name="param_primitive", - params=["param_int", "param_float", "param_bool", "param_str", "param_none"], -) -def fixture_param_primitive(request: pytest.FixtureRequest) -> ParamPrimitive: - """Parameter primitive.""" - return cast(ParamPrimitive, deepcopy(request.getfixturevalue(request.param))) - - -@pytest.fixture(name="custom_param_primitive") -def fixture_custom_param_primitive( - param_primitive: ParamPrimitive, -) -> CustomParamPrimitive: - """Custom parameter primitive.""" - if isinstance(param_primitive, ParamInt): - return CustomParamInt(param_primitive.value) - if isinstance(param_primitive, ParamFloat): - return CustomParamFloat(param_primitive.value) - if isinstance(param_primitive, ParamBool): - return CustomParamBool(param_primitive.value) - if isinstance(param_primitive, ParamStr): - return CustomParamStr(param_primitive.value) - return CustomParamNone() - - -def test_param_int_isinstance(param_int: ParamInt) -> None: - """Parameter integers are instances of ``int``.""" - assert isinstance(param_int, int) - assert isinstance(CustomParamInt(param_int.value), int) - - -def test_param_float_isinstance(param_float: ParamFloat) -> None: - """Parameter floats are instances of ``float``.""" - assert isinstance(param_float, float) - assert isinstance(CustomParamFloat(param_float.value), float) - - -def test_param_bool_isinstance(param_bool: ParamBool) -> None: - """Parameter booleans are instances of ``int``.""" - assert isinstance(param_bool, int) - assert isinstance(CustomParamBool(param_bool.value), int) - - -def test_param_str_isinstance(param_str: ParamStr) -> None: - """Parameter strings are instances of ``str``.""" - assert isinstance(param_str, str) - assert isinstance(CustomParamStr(param_str.value), str) - - -def test_param_int_constructor() -> None: - """The ``ParamInt`` constructor behaves like the ``int`` constructor.""" - assert ParamInt() == 0 - assert ParamInt(123) == 123 - assert ParamInt(123.0) == 123.0 - assert ParamInt(123.1) == 123 - assert ParamInt(123.9) == 123 - assert ParamInt("123") == 123 - assert ParamInt("0x42", base=16) == 66 - with pytest.raises(ValueError): - ParamInt("hello") - - -def test_param_float_constructor() -> None: - """The ``ParamFloat`` constructor behaves like the ``float`` constructor.""" - assert ParamFloat() == 0.0 - assert ParamFloat(123) == 123.0 - assert ParamFloat(1.23) == 1.23 - assert ParamFloat("1.23") == 1.23 - assert ParamFloat("inf") == float("inf") - assert math.isnan(ParamFloat("nan")) - with pytest.raises(ValueError): - ParamFloat("hello") - - -def test_param_bool_constructor() -> None: - """The ``ParamBool`` constructor behaves like the ``bool`` constructor.""" - assert ParamBool().value is False - assert ParamBool(123).value is True - assert ParamBool("").value is False - assert ParamBool("hello").value is True - - -def test_param_str_constructor() -> None: - """The ``ParamStr`` constructor behaves like the ``str`` constructor.""" - assert ParamStr() == "" - assert ParamStr("hello") == "hello" - assert ParamStr(123) == "123" - assert ParamStr(b"hello", encoding="utf-8") == "hello" - - -def test_param_int_value() -> None: - """Can access the primitive value of a parameter integer.""" - param_int_value = ParamInt(123).value - assert type(param_int_value) is int # pylint: disable=unidiomatic-typecheck - assert param_int_value == 123 - - -def test_param_float_value(number: float) -> None: - """Can access the primitive value of a parameter float.""" - param_float_value = ParamFloat(number).value - assert type(param_float_value) is float # pylint: disable=unidiomatic-typecheck - assert param_float_value == number - - -def test_param_bool_value() -> None: - """Can access the primitive value of a parameter boolean.""" - assert ParamBool(True).value is True - assert ParamBool(False).value is False - assert ParamBool(0).value is False - assert ParamBool(123).value is True - - -def test_param_str_value(string: str) -> None: - """Can access the primitive value of a parameter string.""" - param_str_value = ParamStr(string).value - assert type(param_str_value) is str # pylint: disable=unidiomatic-typecheck - assert param_str_value == string - - -def test_param_none_value() -> None: - """Can access the primitive value of a parameter ``None``.""" - assert ParamNone().value is None - - -def test_param_primitive_repr( - param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive -) -> None: - """Can represent a parameter primitive as a string using ``repr()``.""" - if isinstance(param_primitive, ParamNone): - assert repr(param_primitive) == f"{ParamNone.__name__}()" - assert repr(custom_param_primitive) == f"{CustomParamNone.__name__}()" - else: - assert ( - repr(param_primitive) - == f"{type(param_primitive).__name__}({param_primitive.value!r})" - ) - assert ( - repr(custom_param_primitive) == f"{type(custom_param_primitive).__name__}" - f"({custom_param_primitive.value!r})" - ) - - -def test_param_primitive_bool( - param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive -) -> None: - """Parameter primitive objects have the correct truth values.""" - assert bool(param_primitive) is bool(param_primitive.value) - assert bool(custom_param_primitive) is bool(custom_param_primitive.value) - - -def test_param_primitive_eq( - param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive -) -> None: - """ - Parameter primitive objects are equal to themselves, their vaues, and custom - parameter primitive objects. - """ - # pylint: disable=comparison-with-itself - assert param_primitive == param_primitive - assert param_primitive == deepcopy(param_primitive) - assert param_primitive == custom_param_primitive - assert param_primitive == param_primitive.value - - -def test_param_primitive_ne( - simple_param: SimpleParam, - param_primitive: ParamPrimitive, - custom_param_primitive: CustomParamPrimitive, -) -> None: - """ - Parameter primitive objects are not equal to other objects or parameter primitives - with different values. - """ - assert param_primitive != simple_param - assert custom_param_primitive != simple_param - if not isinstance(param_primitive, ParamNone): - assert param_primitive != type(param_primitive)() - assert custom_param_primitive != type(custom_param_primitive)() - - -def test_param_primitive_hash( - param_primitive: ParamPrimitive, custom_param_primitive: CustomParamPrimitive -) -> None: - """Parameter primitive objects has the same hash as objects they are equal to.""" - assert hash(param_primitive) == hash(deepcopy(param_primitive)) - assert hash(param_primitive) == hash(custom_param_primitive) - assert hash(param_primitive) == hash(param_primitive.value) - - -def test_param_int_methods_return_int(param_int: ParamInt) -> None: - """``ParamInt`` methods inherited from ``int`` return ``int`` objects.""" - # pylint: disable=unidiomatic-typecheck - assert type(param_int + 123) is int - assert type(-param_int) is int - assert type(param_int.real) is int - - -def test_param_float_methods_return_float(param_float: ParamFloat) -> None: - """``ParamFloat`` methods inherited from ``float`` return ``float`` objects.""" - # pylint: disable=unidiomatic-typecheck - assert type(param_float + 123) is float - assert type(-param_float) is float - assert type(param_float.real) is float - - -def test_param_bool_methods_return_int() -> None: - """``ParamBool`` methods inherited from ``int`` return ``int`` objects.""" - # pylint: disable=unidiomatic-typecheck - assert type(ParamBool(True) ^ ParamBool(False)) is int - assert type(ParamBool(True) | ParamBool(False)) is int - - -def test_param_str_methods_return_str(param_str: ParamStr) -> None: - """``ParamStr`` methods inherited from ``str`` return ``str`` objects.""" - # pylint: disable=unidiomatic-typecheck - assert type(param_str + "") is str - assert type(param_str[::-1]) is str - assert type(param_str.capitalize()) is str diff --git a/tests/conftest.py b/tests/conftest.py index 3000654..397ef01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,10 @@ """Defines global fixtures. Called automatically by Pytest before running tests.""" +from __future__ import annotations from typing import Any from copy import deepcopy import pytest -from paramdb import ( - ParamData, - ParamInt, - ParamFloat, - ParamBool, - ParamStr, - ParamNone, - ParamDataFrame, - ParamList, - ParamDict, -) +from paramdb import ParamData, ParamDataFrame, ParamList, ParamDict from tests.helpers import ( DEFAULT_NUMBER, DEFAULT_STRING, @@ -42,36 +33,6 @@ def fixture_string() -> str: return DEFAULT_STRING -@pytest.fixture(name="param_int") -def fixture_param_int() -> ParamInt: - """Parameter integer object.""" - return ParamInt(123) - - -@pytest.fixture(name="param_float") -def fixture_param_float(number: float) -> ParamFloat: - """Parameter float object.""" - return ParamFloat(number) - - -@pytest.fixture(name="param_bool") -def fixture_param_bool() -> ParamBool: - """Parameter boolean object.""" - return ParamBool(True) - - -@pytest.fixture(name="param_str") -def fixture_param_str(string: str) -> ParamStr: - """Parameter string object.""" - return ParamStr(string) - - -@pytest.fixture(name="param_none") -def fixture_param_none() -> ParamNone: - """Parameter ``None`` object.""" - return ParamNone() - - @pytest.fixture(name="param_data_frame") def fixture_param_data_frame(string: str) -> ParamDataFrame: """Parameter DataFrame.""" @@ -156,11 +117,6 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: return [ number, string, - ParamInt(), - ParamFloat(number), - ParamBool(), - ParamStr(string), - ParamNone(), ParamDataFrame(string), EmptyParam(), SimpleParam(), @@ -180,11 +136,6 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: def fixture_param_dict_contents( number: float, string: str, - param_int: ParamInt, - param_float: ParamFloat, - param_bool: ParamBool, - param_str: ParamStr, - param_none: ParamNone, param_data_frame: ParamDataFrame, empty_param: EmptyParam, simple_param: SimpleParam, @@ -199,11 +150,6 @@ def fixture_param_dict_contents( return { "number": number, "string": string, - "param_int": param_int, - "param_float": deepcopy(param_float), - "param_bool": deepcopy(param_bool), - "param_str": deepcopy(param_str), - "param_none": deepcopy(param_none), "param_data_frame": deepcopy(param_data_frame), "empty_param": deepcopy(empty_param), "simple_param": deepcopy(simple_param), @@ -245,11 +191,6 @@ def fixture_param_dict(param_dict_contents: dict[str, Any]) -> ParamDict[Any]: @pytest.fixture( name="param_data", params=[ - "param_int", - "param_float", - "param_bool", - "param_str", - "param_none", "param_data_frame", "empty_param", "simple_param", @@ -263,28 +204,42 @@ def fixture_param_dict(param_dict_contents: dict[str, Any]) -> ParamDict[Any]: "param_dict", ], ) -def fixture_param_data(request: pytest.FixtureRequest) -> ParamData: +def fixture_param_data(request: pytest.FixtureRequest) -> ParamData[Any]: """Parameter data.""" - param_data: ParamData = deepcopy(request.getfixturevalue(request.param)) + param_data: ParamData[Any] = deepcopy(request.getfixturevalue(request.param)) return param_data +@pytest.fixture(name="param_data_child_name") +# pylint: disable-next=too-many-return-statements +def fixture_param_data_child_name(param_data: ParamData[Any]) -> str | int | None: + """Name of a child in the parameter data.""" + if isinstance(param_data, ParamDataFrame): + return "path" + if isinstance(param_data, SimpleParam): + return "number" + if isinstance(param_data, SubclassParam): + return "second_number" + if isinstance(param_data, ComplexParam): + return "simple_param" + if isinstance(param_data, ParamList): + return None if len(param_data) == 0 else 4 + if isinstance(param_data, ParamDict): + return None if len(param_data) == 0 else "simple_param" + return None + + @pytest.fixture(name="updated_param_data_and_times") def fixture_updated_param_data_and_times( - param_data: ParamData, number: float -) -> tuple[ParamData, Times]: + param_data: ParamData[Any], number: float +) -> tuple[ParamData[Any], Times]: """ Parameter data that has been updated between the returned Times. Broken down into individual fixtures for parameter data and times below. """ updated_param_data = deepcopy(param_data) with capture_start_end_times() as times: - if isinstance( - updated_param_data, - (ParamInt, ParamFloat, ParamBool, ParamStr), - ): - updated_param_data = type(updated_param_data)(updated_param_data.value) - elif isinstance(updated_param_data, (ParamNone, EmptyParam)): + if isinstance(updated_param_data, EmptyParam): updated_param_data = type(updated_param_data)() elif isinstance(updated_param_data, ParamDataFrame): updated_param_data.path = "" @@ -299,7 +254,7 @@ def fixture_updated_param_data_and_times( if len(updated_param_data) == 0: updated_param_data.append(number) else: - updated_param_data[9].number += 1 + updated_param_data[4].number += 1 elif isinstance(updated_param_data, ParamDict): if len(updated_param_data) == 0: updated_param_data["number"] = number @@ -310,15 +265,15 @@ def fixture_updated_param_data_and_times( @pytest.fixture(name="updated_param_data") def fixture_updated_param_data( - updated_param_data_and_times: tuple[ParamData, Times] -) -> ParamData: + updated_param_data_and_times: tuple[ParamData[Any], Times] +) -> ParamData[Any]: """Parameter data that has been updated.""" return updated_param_data_and_times[0] @pytest.fixture(name="updated_times") def fixture_updated_times( - updated_param_data_and_times: tuple[ParamData, Times] + updated_param_data_and_times: tuple[ParamData[Any], Times] ) -> Times: """Times before and after param_data fixture was updated.""" return updated_param_data_and_times[1] diff --git a/tests/helpers.py b/tests/helpers.py index 217649e..1a5a1dc 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -14,17 +14,13 @@ from astropy.units import Quantity # type: ignore[import-untyped] from paramdb import ( ParamData, - ParamInt, - ParamBool, - ParamFloat, - ParamStr, - ParamNone, ParamDataclass, ParamFile, ParamDataFrame, ParamList, ParamDict, ) +from paramdb._param_data._param_data import _ParamWrapper DEFAULT_NUMBER = 1.23 DEFAULT_STRING = "test" @@ -53,11 +49,6 @@ class SimpleParam(ParamDataclass): number_init_false: float = field(init=False, default=DEFAULT_NUMBER) number_with_units: Quantity = Quantity(12, "m") string: str = DEFAULT_STRING - param_int: ParamInt = ParamInt(123) - param_float: ParamFloat = ParamFloat(DEFAULT_NUMBER) - param_bool: ParamBool = ParamBool(False) - param_str: ParamStr = ParamStr(DEFAULT_STRING) - param_none: ParamNone = ParamNone() class NoTypeValidationParam(SimpleParam, type_validation=False): @@ -94,11 +85,6 @@ class ComplexParam(ParamDataclass): string: str = DEFAULT_STRING list: list[Any] = field(default_factory=list) dict: dict[str, Any] = field(default_factory=dict) - param_int: ParamInt = ParamInt(123) - param_float: ParamFloat = ParamFloat(DEFAULT_NUMBER) - param_bool: ParamBool = ParamBool(False) - param_str: ParamStr = ParamStr(DEFAULT_STRING) - param_none: ParamNone = ParamNone() param_data_frame: ParamDataFrame | None = None empty_param: EmptyParam | None = None simple_param: SimpleParam | None = None @@ -110,7 +96,7 @@ class ComplexParam(ParamDataclass): complex_param: ComplexParam | None = None param_list: ParamList[Any] = field(default_factory=ParamList) param_dict: ParamDict[Any] = field(default_factory=ParamDict) - param_data: ParamData | None = None + param_data: ParamData[Any] | None = None class CustomParamList(ParamList[Any]): @@ -121,24 +107,29 @@ class CustomParamDict(ParamDict[Any]): """Custom parameter dictionary subclass.""" -class CustomParamInt(ParamInt): - """Custom parameter integer subclass.""" - - -class CustomParamFloat(ParamFloat): - """Custom parameter float subclass.""" - - -class CustomParamBool(ParamBool): - """Custom parameter boolean subclass.""" - - -class CustomParamStr(ParamStr): - """Custom parameter string subclass.""" - - -class CustomParamNone(ParamNone): - """Custom parameter ``None`` subclass.""" +def assert_param_data_strong_equals( + param_data: ParamData[Any], other_param_data: ParamData[Any], child_name: str +) -> None: + """ + Assert that the given parameter data is equal to the other parameter data based on + equality as well as stronger tests, such as last updated times and children. + """ + # pylint: disable=protected-access + assert param_data == other_param_data + assert param_data.last_updated == other_param_data.last_updated + assert param_data.to_dict() == other_param_data.to_dict() + if child_name is not None: + assert param_data.child_last_updated( + child_name + ) == other_param_data.child_last_updated(child_name) + child = param_data._get_wrapped_child(child_name) + other_child = other_param_data._get_wrapped_child(child_name) + if isinstance(other_child, _ParamWrapper): + assert isinstance(child, _ParamWrapper) + assert child.value == other_child.value + else: + assert child == other_child + assert child.parent == other_child.parent @dataclass diff --git a/tests/test_database.py b/tests/test_database.py index f5a4490..6b23b81 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -15,23 +15,14 @@ SimpleParam, SubclassParam, ComplexParam, - CustomParamInt, - CustomParamFloat, - CustomParamBool, - CustomParamStr, - CustomParamNone, CustomParamList, CustomParamDict, Times, + assert_param_data_strong_equals, capture_start_end_times, ) from paramdb import ( ParamData, - ParamInt, - ParamFloat, - ParamBool, - ParamStr, - ParamNone, ParamDataclass, ParamDataFrame, ParamList, @@ -139,9 +130,11 @@ def test_load_nonexistent_commit_fails(db_path: str) -> None: ) -def test_commit_and_load(db_path: str, param_data: ParamData) -> None: +def test_commit_and_load( + db_path: str, param_data: ParamData[Any], param_data_child_name: str +) -> None: """Can commit and load parameter data and commit entries.""" - param_db = ParamDB[ParamData](db_path) + param_db = ParamDB[ParamData[Any]](db_path) with capture_start_end_times() as times: commit_entry = param_db.commit("Initial commit", param_data) @@ -154,30 +147,31 @@ def test_commit_and_load(db_path: str, param_data: ParamData) -> None: with capture_start_end_times(): param_data_latest = param_db.load() commit_entry_latest = param_db.load_commit_entry() - assert param_data_latest == param_data - assert param_data_latest.last_updated == param_data.last_updated + assert_param_data_strong_equals( + param_data_latest, param_data, param_data_child_name + ) assert commit_entry_latest == commit_entry # Can load by commit ID with capture_start_end_times(): param_data_first = param_db.load(commit_entry.id) commit_entry_first = param_db.load_commit_entry(commit_entry.id) - assert param_data_first == param_data - assert param_data_first.last_updated == param_data.last_updated + assert_param_data_strong_equals(param_data_first, param_data, param_data_child_name) assert commit_entry_first == commit_entry # Can load from history with capture_start_end_times(): param_data_from_history = param_db.commit_history_with_data()[0].data commit_entry_from_history = param_db.commit_history()[0] - assert param_data_from_history == param_data - assert param_data_from_history.last_updated == param_data.last_updated + assert_param_data_strong_equals( + param_data_from_history, param_data, param_data_child_name + ) assert commit_entry_from_history == commit_entry def test_commit_and_load_timestamp(db_path: str, simple_param: SimpleParam) -> None: """Can make a commit using a specific timestamp and load it back.""" - param_db = ParamDB[ParamData](db_path) + param_db = ParamDB[ParamData[Any]](db_path) utc_timestamp = datetime.now(timezone.utc) naive_timestamp = utc_timestamp.replace(tzinfo=None) aware_timestamp = utc_timestamp.astimezone() @@ -199,9 +193,9 @@ def test_commit_and_load_timestamp(db_path: str, simple_param: SimpleParam) -> N assert commit_entry_with_data.timestamp == aware_timestamp -def test_load_classes_false(db_path: str, param_data: ParamData) -> None: +def test_load_classes_false(db_path: str, param_data: ParamData[Any]) -> None: """Can load data as dictionaries if ``load_classes`` is false.""" - param_db = ParamDB[ParamData](db_path) + param_db = ParamDB[ParamData[Any]](db_path) param_db.commit("Initial commit", param_data) data_loaded = param_db.load(load_classes=False) data_from_history = param_db.commit_history_with_data(load_classes=False)[0].data @@ -259,11 +253,6 @@ def test_commit_and_load_complex( string: str, param_list_contents: list[Any], param_dict_contents: dict[str, Any], - param_int: ParamInt, - param_float: ParamFloat, - param_bool: ParamBool, - param_str: ParamStr, - param_none: ParamNone, param_data_frame: ParamDataFrame, empty_param: EmptyParam, simple_param: SimpleParam, @@ -281,11 +270,6 @@ class Root(ParamDataclass): string: str list: list[Any] dict: dict[str, Any] - param_int: ParamInt - param_float: ParamFloat - param_bool: ParamBool - param_str: ParamStr - param_none: ParamNone param_data_frame: ParamDataFrame empty_param: EmptyParam simple_param: SimpleParam @@ -293,11 +277,6 @@ class Root(ParamDataclass): complex_param: ComplexParam param_list: ParamList[Any] param_dict: ParamDict[Any] - custom_param_int: CustomParamInt - custom_param_float: CustomParamFloat - custom_param_bool: CustomParamBool - custom_param_str: CustomParamStr - custom_param_none: CustomParamNone custom_param_list: CustomParamList custom_param_dict: CustomParamDict @@ -306,11 +285,6 @@ class Root(ParamDataclass): string=string, list=param_list_contents, dict=param_dict_contents, - param_int=param_int, - param_float=param_float, - param_bool=param_bool, - param_str=param_str, - param_none=param_none, param_data_frame=param_data_frame, empty_param=empty_param, simple_param=simple_param, @@ -318,22 +292,15 @@ class Root(ParamDataclass): complex_param=complex_param, param_list=param_list, param_dict=param_dict, - custom_param_int=CustomParamInt(param_int.value), - custom_param_float=CustomParamFloat(param_float.value), - custom_param_bool=CustomParamBool(param_bool.value), - custom_param_str=CustomParamStr(param_str.value), - custom_param_none=CustomParamNone(), custom_param_list=CustomParamList(deepcopy(param_list_contents)), custom_param_dict=CustomParamDict(deepcopy(param_dict_contents)), ) param_db = ParamDB[Root](db_path) param_db.commit("Initial commit", root) root_loaded = param_db.load() - assert root_loaded == root - assert root_loaded.last_updated == root.last_updated + assert_param_data_strong_equals(root_loaded, root, "number") root_from_history = param_db.commit_history_with_data()[0].data - assert root_from_history == root - assert root_from_history.last_updated == root.last_updated + assert_param_data_strong_equals(root_from_history, root, "number") def test_commit_load_latest(db_path: str) -> None: @@ -379,11 +346,9 @@ def test_commit_load_multiple(db_path: str) -> None: # Verify data param_loaded = param_db.load(commit_entry.id) - assert param_loaded == param - assert param_loaded.last_updated == param.last_updated + assert_param_data_strong_equals(param_loaded, param, "number") param_from_history = commit_entry_with_data_from_history.data - assert param_from_history == param - assert param_from_history.last_updated == param.last_updated + assert_param_data_strong_equals(param_from_history, param, "number") def test_separate_connections(db_path: str, simple_param: SimpleParam) -> None: @@ -399,8 +364,8 @@ def test_separate_connections(db_path: str, simple_param: SimpleParam) -> None: # Load back using another connection param_db2 = ParamDB[SimpleParam](db_path) param_loaded = param_db2.load() - assert simple_param == param_loaded - assert simple_param.last_updated == param_loaded.last_updated + + assert_param_data_strong_equals(param_loaded, simple_param, "number") def test_empty_num_commits(db_path: str) -> None: From 25efa37de2e128c7d85d2d9204b5b2c2a7ea4b93 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Wed, 12 Jun 2024 13:15:17 -0700 Subject: [PATCH 03/12] #187 Update parameter collection tests --- tests/_param_data/test_collections.py | 47 +++++---------------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/tests/_param_data/test_collections.py b/tests/_param_data/test_collections.py index 2ff5843..2d6e287 100644 --- a/tests/_param_data/test_collections.py +++ b/tests/_param_data/test_collections.py @@ -4,6 +4,7 @@ from copy import deepcopy import pytest from tests.helpers import ( + ComplexParam, CustomParamList, CustomParamDict, capture_start_end_times, @@ -229,14 +230,14 @@ def test_param_list_get_slice( def test_param_list_get_slice_parent(param_list: ParamList[Any]) -> None: """ - Slices of a parameter list have no parent, and the parent of their items is the - slice, not the original parameter list. + Slices of a parameter list have the same parent as the original parameter list, and + the parent of their items is the original parameter list. """ + parent = ComplexParam(param_list=param_list) sublist = param_list[2:4] - with pytest.raises(ValueError): - assert sublist.parent - assert sublist[0].parent is sublist - assert sublist[1].parent is sublist + assert sublist.parent is parent + assert sublist[0].parent is param_list + assert sublist[0].parent is param_list def test_param_list_set_index(param_list: ParamList[Any]) -> None: @@ -275,9 +276,9 @@ def test_param_list_set_index_parent( def test_param_list_set_slice(param_list: ParamList[Any]) -> None: """Can set items by slice in a parameter list.""" new_numbers = [4.56, 7.89] - assert param_list[0:2] != new_numbers + assert list(param_list[0:2]) != new_numbers param_list[0:2] = new_numbers - assert param_list[0:2] == new_numbers + assert list(param_list[0:2]) == new_numbers def test_param_list_set_slice_last_updated(param_list: ParamList[Any]) -> None: @@ -493,33 +494,3 @@ def test_param_dict_iter( """A parameter dictionary correctly supports iteration.""" for key, contents_key in zip(param_dict, param_dict_contents): assert key == contents_key - - -def test_param_dict_keys( - param_dict: ParamDict[Any], param_dict_contents: dict[str, Any] -) -> None: - """A parameter dictionary outputs keys as a dict_keys object.""" - param_dict_keys = param_dict.keys() - contents_keys = param_dict_contents.keys() - assert isinstance(param_dict_keys, type(contents_keys)) - assert param_dict_keys == param_dict_keys - - -def test_param_dict_values( - param_dict: ParamDict[Any], param_dict_contents: dict[str, Any] -) -> None: - """A parameter dictionary outputs values as a dict_values object.""" - param_dict_values = param_dict.values() - contents_values = param_dict_contents.values() - assert isinstance(param_dict_values, type(contents_values)) - assert list(param_dict_values) == list(contents_values) - - -def test_param_dict_items( - param_dict: ParamDict[Any], param_dict_contents: dict[str, Any] -) -> None: - """A parameter dictionary outputs items as a dict_items object.""" - param_dict_items = param_dict.items() - contents_items = param_dict_contents.items() - assert isinstance(param_dict_items, type(contents_items)) - assert param_dict_items == contents_items From fc0a85ed7eef9ab718a6cdcf9302e6d24be2d4a8 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Wed, 12 Jun 2024 14:07:04 -0700 Subject: [PATCH 04/12] #187 Update parameter collection tests --- tests/_param_data/test_collections.py | 16 +++++++++++++++- tests/helpers.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/_param_data/test_collections.py b/tests/_param_data/test_collections.py index 2d6e287..9f77b64 100644 --- a/tests/_param_data/test_collections.py +++ b/tests/_param_data/test_collections.py @@ -7,6 +7,7 @@ ComplexParam, CustomParamList, CustomParamDict, + assert_param_data_strong_equals, capture_start_end_times, ) from paramdb import ParamData, ParamList, ParamDict @@ -223,9 +224,13 @@ def test_param_list_get_index_parent(param_list: ParamList[Any]) -> None: def test_param_list_get_slice( param_list: ParamList[Any], param_list_contents: list[Any] ) -> None: - """Can get an item by slice from a parameter list.""" + """ + Can get an item by slice from a parameter list. Also, a slice of the entire list is + strongly equal to the original list. + """ assert isinstance(param_list[0:2], ParamList) assert list(param_list[0:2]) == param_list_contents[0:2] + assert_param_data_strong_equals(param_list[:], param_list, child_name=1) def test_param_list_get_slice_parent(param_list: ParamList[Any]) -> None: @@ -240,6 +245,15 @@ def test_param_list_get_slice_parent(param_list: ParamList[Any]) -> None: assert sublist[0].parent is param_list +def test_param_list_get_slice_references(param_list: ParamList[Any]) -> None: + """ + Children of a parameter list slice are references to items in the original list. + """ + sublist = param_list[2:] + assert sublist[1] is param_list[3] + assert sublist[2] is param_list[4] + + def test_param_list_set_index(param_list: ParamList[Any]) -> None: """Can set an item by index in a parameter list.""" new_number = 4.56 diff --git a/tests/helpers.py b/tests/helpers.py index 1a5a1dc..780981e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -108,7 +108,7 @@ class CustomParamDict(ParamDict[Any]): def assert_param_data_strong_equals( - param_data: ParamData[Any], other_param_data: ParamData[Any], child_name: str + param_data: ParamData[Any], other_param_data: ParamData[Any], child_name: str | int ) -> None: """ Assert that the given parameter data is equal to the other parameter data based on From 5cd8a1a74825cbdf3994ab7d8fc3e8f1388427c4 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Wed, 12 Jun 2024 22:56:37 -0700 Subject: [PATCH 05/12] #187 Improve ParamDict dot notation access --- paramdb/_param_data/_collections.py | 47 +++++++++++++++++---------- paramdb/_param_data/_dataclasses.py | 5 ++- tests/_param_data/test_collections.py | 30 ++++++++--------- 3 files changed, 48 insertions(+), 34 deletions(-) diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index 5a3244a..3e205c8 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -1,7 +1,16 @@ """Parameter data collection classes.""" from __future__ import annotations -from typing import Union, TypeVar, Generic, SupportsIndex, Any, cast, overload +from typing import ( + Union, + TypeVar, + Generic, + SupportsIndex, + Any, + cast, + overload, + get_type_hints, +) from collections.abc import ( Iterator, Collection, @@ -158,11 +167,12 @@ def __eq__(self, other: Any) -> bool: return isinstance(other, ParamDict) and self._contents == other._contents def __dir__(self) -> Iterable[str]: - # Return keys that are not attribute names (i.e. do not pass self._is_attribute) - # in __dir__() so they are suggested by interactive prompts like IPython. + # In addition to the default __dir__(), include dictionary keys so they are + # suggested for dot notation by interactive prompts like IPython. + super_dir = super().__dir__() return [ - *super().__dir__(), - *filter(lambda key: not self._is_attribute(key), self._contents.keys()), + *super_dir, + *filter(lambda key: key not in super_dir, self.keys()), ] def __getitem__(self, key: str) -> ItemT: @@ -185,13 +195,12 @@ def __iter__(self) -> Iterator[str]: def __getattr__(self, name: str) -> ItemT: # Enable accessing items via dot notation - if self._is_attribute(name): - # It is important to raise an attribute error rather than a key error for - # names considered to be attributes. For example, this allows deepcopy to - # work properly. - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) + if self._is_attribute(name) or name not in self: + # If the name corresponds to an attribute or is not in the dictionary, we + # should raise the default AttributeError rather than a KeyError, since this + # is the expected behavior of __getattr__(). For example, this allows + # getattr() and hasattr() to work properly. + self.__getattribute__(name) # Raises the default AttributeError return self[name] def __setattr__(self, name: str, value: ItemT) -> None: @@ -203,15 +212,19 @@ def __setattr__(self, name: str, value: ItemT) -> None: def __delattr__(self, name: str) -> None: # Enable deleting items via dot notation - if self._is_attribute(name): + if self._is_attribute(name) or name not in self: super().__delattr__(name) else: del self[name] def _is_attribute(self, name: str) -> bool: """ - Names beginning with underscores are considered to be attributes when accessed - via dot notation. This is both to allow internal Python variables to be set - (i.e. dunder variables), and to allow for true attributes to be used if needed. + If the given name matches an existing attribute or has a corresponding class + type hint, treat it as the name of an attribute. """ - return len(name) > 0 and name[0] == "_" + try: + self.__getattribute__(name) # pylint: disable=unnecessary-dunder-call + existing_attribute = True + except AttributeError: + existing_attribute = False + return existing_attribute or name in get_type_hints(type(self)) diff --git a/paramdb/_param_data/_dataclasses.py b/paramdb/_param_data/_dataclasses.py index 6c61e0a..8054cc8 100644 --- a/paramdb/_param_data/_dataclasses.py +++ b/paramdb/_param_data/_dataclasses.py @@ -151,7 +151,10 @@ def __getattribute__(self, name: str) -> Any: def __setattr__(self, name: str, value: Any) -> None: # If this attribute is a field, process the old and new child if name in self._field_names: - old_wrapped_value = super().__getattribute__(name) + try: + old_wrapped_value = super().__getattribute__(name) + except AttributeError: + old_wrapped_value = None self.__base_setattr(name, value) # May perform type validation wrapped_value = self._wrap_child(value) super().__setattr__(name, wrapped_value) diff --git a/tests/_param_data/test_collections.py b/tests/_param_data/test_collections.py index 9f77b64..59f3490 100644 --- a/tests/_param_data/test_collections.py +++ b/tests/_param_data/test_collections.py @@ -378,36 +378,34 @@ def test_param_list_empty_last_updated() -> None: def test_param_dict_key_error(param_dict: ParamDict[Any]) -> None: - """Getting or deleting a nonexistent key raises a KeyError.""" + """ + Getting or deleting a nonexistent key raises a KeyError if accessed using bracket + notation. + """ with pytest.raises(KeyError): _ = param_dict["nonexistent"] with pytest.raises(KeyError): del param_dict["nonexistent"] - with pytest.raises(KeyError): - _ = param_dict.nonexistent - with pytest.raises(KeyError): - del param_dict.nonexistent def test_param_dict_attribute_error(param_dict: ParamDict[Any]) -> None: - """Getting or deleting a nonexistent attribute raises an AttributeError.""" + """ + Getting or deleting a nonexistent attribute raises an AttributeError if accessed + using dot notation. + """ with pytest.raises(AttributeError): - _ = param_dict._nonexistent # pylint: disable=protected-access + _ = param_dict.nonexistent # pylint: disable=protected-access with pytest.raises(AttributeError): - del param_dict._nonexistent # pylint: disable=protected-access + del param_dict.nonexistent # pylint: disable=protected-access def test_param_dict_dir(param_dict: ParamDict[Any]) -> None: - """ - Keys of a parameter dictionary that are not attribute names (names that pass - ParamDict._is_attribute) are returned by dir(). - """ + """Keys of a parameter dictionary are included in the list returned by dir().""" param_dict["_attribute_name"] = 123 - param_dict_dir_items = set(dir(param_dict)) - assert "_attribute_name" not in param_dict_dir_items + param_dict["__attribute_name__"] = 456 + param_dict_dir = dir(param_dict) for key in param_dict.keys(): - if not param_dict._is_attribute(key): # pylint: disable=protected-access - assert key in param_dict_dir_items + assert key in param_dict_dir def test_param_dict_get( From 54b3eac2a6f33d7f74c445e2fcfda06295382dc8 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Wed, 12 Jun 2024 23:12:18 -0700 Subject: [PATCH 06/12] #187 Fix ParamDict class hint detection in Python 3.9 --- paramdb/_param_data/_collections.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index 3e205c8..eb0175e 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -1,16 +1,7 @@ """Parameter data collection classes.""" from __future__ import annotations -from typing import ( - Union, - TypeVar, - Generic, - SupportsIndex, - Any, - cast, - overload, - get_type_hints, -) +from typing import Union, TypeVar, Generic, SupportsIndex, Any, cast, overload from collections.abc import ( Iterator, Collection, @@ -227,4 +218,7 @@ def _is_attribute(self, name: str) -> bool: existing_attribute = True except AttributeError: existing_attribute = False - return existing_attribute or name in get_type_hints(type(self)) + class_annotations: dict[str, Any] = {} + for cls in type(self).mro(): + class_annotations |= getattr(cls, "__annotations__", {}) + return existing_attribute or name in class_annotations From ae3ebaea6a0c13abdf465d009cbc5490c3562743 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Thu, 13 Jun 2024 11:00:14 -0700 Subject: [PATCH 07/12] #187 Update docs --- docs/api-reference.md | 5 - docs/parameter-data.md | 132 ++++++++------------------ tests/_param_data/test_dataclasses.py | 4 +- tests/_param_data/test_param_data.py | 8 +- tests/conftest.py | 24 ++--- tests/helpers.py | 2 +- 6 files changed, 61 insertions(+), 114 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index edbf045..556e041 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -10,11 +10,6 @@ All of the following can be imported from `paramdb`. ```{eval-rst} .. autoclass:: ParamData -.. autoclass:: ParamInt -.. autoclass:: ParamFloat -.. autoclass:: ParamBool -.. autoclass:: ParamStr -.. autoclass:: ParamNone .. autoclass:: ParamDataclass .. autoclass:: ParamFile .. autoclass:: ParamDataFrame diff --git a/docs/parameter-data.md b/docs/parameter-data.md index b07a10f..3c7a653 100644 --- a/docs/parameter-data.md +++ b/docs/parameter-data.md @@ -32,46 +32,6 @@ type (`str`, `int`, `float`, `bool`, `None`, `dict`, or `list`), a [`datetime`], a `TypeError` will be raised when they are committed to the database. ``` -## Primitives - -Primitives are the building blocks of parameter data. While builtin primitive types can -be used in a ParamDB (`int`, `float`, `str`, `bool`, and `None`), they will not store a -{py:class}`~ParamData.last_updated` time and will not have {py:class}`~ParamData.parent` -or {py:class}`~ParamData.root` properties. When these features are desired, we can wrap -primitive values in the following types: - -- {py:class}`ParamInt` for integers -- {py:class}`ParamFloat` for float -- {py:class}`ParamBool` for booleans -- {py:class}`ParamStr` for strings -- {py:class}`ParamNone` for `None` - -For example: - -```{jupyter-execute} -from paramdb import ParamInt - -param_int = ParamInt(123) -param_int -``` - -```{jupyter-execute} -print(param_int.last_updated) -``` - -````{tip} -Methods from the builtin primitive types work on parameter primitives, with the caveat -that they return the builtin type. For example: - -```{jupyter-execute} -param_int + 123 -``` - -```{jupyter-execute} -type(param_int + 123) -``` -```` - ## Data Classes A parameter data class is defined from the base class {py:class}`ParamDataclass`. This @@ -81,19 +41,20 @@ function is generated. An example of a defining a custom parameter Data Class is below. ```{jupyter-execute} -from paramdb import ParamFloat, ParamDataclass +from paramdb import ParamDataclass class CustomParam(ParamDataclass): - value: ParamFloat + value: float -custom_param = CustomParam(value=ParamFloat(1.23)) +custom_param = CustomParam(value=1.23) +print(custom_param) ``` These properties can then be accessed and updated. ```{jupyter-execute} -custom_param.value = ParamFloat(1.234) -custom_param.value +custom_param.value += 0.004 +print(custom_param) ``` The data class aspects of the subclass can be customized by passing keyword arguments when @@ -112,16 +73,18 @@ when building up dataclasses through inheritance. from dataclasses import field class KeywordOnlyParam(ParamDataclass, kw_only=True): - count: int + num_values: int = 0 values: list[int] = field(default_factory=list) + type: str -keyword_only_param = KeywordOnlyParam(count=123) -keyword_only_param +keyword_only_param = KeywordOnlyParam(type="example") +print(keyword_only_param) ``` ```{warning} -For mutable default values, `default_factory` should generally be used. See the Python -data class documentation on [mutable default values] for more information. +For mutable default values, `default_factory` should generally be used (see the example +above). See the Python data class documentation on [mutable default values] for more +information. ``` Custom methods can also be added, including dynamic properties using the [`@property`] @@ -129,27 +92,27 @@ decorator. For example: ```{jupyter-execute} class ParamWithProperty(ParamDataclass): - value: ParamInt + value: int @property def value_cubed(self) -> int: return self.value ** 3 -param_with_property = ParamWithProperty(value=ParamInt(16)) -param_with_property.value_cubed +param_with_property = ParamWithProperty(value=16) +print(param_with_property.value_cubed) ``` ````{important} Since [`__init__`] is generated for data classes, other initialization must be done using the [`__post_init__`] function. Furthermore, since [`__post_init__`] is used internally by {py:class}`ParamDataclass` to perform initialization, always call the superclass's -[`__post_init__`]. For example: +[`__post_init__`] first. For example: ```{jupyter-execute} class ParamCustomInit(ParamDataclass): def __post_init__(self) -> None: + super().__post_init__() # Always call the superclass __post_init__() first print("Initializing...") # Replace with custom initialization code - super().__post_init__() param_custom_init = ParamCustomInit() ``` @@ -166,27 +129,33 @@ print(custom_param.last_updated) import time time.sleep(1) -custom_param.value = ParamFloat(4.56) +custom_param.value = 4.56 print(custom_param.last_updated) ``` -Parameter dataclasses can also be nested, in which case the -{py:attr}`ParamData.last_updated` property returns the most recent last updated time stamp -among its own last updated time and the last updated times of any {py:class}`ParamData` -it contains. For example: +Last updated times for properties can also be accessed using by calling +{py:meth}`ParamData.child_last_updated` on the parent object. This is particularly useful +for property values which are not {py:class}`ParamData`. For example: + +```{jupyter-execute} +print(custom_param.child_last_updated("value")) +``` + +When parameter dataclasses are nested, updating a child also updates the last updated +times of its parents. For example: ```{jupyter-execute} class NestedParam(ParamDataclass): value: float child_param: CustomParam -nested_param = NestedParam(value=1.23, child_param=CustomParam(value=ParamFloat(4.56))) +nested_param = NestedParam(value=1.23, child_param=CustomParam(value=4.56)) print(nested_param.last_updated) ``` ```{jupyter-execute} time.sleep(1) -nested_param.child_param.value = ParamFloat(2) +nested_param.child_param.value += 1 print(nested_param.last_updated) ``` @@ -273,54 +242,33 @@ properly. For example: ```{jupyter-execute} from paramdb import ParamList -param_list = ParamList([ParamInt(1), ParamInt(2), ParamInt(3)]) -param_list[1].parent is param_list -``` - -```{jupyter-execute} -print(param_list.last_updated) -``` - -```{jupyter-execute} -time.sleep(1) -param_list[1] = ParamInt(4) -print(param_list.last_updated) +param_list = ParamList([1, 2, 3]) +print(param_list.child_last_updated(1)) ``` ### Parameter Dictionaries Similarly, {py:class}`ParamDict` implements `MutableMapping` from [`collections.abc`], -so it behaves similarly to a dictionary. Additionally, its items can be accessed via -dot notation in addition to index brackets (unless they begin with an underscore). For -example: +so it behaves similarly to a dictionary. Additionally, items can be accessed via dot +notation in addition to index brackets. For example: ```{jupyter-execute} from paramdb import ParamDict -param_dict = ParamDict(p1=ParamFloat(1.23), p2=ParamFloat(4.56), p3=ParamFloat(7.89)) -param_dict.p2.root == param_dict -``` - -```{jupyter-execute} -print(param_dict.last_updated) -``` - -```{jupyter-execute} -time.sleep(1) -param_dict.p2 = ParamFloat(0) -print(param_dict.last_updated) +param_dict = ParamDict(p1=1.23, p2=4.56, p3=7.89) +print(param_dict.child_last_updated("p2")) ``` Parameter collections can also be subclassed to provide custom functionality. For example: ```{jupyter-execute} -class CustomDict(ParamDict[ParamFloat]): +class CustomDict(ParamDict[float]): @property def total(self) -> float: - return sum(param.value for param in self.values()) + return sum(self.values()) custom_dict = CustomDict(param_dict) -custom_dict.total +print(custom_dict.total) ``` ## Type Mixins diff --git a/tests/_param_data/test_dataclasses.py b/tests/_param_data/test_dataclasses.py index 039697c..e447a08 100644 --- a/tests/_param_data/test_dataclasses.py +++ b/tests/_param_data/test_dataclasses.py @@ -147,7 +147,7 @@ def test_param_dataclass_init_wrong_type( assert "Input should be a valid number" in str(exc_info.value) -def test_param_dataclass_init_default_wrong_type() -> None: +def test_param_dataclass_init_default_wrong_type(number: float) -> None: """ Fails or succeeds to initialize a parameter object with a default value having the wrong type @@ -159,7 +159,7 @@ class DefaultWrongTypeParam(SimpleParam): default_number: float = "123" # type: ignore[assignment] with pytest.raises(pydantic.ValidationError) as exc_info: - DefaultWrongTypeParam() + DefaultWrongTypeParam(number=number) assert "Input should be a valid number" in str(exc_info.value) diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index 7432568..f2e0fe7 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -4,7 +4,7 @@ from dataclasses import is_dataclass from copy import deepcopy import pytest -from tests.helpers import ComplexParam, Times, capture_start_end_times +from tests.helpers import SimpleParam, ComplexParam, Times, capture_start_end_times from paramdb import ParamData, ParamDataFrame from paramdb._param_data._param_data import get_param_class @@ -48,12 +48,16 @@ def test_get_param_class(param_data: ParamData[Any]) -> None: assert get_param_class(param_class.__name__) is param_class -def test_param_data_initial_last_updated(param_data_type: type[ParamData[Any]]) -> None: +def test_param_data_initial_last_updated( + number: float, param_data_type: type[ParamData[Any]] +) -> None: """New parameter data objects are initialized with a last updated timestamp.""" with capture_start_end_times() as times: new_param_data: ParamData[Any] if issubclass(param_data_type, ParamDataFrame): new_param_data = param_data_type("") + elif issubclass(param_data_type, SimpleParam): + new_param_data = param_data_type(number=number) else: new_param_data = param_data_type() assert new_param_data.last_updated is not None diff --git a/tests/conftest.py b/tests/conftest.py index 397ef01..6c42b91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,12 +99,12 @@ def fixture_complex_param(number: float, string: str) -> ComplexParam: string=string, param_data_frame=ParamDataFrame(string), empty_param=EmptyParam(), - simple_param=SimpleParam(), - no_type_validation_param=NoTypeValidationParam(), - with_type_validation_param=WithTypeValidationParam(), - no_assignment_validation_param=NoAssignmentValidationParam(), - with_assignment_validation_param=WithAssignmentValidationParam(), - subclass_param=SubclassParam(), + simple_param=SimpleParam(number=number), + no_type_validation_param=NoTypeValidationParam(number=number), + with_type_validation_param=WithTypeValidationParam(number=number), + no_assignment_validation_param=NoAssignmentValidationParam(number=number), + with_assignment_validation_param=WithAssignmentValidationParam(number=number), + subclass_param=SubclassParam(number=number), complex_param=ComplexParam(), param_list=ParamList(), param_dict=ParamDict(), @@ -119,12 +119,12 @@ def fixture_param_list_contents(number: float, string: str) -> list[Any]: string, ParamDataFrame(string), EmptyParam(), - SimpleParam(), - NoTypeValidationParam(), - WithTypeValidationParam(), - NoAssignmentValidationParam(), - WithAssignmentValidationParam(), - SubclassParam(), + SimpleParam(number=number), + NoTypeValidationParam(number=number), + WithTypeValidationParam(number=number), + NoAssignmentValidationParam(number=number), + WithAssignmentValidationParam(number=number), + SubclassParam(number=number), ComplexParam(), ParamList(), ParamDict(), diff --git a/tests/helpers.py b/tests/helpers.py index 780981e..64a05cb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -45,7 +45,7 @@ class EmptyParam(ParamDataclass): class SimpleParam(ParamDataclass): """Simple parameter data class.""" - number: float = DEFAULT_NUMBER + number: float # No default to verify that non-default properties work number_init_false: float = field(init=False, default=DEFAULT_NUMBER) number_with_units: Quantity = Quantity(12, "m") string: str = DEFAULT_STRING From 7b4de7bfbe94d13a6880dfa52795a8e52c8dc9ee Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Thu, 13 Jun 2024 13:20:28 -0700 Subject: [PATCH 08/12] #187 Update tests --- tests/_param_data/test_param_data.py | 36 +++++++--- tests/conftest.py | 98 ++++++++++++++-------------- tests/helpers.py | 12 +++- tests/test_database.py | 29 +++++++- 4 files changed, 113 insertions(+), 62 deletions(-) diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index f2e0fe7..272af51 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -1,10 +1,16 @@ """Tests for the paramdb._param_data._param_data module.""" +from __future__ import annotations from typing import Any from dataclasses import is_dataclass from copy import deepcopy import pytest -from tests.helpers import SimpleParam, ComplexParam, Times, capture_start_end_times +from tests.helpers import ( + SimpleParam, + ComplexParam, + update_child, + capture_start_end_times, +) from paramdb import ParamData, ParamDataFrame from paramdb._param_data._param_data import get_param_class @@ -64,16 +70,26 @@ def test_param_data_initial_last_updated( assert times.start < new_param_data.last_updated.timestamp() < times.end -def test_param_data_updates_last_updated( - updated_param_data: ParamData[Any], updated_times: Times +def test_param_data_updating_child_updates_last_updated( + param_data: ParamData[Any], param_data_child_name: str | int | None ) -> None: - """Updating parameter data updates the last updated time.""" - assert updated_param_data.last_updated is not None - assert ( - updated_times.start - < updated_param_data.last_updated.timestamp() - < updated_times.end - ) + """The last updated time is updated when a child is updated.""" + if param_data_child_name is None: + return + with capture_start_end_times() as times: + update_child(param_data, param_data_child_name) + assert times.start < param_data.last_updated.timestamp() < times.end + + +def test_param_data__updates_last_updated( + param_data: ParamData[Any], param_data_child_name: str | int | None +) -> None: + """The last updated time is updated when a child is updated.""" + if param_data_child_name is None: + return + with capture_start_end_times() as times: + update_child(param_data, param_data_child_name) + assert times.start < param_data.last_updated.timestamp() < times.end def test_child_does_not_change(param_data: ParamData[Any]) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 6c42b91..5f6166c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,6 @@ WithAssignmentValidationParam, SubclassParam, ComplexParam, - Times, - capture_start_end_times, ) @@ -229,51 +227,51 @@ def fixture_param_data_child_name(param_data: ParamData[Any]) -> str | int | Non return None -@pytest.fixture(name="updated_param_data_and_times") -def fixture_updated_param_data_and_times( - param_data: ParamData[Any], number: float -) -> tuple[ParamData[Any], Times]: - """ - Parameter data that has been updated between the returned Times. Broken down into - individual fixtures for parameter data and times below. - """ - updated_param_data = deepcopy(param_data) - with capture_start_end_times() as times: - if isinstance(updated_param_data, EmptyParam): - updated_param_data = type(updated_param_data)() - elif isinstance(updated_param_data, ParamDataFrame): - updated_param_data.path = "" - elif isinstance(updated_param_data, SimpleParam): - updated_param_data.number += 1 - elif isinstance(updated_param_data, SubclassParam): - updated_param_data.second_number += 1 - elif isinstance(updated_param_data, ComplexParam): - assert updated_param_data.simple_param is not None - updated_param_data.simple_param.number += 1 - elif isinstance(updated_param_data, ParamList): - if len(updated_param_data) == 0: - updated_param_data.append(number) - else: - updated_param_data[4].number += 1 - elif isinstance(updated_param_data, ParamDict): - if len(updated_param_data) == 0: - updated_param_data["number"] = number - else: - updated_param_data.simple_param.number += 1 - return updated_param_data, times - - -@pytest.fixture(name="updated_param_data") -def fixture_updated_param_data( - updated_param_data_and_times: tuple[ParamData[Any], Times] -) -> ParamData[Any]: - """Parameter data that has been updated.""" - return updated_param_data_and_times[0] - - -@pytest.fixture(name="updated_times") -def fixture_updated_times( - updated_param_data_and_times: tuple[ParamData[Any], Times] -) -> Times: - """Times before and after param_data fixture was updated.""" - return updated_param_data_and_times[1] +# @pytest.fixture(name="updated_param_data_and_times") +# def fixture_updated_param_data_and_times( +# param_data: ParamData[Any], number: float +# ) -> tuple[ParamData[Any], Times]: +# """ +# Parameter data that has been updated between the returned Times. Broken down into +# individual fixtures for parameter data and times below. +# """ +# updated_param_data = deepcopy(param_data) +# with capture_start_end_times() as times: +# if isinstance(updated_param_data, EmptyParam): +# updated_param_data = type(updated_param_data)() +# elif isinstance(updated_param_data, ParamDataFrame): +# updated_param_data.path = "" +# elif isinstance(updated_param_data, SimpleParam): +# updated_param_data.number += 1 +# elif isinstance(updated_param_data, SubclassParam): +# updated_param_data.second_number += 1 +# elif isinstance(updated_param_data, ComplexParam): +# assert updated_param_data.simple_param is not None +# updated_param_data.simple_param.number += 1 +# elif isinstance(updated_param_data, ParamList): +# if len(updated_param_data) == 0: +# updated_param_data.append(number) +# else: +# updated_param_data[4].number += 1 +# elif isinstance(updated_param_data, ParamDict): +# if len(updated_param_data) == 0: +# updated_param_data["number"] = number +# else: +# updated_param_data.simple_param.number += 1 +# return updated_param_data, times + + +# @pytest.fixture(name="updated_param_data") +# def fixture_updated_param_data( +# updated_param_data_and_times: tuple[ParamData[Any], Times] +# ) -> ParamData[Any]: +# """Parameter data that has been updated.""" +# return updated_param_data_and_times[0] + + +# @pytest.fixture(name="updated_times") +# def fixture_updated_times( +# updated_param_data_and_times: tuple[ParamData[Any], Times] +# ) -> Times: +# """Times before and after param_data fixture was updated.""" +# return updated_param_data_and_times[1] diff --git a/tests/helpers.py b/tests/helpers.py index 64a05cb..6136027 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -108,7 +108,9 @@ class CustomParamDict(ParamDict[Any]): def assert_param_data_strong_equals( - param_data: ParamData[Any], other_param_data: ParamData[Any], child_name: str | int + param_data: ParamData[Any], + other_param_data: ParamData[Any], + child_name: str | int | None, ) -> None: """ Assert that the given parameter data is equal to the other parameter data based on @@ -132,6 +134,14 @@ def assert_param_data_strong_equals( assert child.parent == other_child.parent +def update_child(param_data: ParamData[Any], child_name: str | int) -> None: + """Update the specified child of the given parameter data.""" + # Update the child by assignment + child = param_data[child_name] # type: ignore[index] + # pylint: disable-next=unsupported-assignment-operation + param_data[child_name] = child # type: ignore[index] + + @dataclass class Times: """Start and end times captured by ``capture_start_end_times()``.""" diff --git a/tests/test_database.py b/tests/test_database.py index 6b23b81..33fcb94 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,6 +4,7 @@ # subscriptable. # pylint: disable=unsubscriptable-object +from __future__ import annotations from typing import Any from copy import deepcopy import os @@ -19,6 +20,7 @@ CustomParamDict, Times, assert_param_data_strong_equals, + update_child, capture_start_end_times, ) from paramdb import ( @@ -131,7 +133,7 @@ def test_load_nonexistent_commit_fails(db_path: str) -> None: def test_commit_and_load( - db_path: str, param_data: ParamData[Any], param_data_child_name: str + db_path: str, param_data: ParamData[Any], param_data_child_name: str | int | None ) -> None: """Can commit and load parameter data and commit entries.""" param_db = ParamDB[ParamData[Any]](db_path) @@ -193,6 +195,31 @@ def test_commit_and_load_timestamp(db_path: str, simple_param: SimpleParam) -> N assert commit_entry_with_data.timestamp == aware_timestamp +def test_update_timestamp_after_load( + db_path: str, param_data: ParamData[Any], param_data_child_name: str | int | None +) -> None: + """ + Updating the child of a parameter data object that has been loaded from the database + updates the timestamps of the object and the child. + + The object and child timestamps are not updated when reconstructing the object from + the database, so this tests that they are subsequently updated as usual. + """ + if param_data_child_name is None: + return + param_db = ParamDB[ParamData[Any]](db_path) + param_db.commit("Initial commit", param_data) + param_data_loaded = param_db.load() + with capture_start_end_times() as times: + update_child(param_data_loaded, param_data_child_name) + assert times.start < param_data_loaded.last_updated.timestamp() < times.end + assert ( + times.start + < param_data_loaded.child_last_updated(param_data_child_name).timestamp() + < times.end + ) + + def test_load_classes_false(db_path: str, param_data: ParamData[Any]) -> None: """Can load data as dictionaries if ``load_classes`` is false.""" param_db = ParamDB[ParamData[Any]](db_path) From ff884bdb8bf26cef0ca890af9b0be3808b845ffa Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Thu, 13 Jun 2024 13:21:11 -0700 Subject: [PATCH 09/12] #187 Update tests --- tests/_param_data/test_param_data.py | 11 ------ tests/conftest.py | 50 ---------------------------- 2 files changed, 61 deletions(-) diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index 272af51..92d19d8 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -81,17 +81,6 @@ def test_param_data_updating_child_updates_last_updated( assert times.start < param_data.last_updated.timestamp() < times.end -def test_param_data__updates_last_updated( - param_data: ParamData[Any], param_data_child_name: str | int | None -) -> None: - """The last updated time is updated when a child is updated.""" - if param_data_child_name is None: - return - with capture_start_end_times() as times: - update_child(param_data, param_data_child_name) - assert times.start < param_data.last_updated.timestamp() < times.end - - def test_child_does_not_change(param_data: ParamData[Any]) -> None: """ Including a parameter data object as a child within a parent structure does not diff --git a/tests/conftest.py b/tests/conftest.py index 5f6166c..cbd14a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -225,53 +225,3 @@ def fixture_param_data_child_name(param_data: ParamData[Any]) -> str | int | Non if isinstance(param_data, ParamDict): return None if len(param_data) == 0 else "simple_param" return None - - -# @pytest.fixture(name="updated_param_data_and_times") -# def fixture_updated_param_data_and_times( -# param_data: ParamData[Any], number: float -# ) -> tuple[ParamData[Any], Times]: -# """ -# Parameter data that has been updated between the returned Times. Broken down into -# individual fixtures for parameter data and times below. -# """ -# updated_param_data = deepcopy(param_data) -# with capture_start_end_times() as times: -# if isinstance(updated_param_data, EmptyParam): -# updated_param_data = type(updated_param_data)() -# elif isinstance(updated_param_data, ParamDataFrame): -# updated_param_data.path = "" -# elif isinstance(updated_param_data, SimpleParam): -# updated_param_data.number += 1 -# elif isinstance(updated_param_data, SubclassParam): -# updated_param_data.second_number += 1 -# elif isinstance(updated_param_data, ComplexParam): -# assert updated_param_data.simple_param is not None -# updated_param_data.simple_param.number += 1 -# elif isinstance(updated_param_data, ParamList): -# if len(updated_param_data) == 0: -# updated_param_data.append(number) -# else: -# updated_param_data[4].number += 1 -# elif isinstance(updated_param_data, ParamDict): -# if len(updated_param_data) == 0: -# updated_param_data["number"] = number -# else: -# updated_param_data.simple_param.number += 1 -# return updated_param_data, times - - -# @pytest.fixture(name="updated_param_data") -# def fixture_updated_param_data( -# updated_param_data_and_times: tuple[ParamData[Any], Times] -# ) -> ParamData[Any]: -# """Parameter data that has been updated.""" -# return updated_param_data_and_times[0] - - -# @pytest.fixture(name="updated_times") -# def fixture_updated_times( -# updated_param_data_and_times: tuple[ParamData[Any], Times] -# ) -> Times: -# """Times before and after param_data fixture was updated.""" -# return updated_param_data_and_times[1] From 51bbedae04c36cd2ae530c457426013b84398656 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Thu, 13 Jun 2024 13:46:02 -0700 Subject: [PATCH 10/12] #187 Update CHANGELOG and bump version from 0.12.0 to 0.13.0b1 --- CHANGELOG.md | 22 ++++++++++++++++++++-- CITATION.cff | 4 ++-- docs/conf.py | 2 +- pyproject.toml | 2 +- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b34622a..b27a21a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this -project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +project adheres to clauses 1–8 of [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.13.0b1] (Jun 13 2024) + +### Added + +- The timestamps of non-`ParamData` children are now tracked internally and can be + accessed via the new method `ParamData.child_last_updated()`. + +### Changed + +- `ParamDict` dot notation now treates names of existing attributes and names of class + type annotations as attributes (rather than treating all names beginning with + underscores as attributes). + +### Removed + +- Parameter primitive classes have been replaced by the new timestamp tracking. + ## [0.12.0] (May 8 2024) ### Added @@ -169,7 +186,8 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Database class `ParamDB` to store parameters in a SQLite file - Ability to retrieve the commit history as `CommitEntry` objects -[unreleased]: https://github.com/PainterQubits/paramdb/compare/v0.12.0...develop +[unreleased]: https://github.com/PainterQubits/paramdb/compare/v0.13.0b1...develop +[0.13.0b1]: https://github.com/PainterQubits/paramdb/releases/tag/v0.13.0b1 [0.12.0]: https://github.com/PainterQubits/paramdb/releases/tag/v0.12.0 [0.11.0]: https://github.com/PainterQubits/paramdb/releases/tag/v0.11.0 [0.10.2]: https://github.com/PainterQubits/paramdb/releases/tag/v0.10.2 diff --git a/CITATION.cff b/CITATION.cff index 20932af..d8202dd 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,6 +4,6 @@ authors: - family-names: "Hadley" given-names: "Alex" title: "ParamDB" -version: 0.12.0 -date-released: 2024-05-08 +version: 0.13.0b1 +date-released: 2024-06-13 url: "https://github.com/PainterQubits/paramdb" diff --git a/docs/conf.py b/docs/conf.py index 7775dff..9d0be58 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,7 +4,7 @@ project = "ParamDB" copyright = "2023–2024, California Institute of Technology" author = "Alex Hadley" -release = "0.12.0" +release = "0.13.0b1" # General configuration extensions = [ diff --git a/pyproject.toml b/pyproject.toml index d99556a..0b09ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "paramdb" -version = "0.12.0" +version = "0.13.0b1" description = "Python package for storing and retrieving experiment parameters." authors = ["Alex Hadley "] license = "BSD-3-Clause" From 9ef9edb33f46998a88e9a9ac2ee859d98e388f0f Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Thu, 13 Jun 2024 13:48:57 -0700 Subject: [PATCH 11/12] #187 Update ParamDict docstring --- paramdb/_param_data/_collections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index eb0175e..b8f31d3 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -139,8 +139,8 @@ class ParamDict( Mutable mapping that is also parameter data. It can be initialized from any mapping or using keyword arguments (like builtin ``dict``). - Keys that do not begin with an underscore can be set via dot notation. Keys, values, - and items are returned as dict_keys, dict_values, and dict_items objects. + Keys that do not refer to existing attributes or class type hints can be gotten, + set, and deleted via dot notation. """ def __init__(self, mapping: Mapping[str, ItemT] | None = None, /, **kwargs: ItemT): From ed61636d814b27e70313e3cce991be370cb9603c Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Fri, 14 Jun 2024 12:12:38 -0700 Subject: [PATCH 12/12] #187 Improve JSON format of ParamDB commits --- CHANGELOG.md | 5 + docs/api-reference.md | 12 +- paramdb/__init__.py | 4 +- paramdb/_database.py | 174 ++++++++++++++++----------- paramdb/_param_data/_collections.py | 2 +- paramdb/_param_data/_dataclasses.py | 2 +- paramdb/_param_data/_param_data.py | 91 +++++++------- tests/_param_data/test_param_data.py | 14 +-- tests/helpers.py | 2 +- tests/test_database.py | 78 ++++++------ 10 files changed, 200 insertions(+), 184 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b27a21a..39b4660 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,12 +13,17 @@ project adheres to clauses 1–8 of [Semantic Versioning](https://semver.org/spe - The timestamps of non-`ParamData` children are now tracked internally and can be accessed via the new method `ParamData.child_last_updated()`. +- The class `ParamDBKey` contains the keys used in the JSON representation of a commit. ### Changed - `ParamDict` dot notation now treates names of existing attributes and names of class type annotations as attributes (rather than treating all names beginning with underscores as attributes). +- The JSON format of a commit has been changed, as specified in the docstring for + `ParamDB.load()`. +- `ParamData.to_dict()` and `ParamData.from_dict()` have been replaced by + `ParamData.to_json()` and `ParamData.from_json()`. ### Removed diff --git a/docs/api-reference.md b/docs/api-reference.md index 556e041..0c81b85 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -29,15 +29,5 @@ All of the following can be imported from `paramdb`. .. autoclass:: ParamDB .. autoclass:: CommitEntry .. autoclass:: CommitEntryWithData -``` - - - -```{eval-rst} -.. py:currentmodule:: paramdb._database -.. autodata:: CLASS_NAME_KEY -.. py:currentmodule:: paramdb +.. autoclass:: ParamDBKey ``` diff --git a/paramdb/__init__.py b/paramdb/__init__.py index 8465303..a99303e 100644 --- a/paramdb/__init__.py +++ b/paramdb/__init__.py @@ -5,7 +5,7 @@ from paramdb._param_data._files import ParamFile from paramdb._param_data._collections import ParamList, ParamDict from paramdb._param_data._type_mixins import ParentType, RootType -from paramdb._database import CLASS_NAME_KEY, ParamDB, CommitEntry, CommitEntryWithData +from paramdb._database import ParamDB, CommitEntry, CommitEntryWithData, ParamDBKey __all__ = [ "ParamData", @@ -15,10 +15,10 @@ "ParamDict", "ParentType", "RootType", - "CLASS_NAME_KEY", "ParamDB", "CommitEntry", "CommitEntryWithData", + "ParamDBKey", ] try: diff --git a/paramdb/_database.py b/paramdb/_database.py index 0408870..756113b 100644 --- a/paramdb/_database.py +++ b/paramdb/_database.py @@ -14,7 +14,7 @@ Mapped, mapped_column, ) -from paramdb._param_data._param_data import ParamData, get_param_class +from paramdb._param_data._param_data import ParamData, _ParamWrapper, get_param_class try: from astropy.units import Quantity # type: ignore[import-untyped] @@ -26,11 +26,28 @@ DataT = TypeVar("DataT") _SelectT = TypeVar("_SelectT", bound=Select[Any]) -CLASS_NAME_KEY = "__type" -""" -Dictionary key corresponding to an object's class name in the JSON representation of a -ParamDB commit. -""" + +class ParamDBKey: + """ + Keys corresponding to different object types in the JSON representation of the data + in a ParamDB commit. + """ + + DATETIME = "t" + """Key for ``datetime.datetime`` objects.""" + QUANTITY = "q" + """Key for ``astropy.units.quantity`` objects.""" + LIST = "l" + """Key for ordinary lists.""" + DICT = "d" + """Key for ordinary dictionaries.""" + WRAPPER = "w" + """ + Key for non-:py:class:`ParamData` children of :py:class:`ParamData` objects, since + they are wrapped with additional metadata, such as a last updated time. + """ + PARAM = "p" + """Key for :py:class:`ParamData` objects.""" def _compress(text: str) -> bytes: @@ -43,79 +60,80 @@ def _decompress(compressed_text: bytes) -> str: return ZstdDecompressor().decompress(compressed_text).decode() -def _full_class_name(cls: type) -> str: - """ - Return the full name of the given class, including the module. Used to convert - non-parameter-data objects to and from JSON. - """ - return f"{cls.__module__}.{cls.__name__}" - - -def _from_dict(json_dict: dict[str, Any]) -> Any: - """ - If the given dictionary created by ``json.loads()`` has the key ``CLASS_NAME_KEY``, - attempt to construct an object of the named type from it. Otherwise, return the - dictionary unchanged. - - If load_classes is False, then parameter data objects will be loaded as - dictionaries. +# pylint: disable-next=too-many-return-statements +def _encode_json(obj: Any) -> Any: """ - if CLASS_NAME_KEY not in json_dict: - return json_dict - class_name = json_dict.pop(CLASS_NAME_KEY) - if class_name == _full_class_name(datetime): - return datetime.fromisoformat(json_dict["isoformat"]).astimezone() - if _ASTROPY_INSTALLED and class_name == _full_class_name(Quantity): - return Quantity(**json_dict) - param_class = get_param_class(class_name) - if param_class is not None: - return param_class.from_dict(json_dict) - raise ValueError( - f"class '{class_name}' is not known to ParamDB, so the load failed" - ) - + Encode the given object and its children into a JSON-serializable format. -def _preprocess_json(obj: Any) -> Any: - """ - Preprocess the given object and its children into a JSON-serializable format. - Compared with ``json.dumps()``, this function can define custom logic for dealing - with subclasses of ``int``, ``float``, and ``str``. + See ``ParamDB.load()`` for the format specification. """ - if isinstance(obj, ParamData): - return {CLASS_NAME_KEY: type(obj).__name__} | _preprocess_json(obj.to_dict()) if isinstance(obj, (int, float, bool, str)) or obj is None: return obj - if isinstance(obj, (list, tuple)): - return [_preprocess_json(value) for value in obj] - if isinstance(obj, dict): - return {key: _preprocess_json(value) for key, value in obj.items()} - class_full_name = _full_class_name(type(obj)) - class_full_name_dict = {CLASS_NAME_KEY: class_full_name} if isinstance(obj, datetime): - return class_full_name_dict | {"isoformat": obj.isoformat()} + return [ParamDBKey.DATETIME, obj.timestamp()] if _ASTROPY_INSTALLED and isinstance(obj, Quantity): - return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)} + return [ParamDBKey.QUANTITY, obj.value, str(obj.unit)] + if isinstance(obj, (list, tuple)): + return [ParamDBKey.LIST, [_encode_json(item) for item in obj]] + if isinstance(obj, dict): + return [ + ParamDBKey.DICT, + {key: _encode_json(value) for key, value in obj.items()}, + ] + if isinstance(obj, ParamData): + timestamp_and_json = [ + obj.last_updated.timestamp(), + _encode_json(obj.to_json()), + ] + if isinstance(obj, _ParamWrapper): + return [ParamDBKey.WRAPPER, *timestamp_and_json] + return [ParamDBKey.PARAM, type(obj).__name__, *timestamp_and_json] raise TypeError( - f"'{class_full_name}' object {repr(obj)} is not JSON serializable, so the" + f"'{type(obj).__name__}' object {repr(obj)} is not JSON serializable, so the" " commit failed" ) +# pylint: disable-next=too-many-return-statements +def _decode_json(json_data: Any) -> Any: + """Reconstruct an object encoded by ``_json_encode()``.""" + if isinstance(json_data, list): + key, *data = json_data + if key == ParamDBKey.DATETIME: + return datetime.fromtimestamp(data[0], timezone.utc).astimezone() + if _ASTROPY_INSTALLED and key == ParamDBKey.QUANTITY: + return Quantity(*data) + if key == ParamDBKey.LIST: + return [_decode_json(item) for item in data[0]] + if key == ParamDBKey.DICT: + return {key: _decode_json(value) for key, value in data[0].items()} + if key == ParamDBKey.WRAPPER: + return _ParamWrapper.from_json(data[0], _decode_json(data[1])) + if key == ParamDBKey.PARAM: + class_name = data[0] + param_class = get_param_class(class_name) + if param_class is not None: + return param_class.from_json(data[1], _decode_json(data[2])) + raise ValueError( + f"ParamData class '{class_name}' is not known to ParamDB, so the load" + " failed" + ) + return json_data + + def _encode(obj: Any) -> bytes: """Encode the given object into bytes that will be stored in the database.""" # pylint: disable=no-member - return _compress(json.dumps(_preprocess_json(obj))) + return _compress(json.dumps(_encode_json(obj))) -def _decode(data: bytes, load_classes: bool) -> Any: +def _decode(data: bytes, decode_json: bool) -> Any: """ Decode an object from the given data from the database. Classes will be loaded in if ``load_classes`` is True; otherwise, classes will be loaded as dictionaries. """ - return json.loads( - _decompress(data), - object_hook=_from_dict if load_classes else None, - ) + json_data = json.loads(_decompress(data)) + return _decode_json(json_data) if decode_json else json_data class _Base(MappedAsDataclass, DeclarativeBase): @@ -267,32 +285,44 @@ def num_commits(self) -> int: @overload def load( - self, commit_id: int | None = None, *, load_classes: Literal[True] = True + self, commit_id: int | None = None, *, decode_json: Literal[True] = True ) -> DataT: ... @overload def load( - self, commit_id: int | None = None, *, load_classes: Literal[False] + self, commit_id: int | None = None, *, decode_json: Literal[False] ) -> Any: ... - def load(self, commit_id: int | None = None, *, load_classes: bool = True) -> Any: + def load(self, commit_id: int | None = None, *, decode_json: bool = True) -> Any: """ Load and return data from the database. If a commit ID is given, load from that commit; otherwise, load from the most recent commit. Raise an ``IndexError`` if the specified commit does not exist. Note that commit IDs begin at 1. - By default, parameter data, ``datetime``, and Astropy ``Quantity`` classes are - reconstructed. The relevant parameter data classes must be defined in the - current program. However, if ``load_classes`` is False, classes are loaded - directly from the database as dictionaries with the class name in the key - :py:const:`~paramdb._database.CLASS_NAME_KEY`. + By default, objects are reconstructed, which requires the relevant parameter + data classes to be defined in the current program. However, if ``decode_json`` + is False, the encoded JSON data is loaded directly from the database. The format + of the encoded data is as follows (see :py:class:`ParamDBKey` for key codes):: + + json_data: + | int + | float + | bool + | str + | None + | [ParamDBKey.DATETIME, float] + | [ParamDBKey.QUANTITY, float, str] + | [ParamDBKey.LIST, [json_data, ...]] + | [ParamDBKey.DICT, {str: json_data, ...}] + | [ParamDBKey.WRAPPED, float, json_data] + | [ParamDBKey.PARAM, str, float, json_data] """ select_stmt = self._select_commit(select(_Snapshot.data), commit_id) with self._Session() as session: data = session.scalar(select_stmt) if data is None: raise self._index_error(commit_id) - return _decode(data, load_classes) + return _decode(data, decode_json) def load_commit_entry(self, commit_id: int | None = None) -> CommitEntry: """ @@ -330,7 +360,7 @@ def commit_history_with_data( start: int | None = None, end: int | None = None, *, - load_classes: Literal[True] = True, + decode_json: Literal[True] = True, ) -> list[CommitEntryWithData[DataT]]: ... @overload @@ -339,7 +369,7 @@ def commit_history_with_data( start: int | None = None, end: int | None = None, *, - load_classes: Literal[False], + decode_json: Literal[False], ) -> list[CommitEntryWithData[Any]]: ... def commit_history_with_data( @@ -347,14 +377,14 @@ def commit_history_with_data( start: int | None = None, end: int | None = None, *, - load_classes: bool = True, + decode_json: bool = True, ) -> list[CommitEntryWithData[Any]]: """ Retrieve the commit history with data as a list of :py:class:`CommitEntryWithData` objects between the provided start and end indices, which work like slicing a Python list. - See :py:meth:`ParamDB.load` for the behavior of ``load_classes``. + See :py:meth:`ParamDB.load` for the behavior of ``decode_json``. """ with self._Session() as session: select_stmt = self._select_slice(select(_Snapshot), start, end) @@ -364,7 +394,7 @@ def commit_history_with_data( snapshot.id, snapshot.message, snapshot.timestamp, - _decode(snapshot.data, load_classes), + _decode(snapshot.data, decode_json), ) for snapshot in snapshots ] diff --git a/paramdb/_param_data/_collections.py b/paramdb/_param_data/_collections.py index b8f31d3..0f1d3fd 100644 --- a/paramdb/_param_data/_collections.py +++ b/paramdb/_param_data/_collections.py @@ -35,7 +35,7 @@ def __repr__(self) -> str: # _ParamWrapper objects return f"{type(self).__name__}({type(self._contents)(self)})" - def _to_json(self) -> _CollectionT: + def to_json(self) -> _CollectionT: return self._contents def _get_wrapped_child(self, child_name: _ChildNameT) -> ParamData[Any]: diff --git a/paramdb/_param_data/_dataclasses.py b/paramdb/_param_data/_dataclasses.py index 8054cc8..9c158fd 100644 --- a/paramdb/_param_data/_dataclasses.py +++ b/paramdb/_param_data/_dataclasses.py @@ -174,7 +174,7 @@ def _get_wrapped_child(self, child_name: str) -> ParamData[Any]: return cast(ParamData[Any], super().__getattribute__(child_name)) return super()._get_wrapped_child(child_name) - def _to_json(self) -> dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { field.name: super(ParamData, self).__getattribute__(field.name) for field in fields(self) # type: ignore[arg-type] diff --git a/paramdb/_param_data/_param_data.py b/paramdb/_param_data/_param_data.py index 6c82a18..7f0137b 100644 --- a/paramdb/_param_data/_param_data.py +++ b/paramdb/_param_data/_param_data.py @@ -101,52 +101,6 @@ def _update_last_updated(self) -> None: super(ParamData, current).__setattr__("_last_updated", new_last_updated) current = current._parent - @abstractmethod - def _to_json(self) -> Any: - """ - Convert the data stored in this object into a JSON serializable format, which - will later be passed to ``self._data_from_json()`` to reconstruct this object. - - The last updated timestamp is handled separately and does not need to be saved - here. - - Note that objects within a list or dictionary returned by this function do not - need to be JSON serializable, since they will be processed recursively. - """ - - def _init_from_json(self, json_data: Any) -> None: - """ - Initialize a new parameter data object from the given JSON data, usually created - by ``json.loads()`` and originally constructed by ``self._data_to_json()``. By - default, this method will pass the JSON data to ``self.__init__()``. - - The object will be generated by ``self.__new__()``, but ``self.__init__()`` has - not been called and ``self._last_updated_frozen`` is set to False. The last - updated timestamp is handled separately and does not need to be set here. - """ - # pylint: disable-next=unnecessary-dunder-call - self.__init__(json_data) # type: ignore[misc] - - def to_dict(self) -> dict[str, Any]: - """ - Return a dictionary representation of this parameter data object, which can be - used to reconstruct the object by passing it to :py:meth:`from_dict`. - """ - return {_LAST_UPDATED_KEY: self._last_updated, _DATA_KEY: self._to_json()} - - @classmethod - def from_dict(cls, data_dict: dict[str, Any]) -> Self: - """ - Construct a parameter data object from the given dictionary, usually created by - ``json.loads()`` and originally constructed by :py:meth:`from_dict`. - """ - param_data = cls.__new__(cls) - super().__setattr__(param_data, "_last_updated_frozen", True) - param_data._init_from_json(data_dict[_DATA_KEY]) - super().__setattr__(param_data, "_last_updated", data_dict[_LAST_UPDATED_KEY]) - super().__setattr__(param_data, "_last_updated_frozen", False) - return param_data - @property def last_updated(self) -> datetime: """When any parameter within this parameter data was last updated.""" @@ -196,6 +150,48 @@ def root(self) -> ParamData[Any]: root = root._parent return root + @abstractmethod + def to_json(self) -> Any: + """ + Convert the data stored in this object into a JSON serializable format, which + will later be passed to :py:meth:`from_json()` to reconstruct this object. + + The last updated timestamp is handled separately and does not need to be saved + here. + + Note that objects within a list or dictionary returned by this function do not + need to be JSON serializable, since they will be processed recursively. + """ + + def _init_from_json(self, json_data: Any) -> None: + """ + Initialize a new parameter data object from the given JSON data, usually created + by ``json.loads()`` and originally constructed by ``to_json()``. By + default, this method will pass the JSON data to ``__init__()``. + + The object will be generated by ``__new__()``, but ``__init__()`` has not been + called and ``self._last_updated_frozen`` is set to False. The last updated + timestamp is handled separately and does not need to be set here. + """ + # pylint: disable-next=unnecessary-dunder-call + self.__init__(json_data) # type: ignore[misc] + + @classmethod + def from_json(cls, last_updated_timestamp: float, json_data: list[Any]) -> Self: + """ + Construct a parameter data object from the given last updated timestamp and JSON + data originally constructed by :py:meth:`to_json`. + """ + last_updated = datetime.fromtimestamp( + last_updated_timestamp, timezone.utc + ).astimezone() + param_data = cls.__new__(cls) + super().__setattr__(param_data, "_last_updated_frozen", True) + param_data._init_from_json(json_data) + super().__setattr__(param_data, "_last_updated", last_updated) + super().__setattr__(param_data, "_last_updated_frozen", False) + return param_data + class _ParamWrapper(ParamData[Never], Generic[_T]): """ @@ -206,7 +202,8 @@ def __init__(self, value: _T) -> None: super().__init__() self.value = value - def _to_json(self) -> _T: + # pylint: disable-next=missing-function-docstring + def to_json(self) -> _T: return self.value def __eq__(self, other: Any) -> bool: diff --git a/tests/_param_data/test_param_data.py b/tests/_param_data/test_param_data.py index 92d19d8..555df76 100644 --- a/tests/_param_data/test_param_data.py +++ b/tests/_param_data/test_param_data.py @@ -93,14 +93,14 @@ def test_child_does_not_change(param_data: ParamData[Any]) -> None: assert param_data == param_data_original -def test_to_and_from_dict(param_data: ParamData[Any]) -> None: - """Parameter data can be converted to and from a dictionary.""" - param_data_dict = param_data.to_dict() - assert isinstance(param_data_dict, dict) +def test_to_and_from_json(param_data: ParamData[Any]) -> None: + """Parameter data can be converted to and from JSON data.""" + timestamp = param_data.last_updated.timestamp() + json_data = param_data.to_json() with capture_start_end_times(): - param_data_from_dict = param_data.from_dict(param_data_dict) - assert param_data_from_dict == param_data - assert param_data_from_dict.last_updated == param_data.last_updated + param_data_from_json = param_data.from_json(timestamp, json_data) + assert param_data_from_json == param_data + assert param_data_from_json.last_updated == param_data.last_updated def test_no_parent_fails(param_data: ParamData[Any]) -> None: diff --git a/tests/helpers.py b/tests/helpers.py index 6136027..d052ca8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -119,7 +119,7 @@ def assert_param_data_strong_equals( # pylint: disable=protected-access assert param_data == other_param_data assert param_data.last_updated == other_param_data.last_updated - assert param_data.to_dict() == other_param_data.to_dict() + assert param_data.to_json() == other_param_data.to_json() if child_name is not None: assert param_data.child_last_updated( child_name diff --git a/tests/test_database.py b/tests/test_database.py index 33fcb94..40e2b72 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -32,7 +32,7 @@ ParamDB, CommitEntry, CommitEntryWithData, - CLASS_NAME_KEY, + ParamDBKey, ) from paramdb._param_data._param_data import _param_classes @@ -78,8 +78,8 @@ class NotJSONSerializable: param_db.commit("Initial commit", data) assert ( str(exc_info.value) - == f"'{NotJSONSerializable.__module__}.{NotJSONSerializable.__name__}' object" - f" {repr(data)} is not JSON serializable, so the commit failed" + == f"'{NotJSONSerializable.__name__}' object {repr(data)} is not JSON" + " serializable, so the commit failed" ) @@ -95,7 +95,8 @@ def test_load_unknown_class_fails(db_path: str) -> None: param_db.load() assert ( str(exc_info.value) - == f"class '{Unknown.__name__}' is not known to ParamDB, so the load failed" + == f"ParamData class '{Unknown.__name__}' is not known to ParamDB, so the load" + " failed" ) @@ -220,43 +221,36 @@ def test_update_timestamp_after_load( ) -def test_load_classes_false(db_path: str, param_data: ParamData[Any]) -> None: - """Can load data as dictionaries if ``load_classes`` is false.""" +def test_decode_json_false(db_path: str, param_data: ParamData[Any]) -> None: + """Can load raw JSON data if ``decode_json`` is false.""" param_db = ParamDB[ParamData[Any]](db_path) param_db.commit("Initial commit", param_data) - data_loaded = param_db.load(load_classes=False) - data_from_history = param_db.commit_history_with_data(load_classes=False)[0].data + data_loaded = param_db.load(decode_json=False) + data_from_history = param_db.commit_history_with_data(decode_json=False)[0].data for data in data_loaded, data_from_history: # Check that loaded dictionary has the correct type and keys - assert isinstance(data, dict) - assert data.pop(CLASS_NAME_KEY) == type(param_data).__name__ - param_data_dict = param_data.to_dict() - assert data.keys() == param_data_dict.keys() - - # Check that loaded dictionary has the correct values - for key, value in data.items(): - value_from_param_data = param_data_dict[key] - if isinstance(value_from_param_data, ParamData): - assert isinstance(value, dict) - assert value.pop(CLASS_NAME_KEY) == type(value_from_param_data).__name__ - assert value.keys() == value_from_param_data.to_dict().keys() - else: - if isinstance(value, list): - assert isinstance(value_from_param_data, list) - assert len(value) == len(value_from_param_data) - elif isinstance(value, dict): - if CLASS_NAME_KEY in value: - value_class = type(value_from_param_data) - full_class_name = ( - f"{value_class.__module__}.{value_class.__name__}" - ) - assert value[CLASS_NAME_KEY] == full_class_name - else: - assert isinstance(value_from_param_data, dict) - assert value.keys() == value_from_param_data.keys() - else: - assert value == value_from_param_data + assert isinstance(data, list) + assert len(data) == 4 + key, class_name, timestamp, json_data = data + assert key == ParamDBKey.PARAM + assert class_name == type(param_data).__name__ + assert timestamp == param_data.last_updated.timestamp() + + # Check that the loaded JSON data has the correct data + assert isinstance(json_data, list) + json_data_key, json_data_data = json_data + param_json_data = param_data.to_json() + if json_data_key == "l": + assert isinstance(json_data_data, list) + assert isinstance(param_json_data, list) + assert len(json_data_data) == len(param_json_data) + elif json_data_key == "d": + assert isinstance(json_data_data, dict) + assert isinstance(param_json_data, dict) + assert json_data_data.keys() == param_json_data.keys() + else: + assert False # Currently, all param_data have list or dict data def test_load_classes_false_unknown_class(db_path: str) -> None: @@ -265,12 +259,12 @@ def test_load_classes_false_unknown_class(db_path: str) -> None: """ param_db = ParamDB[Unknown](db_path) param_db.commit("Initial commit", Unknown()) - data_loaded = param_db.load(load_classes=False) - data_from_history = param_db.commit_history_with_data(load_classes=False)[0].data - assert isinstance(data_loaded, dict) - assert data_loaded.pop(CLASS_NAME_KEY) == Unknown.__name__ - assert isinstance(data_from_history, dict) - assert data_from_history.pop(CLASS_NAME_KEY) == Unknown.__name__ + data_loaded = param_db.load(decode_json=False) + data_from_history = param_db.commit_history_with_data(decode_json=False)[0].data + assert isinstance(data_loaded, list) + assert data_loaded[1] == Unknown.__name__ + assert isinstance(data_from_history, list) + assert data_from_history[1] == Unknown.__name__ # pylint: disable-next=too-many-arguments,too-many-locals