Skip to content

Commit

Permalink
Merge pull request #180 from PainterQubits/#177-add-paramdata-primtives
Browse files Browse the repository at this point in the history
#177 Add ParamData Primitives
  • Loading branch information
alexhad6 authored May 3, 2024
2 parents 8807f0a + 8cb21c5 commit 7007087
Show file tree
Hide file tree
Showing 15 changed files with 732 additions and 75 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

- If Pydantic is installed, parameter data classes automatically have Pydantic type
validation enabled.
- Parameter primitives classes: `ParamInt`, `ParamFloat`, `ParamBool`, `ParamStr`, and
`ParamNone`.

### Changed

Expand Down
9 changes: 9 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@ All of the following can be imported from `paramdb`.

```{eval-rst}
.. autoclass:: ParamData
.. autoclass:: ParamInt
.. autoclass:: ParamFloat
.. autoclass:: ParamBool
.. autoclass:: ParamStr
.. autoclass:: ParamNone
.. autoclass:: ParamDataclass
.. autoclass:: ParamList
:no-members:
.. autoclass:: ParamDict
:no-members:
.. autoclass:: ParentType
:no-members:
.. autoclass:: RootType
:no-members:
```

## Database
Expand Down
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
# Autodoc options
# See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration
autodoc_default_options = {"members": True, "member-order": "bysource"}
autodoc_inherit_docstrings = False
autodoc_type_aliases = {
"ConvertibleToInt": "ConvertibleToInt",
"ConvertibleToFloat": "ConvertibleToFloat",
}
# autodoc_inherit_docstrings = False
add_module_names = False


Expand Down
90 changes: 63 additions & 27 deletions docs/parameter-data.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defines some core functionality for this data, including the
{py:class}`ParamData` are automatically registered with ParamDB so that they can be
loaded to and from JSON, which is how they are stored in the database.

All of the classes described on this page are subclasses of {py:class}`ParamData`.
All of the "Param" classes described on this page are subclasses of {py:class}`ParamData`.

```{important}
Any data that is going to be stored in a ParamDB database must be a JSON serializable
Expand All @@ -28,6 +28,46 @@ type (`str`, `int`, `float`, `bool`, `None`, `dict`, or `list`), a [`datetime`],
a `TypeError` will be raised when they are committed to the database.
```

## Primitives

Primitives are the building blocks of parameter data. While builtin primitive types can
be used in a ParamDB (`int`, `float`, `str`, `bool`, and `None`), they will not store a
{py:class}`~ParamData.last_updated` time and will not have {py:class}`~ParamData.parent`
or {py:class}`~ParamData.root` properties. When these features are desired, we can wrap
primitive values in the following types:

- {py:class}`ParamInt` for integers
- {py:class}`ParamFloat` for float
- {py:class}`ParamBool` for booleans
- {py:class}`ParamStr` for strings
- {py:class}`ParamNone` for `None`

For example:

```{jupyter-execute}
from paramdb import ParamInt
param_int = ParamInt(123)
param_int
```

```{jupyter-execute}
print(param_int.last_updated)
```

````{tip}
Methods from the builtin primitive types work on parameter primitives, with the caveat
that they return the builtin type. For example:
```{jupyter-execute}
param_int + 123
```
```{jupyter-execute}
type(param_int + 123)
```
````

## Data Classes

A parameter data class is defined from the base class {py:class}`ParamDataclass`. This
Expand All @@ -37,18 +77,18 @@ function is generated. An example of a defining a custom parameter Data Class is
below.

```{jupyter-execute}
from paramdb import ParamDataclass
from paramdb import ParamFloat, ParamDataclass
class CustomParam(ParamDataclass):
value: float
value: ParamFloat
custom_param = CustomParam(value=1.23)
custom_param = CustomParam(value=ParamFloat(1.23))
```

These properties can then be accessed and updated.

```{jupyter-execute}
custom_param.value += 0.004
custom_param.value = ParamFloat(1.234)
custom_param.value
```

Expand Down Expand Up @@ -85,13 +125,13 @@ decorator. For example:

```{jupyter-execute}
class ParamWithProperty(ParamDataclass):
value: int
value: ParamInt
@property
def value_cubed(self) -> int:
return self.value ** 3
param_with_property = ParamWithProperty(value=16)
param_with_property = ParamWithProperty(value=ParamInt(16))
param_with_property.value_cubed
```

Expand All @@ -115,15 +155,15 @@ Parameter data track when any of their properties were last updated, and this va
accessed by the read-only {py:attr}`~ParamData.last_updated` property. For example:

```{jupyter-execute}
custom_param.last_updated
print(custom_param.last_updated)
```

```{jupyter-execute}
import time
time.sleep(1)
custom_param.value += 1
custom_param.last_updated
custom_param.value = ParamFloat(4.56)
print(custom_param.last_updated)
```

Parameter dataclasses can also be nested, in which case the
Expand All @@ -136,14 +176,14 @@ class NestedParam(ParamDataclass):
value: float
child_param: CustomParam
nested_param = NestedParam(value=1.23, child_param=CustomParam(value=4.56))
nested_param.last_updated
nested_param = NestedParam(value=1.23, child_param=CustomParam(value=ParamFloat(4.56)))
print(nested_param.last_updated)
```

```{jupyter-execute}
time.sleep(1)
nested_param.child_param.value += 1
nested_param.last_updated
nested_param.child_param.value = ParamFloat(2)
print(nested_param.last_updated)
```

You can access the parent of any parameter data using the {py:attr}`ParamData.parent`
Expand Down Expand Up @@ -207,18 +247,18 @@ properly. For example:
```{jupyter-execute}
from paramdb import ParamList
param_list = ParamList([CustomParam(value=1), CustomParam(value=2), CustomParam(value=3)])
param_list = ParamList([ParamInt(1), ParamInt(2), ParamInt(3)])
param_list[1].parent is param_list
```

```{jupyter-execute}
param_list.last_updated
print(param_list.last_updated)
```

```{jupyter-execute}
time.sleep(1)
param_list[1].value += 1
param_list.last_updated
param_list[1] = ParamInt(4)
print(param_list.last_updated)
```

### Parameter Dictionaries
Expand All @@ -231,28 +271,24 @@ example:
```{jupyter-execute}
from paramdb import ParamDict
param_dict = ParamDict(
p1=CustomParam(value=1.23),
p2=CustomParam(value=4.56),
p3=CustomParam(value=7.89),
)
param_dict = ParamDict(p1=ParamFloat(1.23), p2=ParamFloat(4.56), p3=ParamFloat(7.89))
param_dict.p2.root == param_dict
```

```{jupyter-execute}
param_list.last_updated
print(param_dict.last_updated)
```

```{jupyter-execute}
time.sleep(1)
param_list[1].value += 1
param_list.last_updated
param_dict.p2 = ParamFloat(0)
print(param_dict.last_updated)
```

Parameter collections can also be subclassed to provide custom functionality. For example:

```{jupyter-execute}
class CustomDict(ParamDict[CustomParam]):
class CustomDict(ParamDict[ParamFloat]):
@property
def total(self) -> float:
return sum(param.value for param in self.values())
Expand Down
12 changes: 12 additions & 0 deletions paramdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
"""Python package for storing and retrieving experiment parameters."""

from paramdb._param_data._param_data import ParamData
from paramdb._param_data._primitives import (
ParamInt,
ParamBool,
ParamFloat,
ParamStr,
ParamNone,
)
from paramdb._param_data._dataclasses import ParamDataclass
from paramdb._param_data._collections import ParamList, ParamDict
from paramdb._param_data._type_mixins import ParentType, RootType
from paramdb._database import CLASS_NAME_KEY, ParamDB, CommitEntry, CommitEntryWithData

__all__ = [
"ParamData",
"ParamInt",
"ParamBool",
"ParamFloat",
"ParamStr",
"ParamNone",
"ParamDataclass",
"ParamList",
"ParamDict",
Expand Down
58 changes: 32 additions & 26 deletions paramdb/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_ASTROPY_INSTALLED = False

T = TypeVar("T")
SelectT = TypeVar("SelectT", bound=Select[Any])
_SelectT = TypeVar("_SelectT", bound=Select[Any])

CLASS_NAME_KEY = "__type"
"""
Expand All @@ -51,27 +51,6 @@ def _full_class_name(cls: type) -> str:
return f"{cls.__module__}.{cls.__name__}"


def _to_dict(obj: Any) -> Any:
"""
Convert the given object into a dictionary to be passed to ``json.dumps()``.
Note that objects within the dictionary do not need to be JSON serializable,
since they will be recursively processed by ``json.dumps()``.
"""
class_full_name = _full_class_name(type(obj))
class_full_name_dict = {CLASS_NAME_KEY: class_full_name}
if isinstance(obj, datetime):
return class_full_name_dict | {"isoformat": obj.isoformat()}
if _ASTROPY_INSTALLED and isinstance(obj, Quantity):
return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)}
if isinstance(obj, ParamData):
return {CLASS_NAME_KEY: type(obj).__name__} | obj.to_dict()
raise TypeError(
f"'{class_full_name}' object {repr(obj)} is not JSON serializable, so the"
" commit failed"
)


def _from_dict(json_dict: dict[str, Any]) -> Any:
"""
If the given dictionary created by ``json.loads()`` has the key ``CLASS_NAME_KEY``,
Expand All @@ -96,9 +75,36 @@ def _from_dict(json_dict: dict[str, Any]) -> Any:
)


def _preprocess_json(obj: Any) -> Any:
"""
Preprocess the given object and its children into a JSON-serializable format.
Compared with ``json.dumps()``, this function can define custom logic for dealing
with subclasses of ``int``, ``float``, and ``str``.
"""
if isinstance(obj, ParamData):
return {CLASS_NAME_KEY: type(obj).__name__} | _preprocess_json(obj.to_dict())
if isinstance(obj, (int, float, bool, str)) or obj is None:
return obj
if isinstance(obj, (list, tuple)):
return [_preprocess_json(value) for value in obj]
if isinstance(obj, dict):
return {key: _preprocess_json(value) for key, value in obj.items()}
class_full_name = _full_class_name(type(obj))
class_full_name_dict = {CLASS_NAME_KEY: class_full_name}
if isinstance(obj, datetime):
return class_full_name_dict | {"isoformat": obj.isoformat()}
if _ASTROPY_INSTALLED and isinstance(obj, Quantity):
return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)}
raise TypeError(
f"'{class_full_name}' object {repr(obj)} is not JSON serializable, so the"
" commit failed"
)


def _encode(obj: Any) -> bytes:
"""Encode the given object into bytes that will be stored in the database."""
return _compress(json.dumps(obj, default=_to_dict))
# pylint: disable=no-member
return _compress(json.dumps(_preprocess_json(obj)))


def _decode(data: bytes, load_classes: bool) -> Any:
Expand Down Expand Up @@ -194,7 +200,7 @@ def _index_error(self, commit_id: int | None) -> IndexError:
else f"commit {commit_id} does not exist in database" f" '{self._path}'"
)

def _select_commit(self, select_stmt: SelectT, commit_id: int | None) -> SelectT:
def _select_commit(self, select_stmt: _SelectT, commit_id: int | None) -> _SelectT:
"""
Modify the given ``_Snapshot`` select statement to return the commit specified
by the given commit ID, or the latest commit if the commit ID is None.
Expand All @@ -206,8 +212,8 @@ def _select_commit(self, select_stmt: SelectT, commit_id: int | None) -> SelectT
)

def _select_slice(
self, select_stmt: SelectT, start: int | None, end: int | None
) -> SelectT:
self, select_stmt: _SelectT, start: int | None, end: int | None
) -> _SelectT:
"""
Modify the given Snapshot select statement to sort by commit ID and return the
slice specified by the given start and end indices.
Expand Down
10 changes: 5 additions & 5 deletions paramdb/_param_data/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from paramdb._param_data._param_data import ParamData

T = TypeVar("T")
CollectionT = TypeVar("CollectionT", bound=Collection[Any])
_CollectionT = TypeVar("_CollectionT", bound=Collection[Any])


# pylint: disable-next=abstract-method
class _ParamCollection(ParamData, Generic[CollectionT]):
class _ParamCollection(ParamData, Generic[_CollectionT]):
"""Base class for parameter collections."""

_contents: CollectionT
_contents: _CollectionT

def __len__(self) -> int:
return len(self._contents)
Expand All @@ -41,12 +41,12 @@ def __eq__(self, other: Any) -> bool:
def __repr__(self) -> str:
return f"{type(self).__name__}({self._contents})"

def _to_json(self) -> CollectionT:
def _to_json(self) -> _CollectionT:
return self._contents

@classmethod
@abstractmethod
def _from_json(cls, json_data: CollectionT) -> Self: ...
def _from_json(cls, json_data: _CollectionT) -> Self: ...


class ParamList(_ParamCollection[list[T]], MutableSequence[T], Generic[T]):
Expand Down
Loading

0 comments on commit 7007087

Please sign in to comment.