Skip to content

Commit

Permalink
refactor: type magic to improve TOMLDataclass
Browse files Browse the repository at this point in the history
tests: TOML tests in test_serializable.py, code cleanup
  • Loading branch information
Jeremy Silver committed Apr 14, 2024
1 parent 437aa5d commit b48298f
Show file tree
Hide file tree
Showing 13 changed files with 894 additions and 661 deletions.
9 changes: 7 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
## v0.3.0

- TOMLDataclass
- Value conversions (borrow from JSON)
- Unit tests (also for `ConfigDataclass.load_config`)
- Unit tests
- Clean up JSON/dict tests (organize into classes?)
- Conflict with multiple inheritance from JSON/TOMLDataclass with `_to_file`?
- Can't do list of optionals with None included
- `ConfigDataclass.load_config` from TOML
- Basic usage examples in docs
- Host on GH Pages or Readthedocs
- Link to actual page from README
- Make PyPI page link to the hosted docs as well as Github
- Github Actions for automated testing
- Configure as much as possible via `hatch`
Expand Down Expand Up @@ -46,6 +50,7 @@
- Test `None` when it's not the default value (can break round-trip fidelity)
- `TabularDataclass`? CSV/TSV/parquet/feather
- Make `SQLDataclass` inherit from it
- Convert to/from `pandas` `Series` and `DataFrame`?
- Support subparsers in `ArgparseDataclass`
- Field metadata
- Be strict about unknown field metadata keys? (Maybe issue warning?)
Expand Down
7 changes: 6 additions & 1 deletion docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ Types of changes:
### Added

- `TOMLDataclass` for saving/loading TOML via [`tomlkit`](https://tomlkit.readthedocs.io/en/latest/)
- `FileSerializable` and `DictFileSerializableDataclass` mixins to factor shared behavior between JSON/TOML serialization
- Support for loading TOML configurations in `ConfigDataclass`
- `FileSerializable` and `DictFileSerializableDataclass` mixins to factor out shared functionality between JSON/TOML serialization

## [0.2.0]

2024-04-13

### Added

- `ConfigDataclass` mixin for global configurations
Expand All @@ -49,6 +52,8 @@ Types of changes:

## [0.1.0]

2022-06-06

### Added

- `DataclassMixin` class providing extra dataclass features
Expand Down
3 changes: 1 addition & 2 deletions docs/gen_ref_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
PKG_DIR = Path(PKG_NAME)
REF_DIR = Path('reference')

# TODO: add config, toml, etc.
REF_MODULES = ['cli', 'dict', 'json', 'mixin', 'sql', 'subprocess', 'utils']
REF_MODULES = ['cli', 'config', 'dict', 'json', 'mixin', 'serialize', 'sql', 'subprocess', 'toml', 'utils']

for mod_name in REF_MODULES:
path = PKG_DIR / f'{mod_name}.py'
Expand Down
4 changes: 2 additions & 2 deletions fancy_dataclass/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Self

from fancy_dataclass.mixin import DataclassMixin, FieldSettings
from fancy_dataclass.utils import check_dataclass, issubclass_safe
from fancy_dataclass.utils import check_dataclass, issubclass_safe, type_is_optional


T = TypeVar('T')
Expand Down Expand Up @@ -137,7 +137,7 @@ def configure_argument(cls, parser: ArgumentParser, name: str) -> None:
action = field.metadata.get('action', 'store')
origin_type = get_origin(tp)
if origin_type is not None: # compound type
if (origin_type == Union) and (getattr(tp, '_name', None) == 'Optional'):
if type_is_optional(tp):
kwargs['default'] = None
if origin_type == ClassVar: # by default, exclude ClassVars from the parser
return
Expand Down
32 changes: 16 additions & 16 deletions fancy_dataclass/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from contextlib import contextmanager
from copy import copy
from dataclasses import make_dataclass
from pathlib import Path
from typing import ClassVar, Iterator, Optional, Type
Expand All @@ -9,7 +8,7 @@
from fancy_dataclass.dict import DictDataclass
from fancy_dataclass.mixin import DataclassMixin
from fancy_dataclass.serialize import FileSerializable
from fancy_dataclass.utils import AnyPath, coerce_to_dataclass, get_dataclass_fields
from fancy_dataclass.utils import AnyPath, coerce_to_dataclass, dataclass_type_map, get_dataclass_fields


class Config:
Expand All @@ -21,7 +20,10 @@ class Config:

@classmethod
def get_config(cls) -> Optional[Self]:
"""Gets the current global configuration."""
"""Gets the current global configuration.
Returns:
Global configuration object (`None` if not set)"""
return cls._config # type: ignore[return-value]

@classmethod
Expand Down Expand Up @@ -61,18 +63,13 @@ class ConfigDataclass(Config, DictDataclass, suppress_defaults=False):
@staticmethod
def _wrap_config_dataclass(mixin_cls: Type[DataclassMixin], cls: Type['ConfigDataclass']) -> Type[DataclassMixin]:
"""Recursively wraps a DataclassMixin class around a ConfigDataclass so that nested ConfigDataclass fields inherit from the same mixin."""
wrapped_cls = mixin_cls.wrap_dataclass(cls)
field_data = []
for fld in get_dataclass_fields(cls, include_classvars=True):
if issubclass(fld.type, ConfigDataclass):
tp = ConfigDataclass._wrap_config_dataclass(mixin_cls, fld.type)
new_fld = copy(fld)
new_fld.type = tp
else:
tp = fld.type
new_fld = fld
field_data.append((fld.name, tp, new_fld))
return make_dataclass(cls.__name__, field_data, bases=wrapped_cls.__bases__)
def _wrap(tp: type) -> type:
if issubclass(tp, ConfigDataclass):
wrapped_cls = mixin_cls.wrap_dataclass(tp)
field_data = [(fld.name, fld.type, fld) for fld in get_dataclass_fields(tp, include_classvars=True)]
return make_dataclass(tp.__name__, field_data, bases=wrapped_cls.__bases__)
return tp
return _wrap(dataclass_type_map(cls, _wrap)) # type: ignore[arg-type]

@classmethod
def _get_dataclass_type_for_extension(cls, ext: str) -> Type[FileSerializable]:
Expand All @@ -90,8 +87,11 @@ def _get_dataclass_type_for_extension(cls, ext: str) -> Type[FileSerializable]:
def load_config(cls, path: AnyPath) -> Self:
"""Loads configurations from a file and sets them to be the global configurations for this class.
Args:
path: File from which to load configurations
Returns:
The newly loaded global configuration"""
The newly loaded global configurations"""
p = Path(path)
ext = p.suffix
if not ext:
Expand Down
33 changes: 20 additions & 13 deletions fancy_dataclass/dict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from abc import ABC, abstractmethod
from copy import copy
import dataclasses
from dataclasses import dataclass
from dataclasses import Field, dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Literal, Optional, Type, TypeVar, Union, _TypedDictMeta, get_args, get_origin # type: ignore[attr-defined]

from typing_extensions import Self, _AnnotatedAlias

from fancy_dataclass.mixin import DataclassMixin, DataclassMixinSettings, FieldSettings
from fancy_dataclass.utils import TypeConversionError, _flatten_dataclass, check_dataclass, fully_qualified_class_name, issubclass_safe, obj_class_name, safe_dict_insert
from fancy_dataclass.utils import TypeConversionError, _flatten_dataclass, check_dataclass, fully_qualified_class_name, issubclass_safe, obj_class_name, safe_dict_insert, type_is_optional


if TYPE_CHECKING:
Expand Down Expand Up @@ -250,7 +250,7 @@ def err() -> TypeConversionError:
return tuple(convert_val(subtype, elt) for elt in val)
return tuple(convert_val(subtype, elt) for (subtype, elt) in zip(args, val))
elif origin_type == Union:
if getattr(tp, '_name', None) == 'Optional':
if type_is_optional(tp):
assert len(args) == 2
assert args[1] is type(None)
args = args[::-1] # check None first
Expand All @@ -271,6 +271,10 @@ def err() -> TypeConversionError:
return type(val)(convert_val(subtype, elt) for elt in val)
raise err()

@classmethod
def _get_missing_value(cls, fld: Field) -> Any: # type: ignore[type-arg]
raise ValueError(f'{fld.name!r} field is required')

@classmethod
def _dataclass_args_from_dict(cls, d: AnyDict, strict: bool = False) -> AnyDict:
"""Given a dict of arguments, performs type conversion and/or validity checking, then returns a new dict that can be passed to the class's constructor."""
Expand All @@ -283,24 +287,27 @@ def _dataclass_args_from_dict(cls, d: AnyDict, strict: bool = False) -> AnyDict:
for key in d:
if (key not in field_names):
raise ValueError(f'{key!r} is not a valid field for {cls.__name__}')
for field in fields:
if not field.init: # suppress fields where init=False
for fld in fields:
if not fld.init: # suppress fields where init=False
continue
if field.name in d:
if fld.name in d:
# field may be defined in the dataclass itself or one of its ancestor dataclasses
for base in bases:
try:
field_type = base.__annotations__[field.name]
kwargs[field.name] = cls._from_dict_value(field_type, d[field.name], strict=strict)
field_type = base.__annotations__[fld.name]
kwargs[fld.name] = cls._from_dict_value(field_type, d[fld.name], strict=strict)
break
except (AttributeError, KeyError):
pass
else:
raise ValueError(f'could not locate field {field.name!r}')
elif field.default == dataclasses.MISSING:
if field.default_factory == dataclasses.MISSING:
raise ValueError(f'{field.name!r} field is required')
kwargs[field.name] = field.default_factory()
raise ValueError(f'could not locate field {fld.name!r}')
elif fld.default == dataclasses.MISSING:
if fld.default_factory == dataclasses.MISSING:
val = cls._get_missing_value(fld)
else:
val = fld.default_factory()
# raise ValueError(f'{fld.name!r} field is required')
kwargs[fld.name] = val
return kwargs

@classmethod
Expand Down
29 changes: 4 additions & 25 deletions fancy_dataclass/json.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datetime import datetime
from enum import Enum
import json
from json import JSONEncoder
from typing import Any, BinaryIO, TextIO, Type, cast, get_args, get_origin

from typing_extensions import Self

from fancy_dataclass.dict import AnyDict
from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable
from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable, from_dict_value_basic, to_dict_value_basic
from fancy_dataclass.utils import AnyIO, TypeConversionError


Expand Down Expand Up @@ -111,20 +110,9 @@ def _text_file_to_dict(cls, fp: TextIO, **kwargs: Any) -> AnyDict:

@classmethod
def _to_dict_value_basic(cls, val: Any) -> Any:
if isinstance(val, Enum):
return val.value
elif isinstance(val, range): # store the range bounds
bounds = [val.start, val.stop]
if val.step != 1:
bounds.append(val.step)
return bounds
elif isinstance(val, datetime):
if isinstance(val, datetime):
return val.isoformat()
elif isinstance(val, (int, float)): # handles numpy numeric types
return val
elif hasattr(val, 'dtype'): # assume it's a numpy array of numbers
return [float(elt) for elt in val]
return val
return to_dict_value_basic(val)

@classmethod
def _to_dict_value(cls, val: Any, full: bool) -> Any:
Expand All @@ -135,18 +123,9 @@ def _to_dict_value(cls, val: Any, full: bool) -> Any:

@classmethod
def _from_dict_value_basic(cls, tp: type, val: Any) -> Any:
if issubclass(tp, float):
return tp(val)
if issubclass(tp, range):
return tp(*val)
if issubclass(tp, datetime):
return tp.fromisoformat(val)
if issubclass(tp, Enum):
try:
return tp(val)
except ValueError as e:
raise TypeConversionError(tp, val) from e
return super()._from_dict_value_basic(tp, val)
return super()._from_dict_value_basic(tp, from_dict_value_basic(tp, val))

@classmethod
def _from_dict_value(cls, tp: type, val: Any, strict: bool = False) -> Any:
Expand Down
43 changes: 42 additions & 1 deletion fancy_dataclass/serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,52 @@
from abc import ABC, abstractmethod
from enum import Enum
from io import StringIO, TextIOBase
from typing import Any, BinaryIO, TextIO

from typing_extensions import Self

from fancy_dataclass.dict import AnyDict, DictDataclass
from fancy_dataclass.utils import AnyIO
from fancy_dataclass.utils import AnyIO, TypeConversionError


def to_dict_value_basic(val: Any) -> Any:
"""Converts an arbitrary value with a basic data type to an appropriate form for serializing to typical file formats (JSON, TOML).
Args:
val: Value with basic data type
Returns:
A version of that value suitable for serialization"""
if isinstance(val, Enum):
return val.value
elif isinstance(val, range): # store the range bounds
bounds = [val.start, val.stop]
if val.step != 1:
bounds.append(val.step)
return bounds
elif hasattr(val, 'dtype'): # assume it's a numpy array of numbers
return [float(elt) for elt in val]
return val

def from_dict_value_basic(tp: type, val: Any) -> Any:
"""Converts a deserialized value to the given type.
Args:
tp: Target type to convert to
val: Deserialized value
Returns:
Converted value"""
if issubclass(tp, float):
return tp(val)
if issubclass(tp, range):
return tp(*val)
if issubclass(tp, Enum):
try:
return tp(val)
except ValueError as e:
raise TypeConversionError(tp, val) from e
return val


class FileSerializable(ABC):
Expand Down
29 changes: 27 additions & 2 deletions fancy_dataclass/toml.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from dataclasses import Field
from typing import Any, TextIO

import tomlkit
from typing_extensions import Self

from fancy_dataclass.dict import AnyDict
from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable
from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable, from_dict_value_basic, to_dict_value_basic
from fancy_dataclass.utils import AnyIO


def _remove_null_dict_values(val: Any) -> Any:
"""Removes all null (None) values from a dict.
Does this recursively to any nested dicts or lists."""
if isinstance(val, (list, tuple)):
return type(val)(_remove_null_dict_values(elt) for elt in val)
if isinstance(val, dict):
return type(val)({key: _remove_null_dict_values(elt) for (key, elt) in val.items() if (elt is not None)})
return val


class TOMLSerializable(FileSerializable):
"""Mixin class enabling conversion of an object to/from TOML."""

Expand Down Expand Up @@ -59,11 +71,24 @@ class TOMLDataclass(DictFileSerializableDataclass, TOMLSerializable):
"""Dataclass mixin enabling default serialization of dataclass objects to and from TOML."""

# TODO: require subclass to set qualified_type=True, like JSONDataclass?

@classmethod
def _dict_to_text_file(cls, d: AnyDict, fp: TextIO, **kwargs: Any) -> None:
d = _remove_null_dict_values(d)
tomlkit.dump(d, fp, **kwargs)

@classmethod
def _text_file_to_dict(cls, fp: TextIO, **kwargs: Any) -> AnyDict:
return tomlkit.load(fp)

@classmethod
def _to_dict_value_basic(cls, val: Any) -> Any:
return to_dict_value_basic(val)

@classmethod
def _from_dict_value_basic(cls, tp: type, val: Any) -> Any:
return super()._from_dict_value_basic(tp, from_dict_value_basic(tp, val))

@classmethod
def _get_missing_value(cls, fld: Field) -> Any: # type: ignore[type-arg]
# replace any missing required fields with a default of None
return None
Loading

0 comments on commit b48298f

Please sign in to comment.