From b48298fded1da49e21ae7f0239493a5a768fe4ac Mon Sep 17 00:00:00 2001 From: Jeremy Silver Date: Sun, 14 Apr 2024 12:47:20 -0400 Subject: [PATCH] refactor: type magic to improve TOMLDataclass tests: TOML tests in test_serializable.py, code cleanup --- TODO.md | 9 +- docs/CHANGELOG.md | 7 +- docs/gen_ref_pages.py | 3 +- fancy_dataclass/cli.py | 4 +- fancy_dataclass/config.py | 32 +- fancy_dataclass/dict.py | 33 +- fancy_dataclass/json.py | 29 +- fancy_dataclass/serialize.py | 43 ++- fancy_dataclass/toml.py | 29 +- fancy_dataclass/utils.py | 66 +++- pyproject.toml | 11 +- tests/test_json.py | 588 ----------------------------- tests/test_serializable.py | 701 +++++++++++++++++++++++++++++++++++ 13 files changed, 894 insertions(+), 661 deletions(-) delete mode 100644 tests/test_json.py create mode 100644 tests/test_serializable.py diff --git a/TODO.md b/TODO.md index 7e6cd23..9ecf507 100644 --- a/TODO.md +++ b/TODO.md @@ -3,10 +3,14 @@ ## v0.3.0 - TOMLDataclass - - Value conversions (borrow from JSON) - - Unit tests (also for `ConfigDataclass.load_config`) + - Unit tests + - Clean up JSON/dict tests (organize into classes?) + - Conflict with multiple inheritance from JSON/TOMLDataclass with `_to_file`? + - Can't do list of optionals with None included + - `ConfigDataclass.load_config` from TOML - Basic usage examples in docs - Host on GH Pages or Readthedocs + - Link to actual page from README - Make PyPI page link to the hosted docs as well as Github - Github Actions for automated testing - Configure as much as possible via `hatch` @@ -46,6 +50,7 @@ - Test `None` when it's not the default value (can break round-trip fidelity) - `TabularDataclass`? CSV/TSV/parquet/feather - Make `SQLDataclass` inherit from it + - Convert to/from `pandas` `Series` and `DataFrame`? - Support subparsers in `ArgparseDataclass` - Field metadata - Be strict about unknown field metadata keys? (Maybe issue warning?) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 8b45cad..1bc49de 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -19,10 +19,13 @@ Types of changes: ### Added - `TOMLDataclass` for saving/loading TOML via [`tomlkit`](https://tomlkit.readthedocs.io/en/latest/) -- `FileSerializable` and `DictFileSerializableDataclass` mixins to factor shared behavior between JSON/TOML serialization + - Support for loading TOML configurations in `ConfigDataclass` +- `FileSerializable` and `DictFileSerializableDataclass` mixins to factor out shared functionality between JSON/TOML serialization ## [0.2.0] +2024-04-13 + ### Added - `ConfigDataclass` mixin for global configurations @@ -49,6 +52,8 @@ Types of changes: ## [0.1.0] +2022-06-06 + ### Added - `DataclassMixin` class providing extra dataclass features diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 60a086d..7613e54 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -14,8 +14,7 @@ PKG_DIR = Path(PKG_NAME) REF_DIR = Path('reference') -# TODO: add config, toml, etc. -REF_MODULES = ['cli', 'dict', 'json', 'mixin', 'sql', 'subprocess', 'utils'] +REF_MODULES = ['cli', 'config', 'dict', 'json', 'mixin', 'serialize', 'sql', 'subprocess', 'toml', 'utils'] for mod_name in REF_MODULES: path = PKG_DIR / f'{mod_name}.py' diff --git a/fancy_dataclass/cli.py b/fancy_dataclass/cli.py index 962cff3..e7a53eb 100644 --- a/fancy_dataclass/cli.py +++ b/fancy_dataclass/cli.py @@ -8,7 +8,7 @@ from typing_extensions import Self from fancy_dataclass.mixin import DataclassMixin, FieldSettings -from fancy_dataclass.utils import check_dataclass, issubclass_safe +from fancy_dataclass.utils import check_dataclass, issubclass_safe, type_is_optional T = TypeVar('T') @@ -137,7 +137,7 @@ def configure_argument(cls, parser: ArgumentParser, name: str) -> None: action = field.metadata.get('action', 'store') origin_type = get_origin(tp) if origin_type is not None: # compound type - if (origin_type == Union) and (getattr(tp, '_name', None) == 'Optional'): + if type_is_optional(tp): kwargs['default'] = None if origin_type == ClassVar: # by default, exclude ClassVars from the parser return diff --git a/fancy_dataclass/config.py b/fancy_dataclass/config.py index a9ee476..3e4c3f9 100644 --- a/fancy_dataclass/config.py +++ b/fancy_dataclass/config.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from copy import copy from dataclasses import make_dataclass from pathlib import Path from typing import ClassVar, Iterator, Optional, Type @@ -9,7 +8,7 @@ from fancy_dataclass.dict import DictDataclass from fancy_dataclass.mixin import DataclassMixin from fancy_dataclass.serialize import FileSerializable -from fancy_dataclass.utils import AnyPath, coerce_to_dataclass, get_dataclass_fields +from fancy_dataclass.utils import AnyPath, coerce_to_dataclass, dataclass_type_map, get_dataclass_fields class Config: @@ -21,7 +20,10 @@ class Config: @classmethod def get_config(cls) -> Optional[Self]: - """Gets the current global configuration.""" + """Gets the current global configuration. + + Returns: + Global configuration object (`None` if not set)""" return cls._config # type: ignore[return-value] @classmethod @@ -61,18 +63,13 @@ class ConfigDataclass(Config, DictDataclass, suppress_defaults=False): @staticmethod def _wrap_config_dataclass(mixin_cls: Type[DataclassMixin], cls: Type['ConfigDataclass']) -> Type[DataclassMixin]: """Recursively wraps a DataclassMixin class around a ConfigDataclass so that nested ConfigDataclass fields inherit from the same mixin.""" - wrapped_cls = mixin_cls.wrap_dataclass(cls) - field_data = [] - for fld in get_dataclass_fields(cls, include_classvars=True): - if issubclass(fld.type, ConfigDataclass): - tp = ConfigDataclass._wrap_config_dataclass(mixin_cls, fld.type) - new_fld = copy(fld) - new_fld.type = tp - else: - tp = fld.type - new_fld = fld - field_data.append((fld.name, tp, new_fld)) - return make_dataclass(cls.__name__, field_data, bases=wrapped_cls.__bases__) + def _wrap(tp: type) -> type: + if issubclass(tp, ConfigDataclass): + wrapped_cls = mixin_cls.wrap_dataclass(tp) + field_data = [(fld.name, fld.type, fld) for fld in get_dataclass_fields(tp, include_classvars=True)] + return make_dataclass(tp.__name__, field_data, bases=wrapped_cls.__bases__) + return tp + return _wrap(dataclass_type_map(cls, _wrap)) # type: ignore[arg-type] @classmethod def _get_dataclass_type_for_extension(cls, ext: str) -> Type[FileSerializable]: @@ -90,8 +87,11 @@ def _get_dataclass_type_for_extension(cls, ext: str) -> Type[FileSerializable]: def load_config(cls, path: AnyPath) -> Self: """Loads configurations from a file and sets them to be the global configurations for this class. + Args: + path: File from which to load configurations + Returns: - The newly loaded global configuration""" + The newly loaded global configurations""" p = Path(path) ext = p.suffix if not ext: diff --git a/fancy_dataclass/dict.py b/fancy_dataclass/dict.py index c3ce546..f00f5d6 100644 --- a/fancy_dataclass/dict.py +++ b/fancy_dataclass/dict.py @@ -1,14 +1,14 @@ from abc import ABC, abstractmethod from copy import copy import dataclasses -from dataclasses import dataclass +from dataclasses import Field, dataclass from functools import partial from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Literal, Optional, Type, TypeVar, Union, _TypedDictMeta, get_args, get_origin # type: ignore[attr-defined] from typing_extensions import Self, _AnnotatedAlias from fancy_dataclass.mixin import DataclassMixin, DataclassMixinSettings, FieldSettings -from fancy_dataclass.utils import TypeConversionError, _flatten_dataclass, check_dataclass, fully_qualified_class_name, issubclass_safe, obj_class_name, safe_dict_insert +from fancy_dataclass.utils import TypeConversionError, _flatten_dataclass, check_dataclass, fully_qualified_class_name, issubclass_safe, obj_class_name, safe_dict_insert, type_is_optional if TYPE_CHECKING: @@ -250,7 +250,7 @@ def err() -> TypeConversionError: return tuple(convert_val(subtype, elt) for elt in val) return tuple(convert_val(subtype, elt) for (subtype, elt) in zip(args, val)) elif origin_type == Union: - if getattr(tp, '_name', None) == 'Optional': + if type_is_optional(tp): assert len(args) == 2 assert args[1] is type(None) args = args[::-1] # check None first @@ -271,6 +271,10 @@ def err() -> TypeConversionError: return type(val)(convert_val(subtype, elt) for elt in val) raise err() + @classmethod + def _get_missing_value(cls, fld: Field) -> Any: # type: ignore[type-arg] + raise ValueError(f'{fld.name!r} field is required') + @classmethod def _dataclass_args_from_dict(cls, d: AnyDict, strict: bool = False) -> AnyDict: """Given a dict of arguments, performs type conversion and/or validity checking, then returns a new dict that can be passed to the class's constructor.""" @@ -283,24 +287,27 @@ def _dataclass_args_from_dict(cls, d: AnyDict, strict: bool = False) -> AnyDict: for key in d: if (key not in field_names): raise ValueError(f'{key!r} is not a valid field for {cls.__name__}') - for field in fields: - if not field.init: # suppress fields where init=False + for fld in fields: + if not fld.init: # suppress fields where init=False continue - if field.name in d: + if fld.name in d: # field may be defined in the dataclass itself or one of its ancestor dataclasses for base in bases: try: - field_type = base.__annotations__[field.name] - kwargs[field.name] = cls._from_dict_value(field_type, d[field.name], strict=strict) + field_type = base.__annotations__[fld.name] + kwargs[fld.name] = cls._from_dict_value(field_type, d[fld.name], strict=strict) break except (AttributeError, KeyError): pass else: - raise ValueError(f'could not locate field {field.name!r}') - elif field.default == dataclasses.MISSING: - if field.default_factory == dataclasses.MISSING: - raise ValueError(f'{field.name!r} field is required') - kwargs[field.name] = field.default_factory() + raise ValueError(f'could not locate field {fld.name!r}') + elif fld.default == dataclasses.MISSING: + if fld.default_factory == dataclasses.MISSING: + val = cls._get_missing_value(fld) + else: + val = fld.default_factory() + # raise ValueError(f'{fld.name!r} field is required') + kwargs[fld.name] = val return kwargs @classmethod diff --git a/fancy_dataclass/json.py b/fancy_dataclass/json.py index 41bf963..2112773 100644 --- a/fancy_dataclass/json.py +++ b/fancy_dataclass/json.py @@ -1,5 +1,4 @@ from datetime import datetime -from enum import Enum import json from json import JSONEncoder from typing import Any, BinaryIO, TextIO, Type, cast, get_args, get_origin @@ -7,7 +6,7 @@ from typing_extensions import Self from fancy_dataclass.dict import AnyDict -from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable +from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable, from_dict_value_basic, to_dict_value_basic from fancy_dataclass.utils import AnyIO, TypeConversionError @@ -111,20 +110,9 @@ def _text_file_to_dict(cls, fp: TextIO, **kwargs: Any) -> AnyDict: @classmethod def _to_dict_value_basic(cls, val: Any) -> Any: - if isinstance(val, Enum): - return val.value - elif isinstance(val, range): # store the range bounds - bounds = [val.start, val.stop] - if val.step != 1: - bounds.append(val.step) - return bounds - elif isinstance(val, datetime): + if isinstance(val, datetime): return val.isoformat() - elif isinstance(val, (int, float)): # handles numpy numeric types - return val - elif hasattr(val, 'dtype'): # assume it's a numpy array of numbers - return [float(elt) for elt in val] - return val + return to_dict_value_basic(val) @classmethod def _to_dict_value(cls, val: Any, full: bool) -> Any: @@ -135,18 +123,9 @@ def _to_dict_value(cls, val: Any, full: bool) -> Any: @classmethod def _from_dict_value_basic(cls, tp: type, val: Any) -> Any: - if issubclass(tp, float): - return tp(val) - if issubclass(tp, range): - return tp(*val) if issubclass(tp, datetime): return tp.fromisoformat(val) - if issubclass(tp, Enum): - try: - return tp(val) - except ValueError as e: - raise TypeConversionError(tp, val) from e - return super()._from_dict_value_basic(tp, val) + return super()._from_dict_value_basic(tp, from_dict_value_basic(tp, val)) @classmethod def _from_dict_value(cls, tp: type, val: Any, strict: bool = False) -> Any: diff --git a/fancy_dataclass/serialize.py b/fancy_dataclass/serialize.py index d55226b..90e9300 100644 --- a/fancy_dataclass/serialize.py +++ b/fancy_dataclass/serialize.py @@ -1,11 +1,52 @@ from abc import ABC, abstractmethod +from enum import Enum from io import StringIO, TextIOBase from typing import Any, BinaryIO, TextIO from typing_extensions import Self from fancy_dataclass.dict import AnyDict, DictDataclass -from fancy_dataclass.utils import AnyIO +from fancy_dataclass.utils import AnyIO, TypeConversionError + + +def to_dict_value_basic(val: Any) -> Any: + """Converts an arbitrary value with a basic data type to an appropriate form for serializing to typical file formats (JSON, TOML). + + Args: + val: Value with basic data type + + Returns: + A version of that value suitable for serialization""" + if isinstance(val, Enum): + return val.value + elif isinstance(val, range): # store the range bounds + bounds = [val.start, val.stop] + if val.step != 1: + bounds.append(val.step) + return bounds + elif hasattr(val, 'dtype'): # assume it's a numpy array of numbers + return [float(elt) for elt in val] + return val + +def from_dict_value_basic(tp: type, val: Any) -> Any: + """Converts a deserialized value to the given type. + + Args: + tp: Target type to convert to + val: Deserialized value + + Returns: + Converted value""" + if issubclass(tp, float): + return tp(val) + if issubclass(tp, range): + return tp(*val) + if issubclass(tp, Enum): + try: + return tp(val) + except ValueError as e: + raise TypeConversionError(tp, val) from e + return val class FileSerializable(ABC): diff --git a/fancy_dataclass/toml.py b/fancy_dataclass/toml.py index aa5e858..b0a7c0a 100644 --- a/fancy_dataclass/toml.py +++ b/fancy_dataclass/toml.py @@ -1,13 +1,25 @@ +from dataclasses import Field from typing import Any, TextIO import tomlkit from typing_extensions import Self from fancy_dataclass.dict import AnyDict -from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable +from fancy_dataclass.serialize import DictFileSerializableDataclass, FileSerializable, from_dict_value_basic, to_dict_value_basic from fancy_dataclass.utils import AnyIO +def _remove_null_dict_values(val: Any) -> Any: + """Removes all null (None) values from a dict. + + Does this recursively to any nested dicts or lists.""" + if isinstance(val, (list, tuple)): + return type(val)(_remove_null_dict_values(elt) for elt in val) + if isinstance(val, dict): + return type(val)({key: _remove_null_dict_values(elt) for (key, elt) in val.items() if (elt is not None)}) + return val + + class TOMLSerializable(FileSerializable): """Mixin class enabling conversion of an object to/from TOML.""" @@ -59,11 +71,24 @@ class TOMLDataclass(DictFileSerializableDataclass, TOMLSerializable): """Dataclass mixin enabling default serialization of dataclass objects to and from TOML.""" # TODO: require subclass to set qualified_type=True, like JSONDataclass? - @classmethod def _dict_to_text_file(cls, d: AnyDict, fp: TextIO, **kwargs: Any) -> None: + d = _remove_null_dict_values(d) tomlkit.dump(d, fp, **kwargs) @classmethod def _text_file_to_dict(cls, fp: TextIO, **kwargs: Any) -> AnyDict: return tomlkit.load(fp) + + @classmethod + def _to_dict_value_basic(cls, val: Any) -> Any: + return to_dict_value_basic(val) + + @classmethod + def _from_dict_value_basic(cls, tp: type, val: Any) -> Any: + return super()._from_dict_value_basic(tp, from_dict_value_basic(tp, val)) + + @classmethod + def _get_missing_value(cls, fld: Field) -> Any: # type: ignore[type-arg] + # replace any missing required fields with a default of None + return None diff --git a/fancy_dataclass/utils.py b/fancy_dataclass/utils.py index 030cdbe..3d254b6 100644 --- a/fancy_dataclass/utils.py +++ b/fancy_dataclass/utils.py @@ -9,7 +9,7 @@ import importlib from pathlib import Path import re -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, ForwardRef, Generic, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Type, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, ForwardRef, Generic, Iterable, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Type, TypeVar, Union, get_args, get_origin, get_type_hints from typing_extensions import TypeGuard @@ -45,6 +45,17 @@ def __init__(self, tp: type, val: Any) -> None: super().__init__(f'could not convert {val!r} to type {tp_name!r}') +def type_is_optional(tp: type) -> bool: + """Determines if a type is an Optional type. + + Args: + tp: Type to check + + Returns: + True if the type is Optional""" + origin_type = get_origin(tp) + return (origin_type == Union) and (getattr(tp, '_name', None) == 'Optional') + def safe_dict_insert(d: Dict[Any, Any], key: str, val: Any) -> None: """Inserts a (key, value) pair into a dict, if the key is not already present. @@ -233,8 +244,56 @@ def coerce_to_dataclass(cls: Type[T], obj: object) -> T: Returns: A new object of the desired type, coerced from the input object""" - d = {fld.name: getattr(obj, fld.name) for fld in dataclasses.fields(cls) if hasattr(obj, fld.name)} # type: ignore[arg-type] - return cls(**d) + kwargs = {} + for fld in dataclasses.fields(cls): # type: ignore[arg-type] + if hasattr(obj, fld.name): + val = getattr(obj, fld.name) + if is_dataclass(fld.type): + val = coerce_to_dataclass(fld.type, val) + else: + origin_type = get_origin(fld.type) + if origin_type and issubclass_safe(origin_type, Iterable): + if issubclass(origin_type, dict): + (_, val_type) = get_args(origin_type) + if is_dataclass(val_type): + val = type(val)({key: coerce_to_dataclass(val_type, elt) for (key, elt) in val.items()}) + elif issubclass(origin_type, tuple): + val = type(val)(coerce_to_dataclass(tp, elt) if is_dataclass(tp) else elt for (tp, elt) in zip(get_args(fld.type), val)) + else: + (elt_type,) = get_args(fld.type) + if is_dataclass(elt_type): + val = type(val)(coerce_to_dataclass(elt_type, elt) for elt in val) + kwargs[fld.name] = val + return cls(**kwargs) + +def dataclass_type_map(cls: Type['DataclassInstance'], func: Callable[[type], type]) -> Type['DataclassInstance']: + """Applies a type function to all dataclass field types, recursively through container types. + + Args: + cls: Target dataclass type to manipulate + func: Function to map onto basic (non-container) field types + + Returns: + A new dataclass type whose field types have been mapped by the function""" + def _map_func(tp: type) -> type: + return func(dataclass_type_map(tp, func)) if is_dataclass(tp) else func(tp) + field_data = [] + for fld in get_dataclass_fields(cls, include_classvars=True): + new_fld = copy(fld) + origin_type = get_origin(fld.type) + if origin_type and issubclass_safe(origin_type, Iterable): + if issubclass(origin_type, dict): + (key_type, val_type) = get_args(origin_type) + tp = origin_type[key_type, _map_func(val_type)] + elif issubclass(origin_type, tuple): + tp = origin_type[*[_map_func(elt_type) for elt_type in get_args(fld.type)]] + else: + (elt_type,) = get_args(fld.type) + tp = origin_type[_map_func(elt_type)] + else: + tp = _map_func(fld.type) + field_data.append((fld.name, tp, new_fld)) + return make_dataclass(cls.__name__, field_data, bases=cls.__bases__) ############## @@ -257,7 +316,6 @@ def traverse_dataclass(cls: type) -> Iterator[Tuple[RecordPath, Field]]: # type Raises: TypeError: if the type cannot be traversed""" def _make_optional(fld: Field) -> Field: # type: ignore[type-arg] - new_fld = Field new_fld = copy(fld) # type: ignore[assignment] new_fld.type = Optional[fld.type] # type: ignore new_fld.default = None # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 42ad56e..0ba417d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ ] dependencies = [ "sqlalchemy >= 2.0", + "tomlkit >= 0.11", "typing_extensions" ] @@ -57,7 +58,7 @@ dependencies = [ [tool.hatch.envs.lint.scripts] run-mypy = "mypy --install-types --non-interactive {args:fancy_dataclass tests}" run-ruff = "ruff check" -run-vermin = "vermin {args:--eval-annotations --no-tips --exclude 'tests.test_json.StrEnum' fancy_dataclass}" +run-vermin = "vermin {args:--eval-annotations --no-tips --exclude 'tests.test_serializable.StrEnum' fancy_dataclass}" all = ["run-ruff", "run-vermin", "run-mypy"] [tool.hatch.envs.test] @@ -116,10 +117,6 @@ disable_error_code = ["assignment"] module = "tests.test_config" disable_error_code = ["misc", "union-attr"] -[[tool.mypy.overrides]] -module = "tests.test_json" -disable_error_code = ["assignment", "misc"] - [[tool.mypy.overrides]] module = "tests.test_inheritance" disable_error_code = ["assignment", "misc"] @@ -128,6 +125,10 @@ disable_error_code = ["assignment", "misc"] module = "tests.test_mixin" disable_error_code = ["assignment", "call-overload", "has-type", "misc", "union-attr"] +[[tool.mypy.overrides]] +module = "tests.test_serializable" +disable_error_code = ["assignment", "misc"] + [[tool.mypy.overrides]] module = "tests.test_subprocess" disable_error_code = ["assignment", "misc"] diff --git a/tests/test_json.py b/tests/test_json.py deleted file mode 100644 index 933a5c4..0000000 --- a/tests/test_json.py +++ /dev/null @@ -1,588 +0,0 @@ -from collections import namedtuple -from dataclasses import asdict, dataclass, field, fields -from datetime import datetime -from enum import Enum, Flag, auto -import json -import math -import re -import sys -from typing import Any, ClassVar, List, Literal, NamedTuple, Optional, TypedDict, Union - -import pytest -from typing_extensions import Annotated, Doc - -from fancy_dataclass.dict import DictDataclass -from fancy_dataclass.json import JSONBaseDataclass, JSONDataclass - - -NOW = datetime.now() - - -@dataclass -class DCEmpty(JSONDataclass): - ... - -@dataclass -class DC1(JSONBaseDataclass): - x: int - y: float - z: str - -@dataclass -class DC2(JSONBaseDataclass): - x: int - y: float - z: str - -@dataclass -class DC1Sub(DC1): - ... - -@dataclass -class DC2Sub(DC2): - ... - -@dataclass -class DC3(JSONDataclass): - list: List[int] - -class MyObject: - """This object is not JSON-serializable.""" - def __eq__(self, other): - return isinstance(other, MyObject) - -@dataclass -class DCNonJSONSerializable(JSONDataclass): - x: int - obj: MyObject - -@dataclass -class DCOptionalInt(JSONDataclass): - x: int - y: Optional[int] - -@dataclass -class DCOptionalStr(JSONDataclass): - x: str - y: Optional[str] - -@dataclass -class DCUnion(JSONDataclass): - x: Union[int, str] - -@dataclass -class DCLiteral(JSONDataclass): - lit: Literal['a', 1] - -@dataclass -class DCDatetime(JSONDataclass): - dt: datetime - -class MyEnum(Enum): - a = auto() - b = auto() - -@dataclass -class DCEnum(JSONDataclass): - enum: MyEnum - -if sys.version_info[:2] < (3, 11): - class StrEnum(str, Enum): - pass -else: - from enum import StrEnum - -class MyStrEnum(StrEnum): - a = 'a' - b = 'b' - -@dataclass -class DCStrEnum(JSONDataclass): - enum: MyStrEnum - -class Color(Flag): - RED = auto() - GREEN = auto() - BLUE = auto() - -@dataclass -class DCColors(JSONDataclass): - colors: List[Color] - -@dataclass -class DCRange(JSONDataclass): - range: range - -@dataclass -class DCAnnotated(JSONDataclass): - x: Annotated[int, 'an integer'] - y: Annotated[float, Doc('a float')] - -class MyTypedDict(TypedDict): - x: int - y: str - -@dataclass -class DCTypedDict(JSONDataclass): - d: MyTypedDict - -MyUntypedNamedTuple = namedtuple('MyUntypedNamedTuple', ['x', 'y']) - -class MyTypedNamedTuple(NamedTuple): - x: int - y: str - -@dataclass -class DCUntypedNamedTuple(JSONDataclass): - t: MyUntypedNamedTuple - -@dataclass -class DCTypedNamedTuple(JSONDataclass): - t: MyTypedNamedTuple - -@dataclass -class DCAny(JSONDataclass): - val: Any - -@dataclass -class DCFloat(JSONDataclass): - x: float - -@dataclass -class DCSuppress(JSONDataclass, suppress_defaults=False): - cv1: ClassVar[int] = field(default=0) - x: int = field(default=1) - y: int = field(default=2, metadata={'suppress': True}) - z: int = field(default=3, metadata={'suppress': False}) - -@dataclass -class DCList(JSONDataclass): - vals: List[DCAny] - - -TEST_JSON = [ - DCEmpty(), - DC1(3, 4.7, 'abc'), - DC2(3, 4.7, 'abc'), - DC1Sub(3, 4.7, 'abc'), - DC2Sub(3, 4.7, 'abc'), - DC3([1, 2, 3]), - DCOptionalInt(1, 2), - DCOptionalInt(1, None), - DCOptionalStr('a', 'b'), - DCOptionalStr('a', 'None'), - DCUnion(1), - DCUnion('a'), - DCUnion('1'), - DCLiteral('a'), - DCLiteral(1), - DCDatetime(NOW), - DCEnum(MyEnum.a), - DCStrEnum(MyStrEnum.a), - DCColors(list(Color)), - DCRange(range(1, 10, 3)), - DCAnnotated(3, 4.7), - DCTypedDict({'x': 3, 'y': 'a'}), - DCUntypedNamedTuple(MyUntypedNamedTuple(3, 'a')), - DCTypedNamedTuple(MyTypedNamedTuple(3, 'a')), - DCAny(3), - DCAny('a'), - DCAny({}), - DCAny(None), - DCSuppress(), - DCList([DCAny(None), DCAny(1), DCAny([1]), DCAny(None), DCAny({})]), -] - -def _make_dict_dataclass(cls: type) -> type: - """Converts a JSONDataclass into a plain DictDataclass.""" - bases = [] - for base in cls.__bases__: - bases.append(DictDataclass if (base is JSONDataclass) else _make_dict_dataclass(base)) - return type(cls.__name__, tuple(bases), dict(cls.__dict__)) - -@pytest.mark.parametrize('obj', TEST_JSON) -def test_dict_convert(obj): - """Tests conversion to/from dict.""" - if isinstance(obj, JSONBaseDataclass): - assert 'type' in obj.to_dict() - else: - assert 'type' not in obj.to_dict() - assert type(obj).from_dict(obj.to_dict()) == obj - # test dict round-trip for regular DictDataclass - cls = _make_dict_dataclass(type(obj)) - obj2 = cls(**{fld.name: getattr(obj, fld.name) for fld in fields(obj)}) - assert cls.from_dict(obj2.to_dict()) == obj2 - -def test_special_type_convert(): - """Tests that DictDataclass does not do special type conversion for certain types, while JSONDataclass does.""" - @dataclass - class DCDatetimePlain(DictDataclass): - dt: datetime - dt = datetime.now() - obj1 = DCDatetimePlain(dt) - assert obj1.to_dict() == {'dt': dt} - obj2 = DCDatetime(dt) - assert obj2.to_dict() == {'dt': dt.isoformat()} - @dataclass - class DCEnumPlain(DictDataclass): - enum: MyEnum - obj1 = DCEnumPlain(MyEnum.a) - assert obj1.to_dict() == {'enum': MyEnum.a} - obj2 = DCEnum(MyEnum.a) - assert obj2.to_dict() == {'enum': 1} - @dataclass - class DCRangePlain(DictDataclass): - range: range - r = range(1, 10, 3) - obj1 = DCRangePlain(r) - assert obj1.to_dict() == {'range': r} - obj2 = DCRange(r) - assert obj2.to_dict() == {'range': [1, 10, 3]} - r = range(1, 10) - obj2 = DCRange(r) - assert obj2.to_dict() == {'range': [1, 10]} - -@pytest.mark.parametrize('obj', TEST_JSON) -def test_json_convert(obj, tmp_path): - """Tests conversion to/from JSON.""" - # write to JSON text file - json_path = tmp_path / 'test.json' - with open(json_path, 'w') as f: - obj.to_json(f) - with open(json_path) as f: - obj1 = type(obj).from_json(f) - assert obj1 == obj - # write to JSON binary file - with open(json_path, 'wb') as f: - obj.to_json(f) - with open(json_path, 'rb') as f: - obj2 = type(obj).from_json(f) - assert obj2 == obj - # convert to JSON string - s = obj.to_json_string() - assert s == json.dumps(obj.to_dict()) - obj3 = type(obj).from_json_string(s) - assert obj3 == obj - -def test_optional(): - obj: Any = DCOptionalInt(1, 2) - d = obj.to_dict() - assert d == {'x': 1, 'y': 2} - assert DCOptionalInt.from_dict(d) == obj - obj = DCOptionalInt(1, None) - d = obj.to_dict() - assert d == {'x': 1, 'y': None} - assert DCOptionalInt.from_dict(d) == obj - obj = DCOptionalInt(None, 1) # type: ignore[arg-type] - # validation does not occur when converting to dict, only the reverse - d = obj.to_dict() - assert d == {'x': None, 'y': 1} - with pytest.raises(ValueError, match="could not convert None to type 'int'"): - _ = DCOptionalInt.from_dict(d) - obj = DCOptionalStr('a', 'b') - d = obj.to_dict() - assert d == {'x': 'a', 'y': 'b'} - assert DCOptionalStr.from_dict(d) == obj - obj = DCOptionalStr('a', None) - d = obj.to_dict() - assert d == {'x': 'a', 'y': None} - assert DCOptionalStr.from_dict(d) == obj - obj = DCOptionalStr(None, 'b') # type: ignore[arg-type] - d = obj.to_dict() - assert d == {'x': None, 'y': 'b'} - with pytest.raises(ValueError, match="could not convert None to type 'str'"): - _ = DCOptionalStr.from_dict(d) - -def test_literal(): - obj = DCLiteral(1) - assert obj.to_dict() == {'lit': 1} - obj = DCLiteral('b') # type: ignore[arg-type] - d = obj.to_dict() - assert d == {'lit': 'b'} - with pytest.raises(ValueError, match=re.escape("could not convert 'b' to type \"typing.Literal['a', 1]\"")): - _ = DCLiteral.from_dict(d) - -def test_datetime(): - obj = DCDatetime(NOW) - d = obj.to_dict() - s = NOW.isoformat() - assert d == {'dt': s} - # compare with dataclasses.asdict - assert asdict(obj) == {'dt': NOW} - assert DCDatetime.from_dict(d) == obj - # some prefixes of full isoformat are valid - assert DCDatetime.from_dict({'dt': NOW.strftime('%Y-%m-%dT%H:%M:%S')}).dt.isoformat() == s[:19] - assert DCDatetime.from_dict({'dt': NOW.strftime('%Y-%m-%d')}).dt.isoformat()[:10] == s[:10] - # other datetime formats are invalid - with pytest.raises(ValueError, match='Invalid isoformat string'): - DCDatetime.from_dict({'dt': NOW.strftime('%m/%d/%Y %H:%M:%S')}) - with pytest.raises(ValueError, match='Invalid isoformat string'): - DCDatetime.from_dict({'dt': NOW.strftime('%d/%m/%Y')}) - -def test_enum(): - obj1 = DCEnum(MyEnum.a) - assert obj1.to_dict() == {'enum': 1} - assert asdict(obj1) == {'enum': MyEnum.a} - obj2 = DCStrEnum(MyStrEnum.a) - assert obj2.to_dict() == {'enum': 'a'} - assert asdict(obj2) == {'enum': MyStrEnum.a} - obj3 = DCColors(list(Color)) - assert obj3.to_dict() == {'colors': [1, 2, 4]} - assert asdict(obj3) == {'colors': list(Color)} - -def test_annotated(): - obj = DCAnnotated(3, 4.7) - assert obj.to_dict() == {'x': 3, 'y': 4.7} - -def test_typed_dict(): - td: MyTypedDict = {'x': 3, 'y': 'a'} - obj = DCTypedDict(td) - assert obj.to_dict() == {'d': td} - # invalid TypedDicts - for d in [{'x': 3}, {'x': 3, 'y': 'a', 'z': 1}, {'x': 3, 'y': 4}, {'x': 3, 'y': None}]: - with pytest.raises(ValueError, match="could not convert .* to type .*"): - _ = DCTypedDict.from_dict({'d': d}) - -def test_namedtuple(): - nt1 = MyUntypedNamedTuple(3, 'a') - obj1 = DCUntypedNamedTuple(nt1) - assert obj1.to_dict() == {'t': {'x': 3, 'y': 'a'}} - nt2 = MyTypedNamedTuple(3, 'a') - obj2 = DCTypedNamedTuple(nt2) - assert obj2.to_dict() == {'t': {'x': 3, 'y': 'a'}} - # invalid NamedTuple field (validation occurs on from_dict) - nt2 = MyTypedNamedTuple(3, 4) # type: ignore[arg-type] - obj2 = DCTypedNamedTuple(nt2) - d = {'t': {'x': 3, 'y': 4}} - assert obj2.to_dict() == d - with pytest.raises(ValueError, match="could not convert 4 to type 'str'"): - _ = DCTypedNamedTuple.from_dict(d) - -def test_subclass_json_dataclass(): - def _remove_type(d): - return {key: val for (key, val) in d.items() if (key != 'type')} - obj = DC1Sub(3, 4.7, 'abc') - obj1 = DC1Sub.from_dict(obj.to_dict()) - assert obj1 == obj - assert isinstance(obj1, DC1Sub) - d = obj.to_dict() - assert d['type'] == 'tests.test_json.DC1Sub' - obj2 = DC1.from_dict(d) - # fully qualified type is resolved to the subclass - assert obj2 == obj - assert isinstance(obj2, DC1Sub) - obj3 = DC1.from_dict(_remove_type(d)) - assert isinstance(obj3, DC1) - assert not isinstance(obj3, DC1Sub) - d3 = obj3.to_dict() - # objects have the same dict other than the type - assert _remove_type(d3) == _remove_type(d) - assert d3['type'] == 'tests.test_json.DC1' - # test behavior of inheriting from JSONDataclass - @dataclass - class MyDC(JSONDataclass): - pass - assert MyDC().to_dict() == {} - with pytest.raises(TypeError, match='you must set qualified_type=True'): - @dataclass - class MyDC1(MyDC): - pass - @dataclass - class MyDC2(MyDC, qualified_type=True): - pass - # TODO: forbid local types? - assert MyDC2().to_dict() == {'type': 'tests.test_json.test_subclass_json_dataclass..MyDC2'} - @dataclass - class MyBaseDC(JSONBaseDataclass): - pass - @dataclass - class MyDC3(MyBaseDC): - pass - assert MyDC3().to_dict() == {'type': 'tests.test_json.test_subclass_json_dataclass..MyDC3'} - with pytest.raises(TypeError, match='you must set qualified_type=True'): - @dataclass - class MyDC4(MyBaseDC, qualified_type=False): - pass - with pytest.raises(TypeError, match='you must set qualified_type=True'): - @dataclass - class MyDC5(MyDC, JSONBaseDataclass): - pass - @dataclass - class MyDC6(JSONBaseDataclass, MyDC): - pass - @dataclass - class MyDC7(MyDC, JSONBaseDataclass, qualified_type=True): - pass - -def test_subclass_json_base_dataclass(): - """Tests JSONBaseDataclass.""" - obj = DC2Sub(3, 4.7, 'abc') - d = obj.to_dict() - assert d['type'] == 'tests.test_json.DC2Sub' - obj1 = DC2Sub.from_dict(d) - assert obj1 == obj - obj2 = DC2.from_dict(d) - assert isinstance(obj2, DC2Sub) - assert obj2 == obj - -def test_invalid_json_obj(): - """Attempts to convert an object to JSON that is not JSONSerializable.""" - obj = MyObject() - njs = DCNonJSONSerializable(3, obj) - d = {'x': 3, 'obj': obj} - assert njs.to_dict() == d - # conversion from dict works OK - assert DCNonJSONSerializable.from_dict(d) == njs - with pytest.raises(TypeError, match='Object of type MyObject is not JSON serializable'): - _ = njs.to_json_string() - -def test_suppress(): - """Tests behavior of setting the 'suppress' option on a field.""" - obj = DCSuppress() - d = {'x': 1, 'z': 3} - assert obj.to_dict() == d - assert obj.to_dict(full=True) == d - assert DCSuppress.from_dict(d) == obj - obj = DCSuppress(y=100) - assert obj.to_dict() == d - assert obj.to_dict(full=True) == d - assert DCSuppress.from_dict(d).y == 2 - -def test_suppress_required_field(): - """Tests that a required field with suppress=True cannot create a valid dict.""" - @dataclass - class DCSuppressRequired(JSONDataclass): - x: int = field(metadata={'suppress': True}) - with pytest.raises(TypeError, match='missing 1 required positional argument'): - _ = DCSuppressRequired() - obj = DCSuppressRequired(1) - assert obj.to_dict() == {} - with pytest.raises(ValueError, match="'x' field is required"): - _ = DCSuppressRequired.from_dict({}) - _ = DCSuppressRequired.from_dict({'x': 1}) - -def test_suppress_defaults(): - """Tests behavior of the suppress_defaults option, both at the class level and the field level.""" - @dataclass - class MyDC(JSONDataclass): - x: int = 1 - assert MyDC.__settings__.suppress_defaults is True - obj = MyDC() - assert obj.to_dict() == {} - assert obj.to_dict(full=True) == {'x': 1} - obj = MyDC(2) - assert obj.to_dict() == {'x': 2} - assert obj.to_dict(full=True) == {'x': 2} - @dataclass - class MyDC(JSONDataclass, suppress_defaults=False): - x: int = 1 - obj = MyDC() - assert obj.to_dict() == {'x': 1} - assert obj.to_dict(full=True) == {'x': 1} - @dataclass - class MyDC(JSONDataclass): - x: int = field(default=1, metadata={'suppress_default': False}) - obj = MyDC() - assert obj.to_dict() == {'x': 1} - assert obj.to_dict(full=True) == {'x': 1} - @dataclass - class MyDC(JSONDataclass, suppress_defaults=False): - x: int = field(default=1, metadata={'suppress_default': True}) - obj = MyDC() - assert obj.to_dict() == {} - assert obj.to_dict(full=True) == {'x': 1} - -def test_class_var(): - """Tests the behavior of ClassVars.""" - @dataclass - class MyDC1(JSONDataclass): - x: ClassVar[int] - obj = MyDC1() - assert obj.to_dict() == {} - assert obj.to_dict(full=True) == {} - assert MyDC1.from_dict({}) == obj - with pytest.raises(AttributeError, match='object has no attribute'): - _ = obj.x - @dataclass - class MyDC2(JSONDataclass): - x: ClassVar[int] = field(metadata={'suppress': False}) - obj = MyDC2() - with pytest.raises(AttributeError, match='object has no attribute'): - _ = obj.to_dict() - assert MyDC2.from_dict({}) == obj - @dataclass - class MyDC3(JSONDataclass): - x: ClassVar[int] = 1 - obj = MyDC3() - assert obj.to_dict() == {} - assert obj.to_dict(full=True) == {} - obj0 = MyDC3.from_dict({}) - assert obj0 == obj - assert obj0.x == 1 - # ClassVar gets ignored when loading from dict - obj1 = MyDC3.from_dict({'x': 1}) - assert obj1 == obj - assert obj1.x == 1 - obj2 = MyDC3.from_dict({'x': 2}) - assert obj2 == obj - assert obj2.x == 1 - MyDC3.x = 2 - obj = MyDC3() - assert obj.to_dict() == {} - # ClassVar field has to override with suppress=False to include it - assert obj.to_dict(full=True) == {} - @dataclass - class MyDC4(JSONDataclass): - x: ClassVar[int] = field(default=1, metadata={'suppress': False}) - obj = MyDC4() - assert obj.to_dict() == {} # equals default, so suppress it - assert obj.to_dict(full=True) == {'x': 1} - obj0 = MyDC4.from_dict({}) - assert obj0 == obj - obj2 = MyDC4.from_dict({'x': 2}) - assert obj2 == obj - assert obj2.x == 1 - MyDC4.x = 2 - obj = MyDC4() - assert obj.to_dict() == {'x': 2} # no longer equals default - assert obj.to_dict(full=True) == {'x': 2} - -def test_from_dict_kwargs(): - """Tests behavior of from_json_string with respect to partitioning kwargs into from_dict and json.loads.""" - @dataclass - class MyDC(JSONDataclass): - x: int = 1 - s = '{"x": 1}' - assert MyDC.from_json_string(s) == MyDC() - assert MyDC.from_json_string(s, strict=True) == MyDC() - with pytest.raises(ValueError, match="'y' is not a valid field for MyDC"): - _ = MyDC.from_json_string('{"x": 1, "y": 2}', strict=True) - parse_int = lambda val: int(val) + 1 - assert MyDC.from_json_string(s, parse_int=parse_int) == MyDC(2) - assert MyDC.from_json_string(s, strict=True, parse_int=parse_int) == MyDC(2) - with pytest.raises(TypeError, match="unexpected keyword argument 'fake_kwarg'"): - _ = MyDC.from_json_string(s, fake_kwarg=True) - -@pytest.mark.parametrize(['obj', 'd', 'obj2'], [ - (DCEmpty(), {}, None), - (DCFloat(1), {'x': 1}, None), - (DCFloat(math.inf), {'x': math.inf}, None), - (DCFloat(math.nan), {'x': math.nan}, None), - (DCColors([Color.RED, Color.BLUE]), {'colors': [1, 4]}, None), - (DCStrEnum(MyStrEnum.a), {'enum': 'a'}, None), - (DCNonJSONSerializable(1, MyObject()), {'x': 1, 'obj': MyObject()}, None), - (DCList([]), {'vals': []}, None), - (DCList([DCAny(None)]), {'vals': [{'val': None}]}, None), - (DCList([DCAny(1)]), {'vals': [{'val': 1}]}, None), - (DCList([DCAny([])]), {'vals': [{'val': []}]}, None), - (DCList([DCAny({})]), {'vals': [{'val': {}}]}, None), - (DCList([DCAny(DCAny(1))]), {'vals': [{'val': {'val': 1}}]}, DCList([DCAny({'val': 1})])), -]) -def test_round_trips(obj, d, obj2): - """Tests round-trip fidelity to/from dict.""" - assert obj.to_dict() == d - obj2 = type(obj).from_dict(d) - if obj2 is None: # round-trip is valid - assert obj == obj2 - d2 = obj2.to_dict() - assert d == d2 diff --git a/tests/test_serializable.py b/tests/test_serializable.py new file mode 100644 index 0000000..cff590b --- /dev/null +++ b/tests/test_serializable.py @@ -0,0 +1,701 @@ +from collections import namedtuple +from dataclasses import asdict, dataclass, field +from datetime import datetime +from enum import Enum, Flag, auto +import json +import math +import re +import sys +from typing import Any, ClassVar, List, Literal, NamedTuple, Optional, TypedDict, Union + +import pytest +from typing_extensions import Annotated, Doc + +from fancy_dataclass.dict import DictDataclass +from fancy_dataclass.json import JSONBaseDataclass, JSONDataclass +from fancy_dataclass.toml import TOMLDataclass +from fancy_dataclass.utils import coerce_to_dataclass, dataclass_type_map, issubclass_safe + + +NOW = datetime.now() + + +def _convert_json_dataclass(cls, new_cls): + """Converts JSONDataclass base classes with the given class, recursively within the input class's fields.""" + # TODO: this is very hacky; can we clean it up? + def _convert(tp): + if issubclass_safe(tp, JSONDataclass): + bases = [] + for base in tp.__bases__: + if base in (JSONDataclass, JSONBaseDataclass): + base = new_cls + elif issubclass_safe(base, JSONDataclass): + base = _convert_json_dataclass(base, new_cls) + bases.append(base) + return type(tp.__name__, tuple(bases), dict(tp.__dict__)) + return tp + tp = _convert(dataclass_type_map(cls, _convert)) + tp.__eq__ = cls.__eq__ + return tp + +################ +# TEST CLASSES # +################ + +@dataclass +class DCEmpty(JSONDataclass): + ... + +@dataclass +class DC1(JSONBaseDataclass): + x: int + y: float + z: str + +@dataclass +class DC2(JSONBaseDataclass): + x: int + y: float + z: str + +@dataclass +class DC1Sub(DC1): + ... + +@dataclass +class DC2Sub(DC2): + ... + +@dataclass +class DC3(JSONDataclass): + list: List[int] + +class MyObject: + """This object is not JSON-serializable.""" + def __eq__(self, other): + return isinstance(other, MyObject) + +@dataclass +class DCNonJSONSerializable(JSONDataclass): + x: int + obj: MyObject + +@dataclass +class DCOptional(JSONDataclass): + x: Optional[int] + +@dataclass +class DCOptionalInt(JSONDataclass): + x: int + y: Optional[int] + +@dataclass +class DCOptionalStr(JSONDataclass): + x: str + y: Optional[str] + +@dataclass +class DCUnion(JSONDataclass): + x: Union[int, str] + +@dataclass +class DCLiteral(JSONDataclass): + lit: Literal['a', 1] + +@dataclass +class DCDatetime(JSONDataclass): + dt: datetime + +class MyEnum(Enum): + a = auto() + b = auto() + +@dataclass +class DCEnum(JSONDataclass): + enum: MyEnum + +if sys.version_info[:2] < (3, 11): + class StrEnum(str, Enum): + pass +else: + from enum import StrEnum + +class MyStrEnum(StrEnum): + a = 'a' + b = 'b' + +@dataclass +class DCStrEnum(JSONDataclass): + enum: MyStrEnum + +class Color(Flag): + RED = auto() + GREEN = auto() + BLUE = auto() + +@dataclass +class DCColors(JSONDataclass): + colors: List[Color] + +@dataclass +class DCRange(JSONDataclass): + range: range + +@dataclass +class DCAnnotated(JSONDataclass): + x: Annotated[int, 'an integer'] + y: Annotated[float, Doc('a float')] + +class MyTypedDict(TypedDict): + x: int + y: str + +@dataclass +class DCTypedDict(JSONDataclass): + d: MyTypedDict + +MyUntypedNamedTuple = namedtuple('MyUntypedNamedTuple', ['x', 'y']) + +class MyTypedNamedTuple(NamedTuple): + x: int + y: str + +@dataclass +class DCUntypedNamedTuple(JSONDataclass): + t: MyUntypedNamedTuple + +@dataclass +class DCTypedNamedTuple(JSONDataclass): + t: MyTypedNamedTuple + +@dataclass +class DCAny(JSONDataclass): + val: Any + +@dataclass +class DCFloat(JSONDataclass): + x: float + def __eq__(self, other): + # make nan equal, for comparison testing + return (math.isnan(self.x) and math.isnan(other.x)) or (self.x == other.x) + +@dataclass +class DCSuppress(JSONDataclass, suppress_defaults=False): + cv1: ClassVar[int] = field(default=0) + x: int = field(default=1) + y: int = field(default=2, metadata={'suppress': True}) + z: int = field(default=3, metadata={'suppress': False}) + +@dataclass +class DCList(JSONDataclass): + vals: List[DCAny] + + +TEST_JSON = [ + DCEmpty(), + DC1(3, 4.7, 'abc'), + DC2(3, 4.7, 'abc'), + DC1Sub(3, 4.7, 'abc'), + DC2Sub(3, 4.7, 'abc'), + DC3([1, 2, 3]), + DCOptionalInt(1, 2), + DCOptionalInt(1, None), + DCOptionalStr('a', 'b'), + DCOptionalStr('a', 'None'), + DCUnion(1), + DCUnion('a'), + DCUnion('1'), + DCLiteral('a'), + DCLiteral(1), + DCDatetime(NOW), + DCEnum(MyEnum.a), + DCStrEnum(MyStrEnum.a), + DCColors(list(Color)), + DCRange(range(1, 10, 3)), + DCAnnotated(3, 4.7), + DCTypedDict({'x': 3, 'y': 'a'}), + DCUntypedNamedTuple(MyUntypedNamedTuple(3, 'a')), + DCTypedNamedTuple(MyTypedNamedTuple(3, 'a')), + DCAny(3), + DCAny('a'), + DCAny({}), + DCAny(None), + DCSuppress(), + DCList([DCAny(None), DCAny(1), DCAny([1]), DCAny(None), DCAny({})]), +] + +class TestDict: + """Unit tests for DictDataclass.""" + + base_cls = DictDataclass + + def _test_dict_convert(self, obj): + # convert object to the desired base class + tp = _convert_json_dataclass(type(obj), self.base_cls) + assert issubclass(tp, self.base_cls) + obj = coerce_to_dataclass(tp, obj) + assert isinstance(obj, self.base_cls) + if obj.__settings__.qualified_type: + assert 'type' in obj.to_dict() + else: + assert 'type' not in obj.to_dict() + assert tp.from_dict(obj.to_dict()) == obj + + # TODO: file round trips + + @pytest.mark.parametrize('obj', TEST_JSON) + def test_dict_convert(self, obj): + """Tests conversion to/from dict.""" + self._test_dict_convert(obj) + + +class TestJSON(TestDict): + """Unit tests for JSONDataclass.""" + + base_cls = JSONDataclass + + @pytest.mark.parametrize('obj', TEST_JSON) + def test_dict_convert(self, obj): + """Tests conversion to/from dict.""" + self._test_dict_convert(obj) + + def test_special_type_convert(self): + """Tests that DictDataclass does not do special type conversion for certain types, while JSONDataclass does.""" + DCDatetimePlain = _convert_json_dataclass(DCDatetime, DictDataclass) + DCEnumPlain = _convert_json_dataclass(DCEnum, DictDataclass) + DCRangePlain = _convert_json_dataclass(DCRange, DictDataclass) + dt = NOW + obj1 = DCDatetimePlain(dt) + assert obj1.to_dict() == {'dt': dt} + obj2 = DCDatetime(dt) + assert obj2.to_dict() == {'dt': dt.isoformat()} + obj1 = DCEnumPlain(MyEnum.a) + assert obj1.to_dict() == {'enum': MyEnum.a} + obj2 = DCEnum(MyEnum.a) + assert obj2.to_dict() == {'enum': 1} + r = range(1, 10, 3) + obj1 = DCRangePlain(r) + assert obj1.to_dict() == {'range': r} + obj2 = DCRange(r) + assert obj2.to_dict() == {'range': [1, 10, 3]} + r = range(1, 10) + obj2 = DCRange(r) + assert obj2.to_dict() == {'range': [1, 10]} + + @pytest.mark.parametrize('obj', TEST_JSON) + def test_json_convert(self, obj, tmp_path): + """Tests conversion to/from JSON.""" + # write to JSON text file + json_path = tmp_path / 'test.json' + with open(json_path, 'w') as f: + obj.to_json(f) + with open(json_path) as f: + obj1 = type(obj).from_json(f) + assert obj1 == obj + # write to JSON binary file + with open(json_path, 'wb') as f: + obj.to_json(f) + with open(json_path, 'rb') as f: + obj2 = type(obj).from_json(f) + assert obj2 == obj + # convert to JSON string + s = obj.to_json_string() + assert s == json.dumps(obj.to_dict()) + obj3 = type(obj).from_json_string(s) + assert obj3 == obj + + # TODO: parametrize + def test_optional(self): + obj: Any = DCOptionalInt(1, 2) + d = obj.to_dict() + assert d == {'x': 1, 'y': 2} + assert DCOptionalInt.from_dict(d) == obj + obj = DCOptionalInt(1, None) + d = obj.to_dict() + assert d == {'x': 1, 'y': None} + assert DCOptionalInt.from_dict(d) == obj + obj = DCOptionalInt(None, 1) # type: ignore[arg-type] + # validation does not occur when converting to dict, only the reverse + d = obj.to_dict() + assert d == {'x': None, 'y': 1} + with pytest.raises(ValueError, match="could not convert None to type 'int'"): + _ = DCOptionalInt.from_dict(d) + obj = DCOptionalStr('a', 'b') + d = obj.to_dict() + assert d == {'x': 'a', 'y': 'b'} + assert DCOptionalStr.from_dict(d) == obj + obj = DCOptionalStr('a', None) + d = obj.to_dict() + assert d == {'x': 'a', 'y': None} + assert DCOptionalStr.from_dict(d) == obj + obj = DCOptionalStr(None, 'b') # type: ignore[arg-type] + d = obj.to_dict() + assert d == {'x': None, 'y': 'b'} + with pytest.raises(ValueError, match="could not convert None to type 'str'"): + _ = DCOptionalStr.from_dict(d) + + # TODO: parametrize + def test_literal(self): + obj = DCLiteral(1) + assert obj.to_dict() == {'lit': 1} + obj = DCLiteral('b') # type: ignore[arg-type] + d = obj.to_dict() + assert d == {'lit': 'b'} + with pytest.raises(ValueError, match=re.escape("could not convert 'b' to type \"typing.Literal['a', 1]\"")): + _ = DCLiteral.from_dict(d) + + # TODO: parametrize + def test_datetime(self): + obj = DCDatetime(NOW) + d = obj.to_dict() + s = NOW.isoformat() + assert d == {'dt': s} + # compare with dataclasses.asdict + assert asdict(obj) == {'dt': NOW} + assert DCDatetime.from_dict(d) == obj + # some prefixes of full isoformat are valid + assert DCDatetime.from_dict({'dt': NOW.strftime('%Y-%m-%dT%H:%M:%S')}).dt.isoformat() == s[:19] + assert DCDatetime.from_dict({'dt': NOW.strftime('%Y-%m-%d')}).dt.isoformat()[:10] == s[:10] + # other datetime formats are invalid + with pytest.raises(ValueError, match='Invalid isoformat string'): + DCDatetime.from_dict({'dt': NOW.strftime('%m/%d/%Y %H:%M:%S')}) + with pytest.raises(ValueError, match='Invalid isoformat string'): + DCDatetime.from_dict({'dt': NOW.strftime('%d/%m/%Y')}) + + def test_enum(self): + obj1 = DCEnum(MyEnum.a) + assert obj1.to_dict() == {'enum': 1} + assert asdict(obj1) == {'enum': MyEnum.a} + obj2 = DCStrEnum(MyStrEnum.a) + assert obj2.to_dict() == {'enum': 'a'} + assert asdict(obj2) == {'enum': MyStrEnum.a} + obj3 = DCColors(list(Color)) + assert obj3.to_dict() == {'colors': [1, 2, 4]} + assert asdict(obj3) == {'colors': list(Color)} + + def test_annotated(self): + obj = DCAnnotated(3, 4.7) + assert obj.to_dict() == {'x': 3, 'y': 4.7} + + def test_typed_dict(self): + td: MyTypedDict = {'x': 3, 'y': 'a'} + obj = DCTypedDict(td) + assert obj.to_dict() == {'d': td} + # invalid TypedDicts + for d in [{'x': 3}, {'x': 3, 'y': 'a', 'z': 1}, {'x': 3, 'y': 4}, {'x': 3, 'y': None}]: + with pytest.raises(ValueError, match="could not convert .* to type .*"): + _ = DCTypedDict.from_dict({'d': d}) + + def test_namedtuple(self): + nt1 = MyUntypedNamedTuple(3, 'a') + obj1 = DCUntypedNamedTuple(nt1) + assert obj1.to_dict() == {'t': {'x': 3, 'y': 'a'}} + nt2 = MyTypedNamedTuple(3, 'a') + obj2 = DCTypedNamedTuple(nt2) + assert obj2.to_dict() == {'t': {'x': 3, 'y': 'a'}} + # invalid NamedTuple field (validation occurs on from_dict) + nt2 = MyTypedNamedTuple(3, 4) # type: ignore[arg-type] + obj2 = DCTypedNamedTuple(nt2) + d = {'t': {'x': 3, 'y': 4}} + assert obj2.to_dict() == d + with pytest.raises(ValueError, match="could not convert 4 to type 'str'"): + _ = DCTypedNamedTuple.from_dict(d) + + def test_subclass_json_dataclass(self): + def _remove_type(d): + return {key: val for (key, val) in d.items() if (key != 'type')} + obj = DC1Sub(3, 4.7, 'abc') + obj1 = DC1Sub.from_dict(obj.to_dict()) + assert obj1 == obj + assert isinstance(obj1, DC1Sub) + d = obj.to_dict() + assert d['type'] == 'tests.test_serializable.DC1Sub' + obj2 = DC1.from_dict(d) + # fully qualified type is resolved to the subclass + assert obj2 == obj + assert isinstance(obj2, DC1Sub) + obj3 = DC1.from_dict(_remove_type(d)) + assert isinstance(obj3, DC1) + assert not isinstance(obj3, DC1Sub) + d3 = obj3.to_dict() + # objects have the same dict other than the type + assert _remove_type(d3) == _remove_type(d) + assert d3['type'] == 'tests.test_serializable.DC1' + # test behavior of inheriting from JSONDataclass + @dataclass + class MyDC(JSONDataclass): + pass + assert MyDC().to_dict() == {} + with pytest.raises(TypeError, match='you must set qualified_type=True'): + @dataclass + class MyDC1(MyDC): + pass + @dataclass + class MyDC2(MyDC, qualified_type=True): + pass + # TODO: forbid local types? + assert MyDC2().to_dict() == {'type': 'tests.test_serializable.TestJSON.test_subclass_json_dataclass..MyDC2'} + @dataclass + class MyBaseDC(JSONBaseDataclass): + pass + @dataclass + class MyDC3(MyBaseDC): + pass + assert MyDC3().to_dict() == {'type': 'tests.test_serializable.TestJSON.test_subclass_json_dataclass..MyDC3'} + with pytest.raises(TypeError, match='you must set qualified_type=True'): + @dataclass + class MyDC4(MyBaseDC, qualified_type=False): + pass + with pytest.raises(TypeError, match='you must set qualified_type=True'): + @dataclass + class MyDC5(MyDC, JSONBaseDataclass): + pass + @dataclass + class MyDC6(JSONBaseDataclass, MyDC): + pass + @dataclass + class MyDC7(MyDC, JSONBaseDataclass, qualified_type=True): + pass + + def test_subclass_json_base_dataclass(self): + """Tests JSONBaseDataclass.""" + obj = DC2Sub(3, 4.7, 'abc') + d = obj.to_dict() + assert d['type'] == 'tests.test_serializable.DC2Sub' + obj1 = DC2Sub.from_dict(d) + assert obj1 == obj + obj2 = DC2.from_dict(d) + assert isinstance(obj2, DC2Sub) + assert obj2 == obj + + def test_invalid_json_obj(self): + """Attempts to convert an object to JSON that is not JSONSerializable.""" + obj = MyObject() + njs = DCNonJSONSerializable(3, obj) + d = {'x': 3, 'obj': obj} + assert njs.to_dict() == d + # conversion from dict works OK + assert DCNonJSONSerializable.from_dict(d) == njs + with pytest.raises(TypeError, match='Object of type MyObject is not JSON serializable'): + _ = njs.to_json_string() + + def test_suppress_field(self): + """Tests behavior of setting the 'suppress' option on a field.""" + obj = DCSuppress() + d = {'x': 1, 'z': 3} + assert obj.to_dict() == d + assert obj.to_dict(full=True) == d + assert DCSuppress.from_dict(d) == obj + obj = DCSuppress(y=100) + assert obj.to_dict() == d + assert obj.to_dict(full=True) == d + assert DCSuppress.from_dict(d).y == 2 + + def test_suppress_required_field(self): + """Tests that a required field with suppress=True cannot create a valid dict.""" + @dataclass + class DCSuppressRequired(JSONDataclass): + x: int = field(metadata={'suppress': True}) + with pytest.raises(TypeError, match='missing 1 required positional argument'): + _ = DCSuppressRequired() + obj = DCSuppressRequired(1) + assert obj.to_dict() == {} + with pytest.raises(ValueError, match="'x' field is required"): + _ = DCSuppressRequired.from_dict({}) + _ = DCSuppressRequired.from_dict({'x': 1}) + + def test_suppress_defaults(self): + """Tests behavior of the suppress_defaults option, both at the class level and the field level.""" + @dataclass + class MyDC(JSONDataclass): + x: int = 1 + assert MyDC.__settings__.suppress_defaults is True + obj = MyDC() + assert obj.to_dict() == {} + assert obj.to_dict(full=True) == {'x': 1} + obj = MyDC(2) + assert obj.to_dict() == {'x': 2} + assert obj.to_dict(full=True) == {'x': 2} + @dataclass + class MyDC(JSONDataclass, suppress_defaults=False): + x: int = 1 + obj = MyDC() + assert obj.to_dict() == {'x': 1} + assert obj.to_dict(full=True) == {'x': 1} + @dataclass + class MyDC(JSONDataclass): + x: int = field(default=1, metadata={'suppress_default': False}) + obj = MyDC() + assert obj.to_dict() == {'x': 1} + assert obj.to_dict(full=True) == {'x': 1} + @dataclass + class MyDC(JSONDataclass, suppress_defaults=False): + x: int = field(default=1, metadata={'suppress_default': True}) + obj = MyDC() + assert obj.to_dict() == {} + assert obj.to_dict(full=True) == {'x': 1} + + def test_class_var(self): + """Tests the behavior of ClassVars.""" + @dataclass + class MyDC1(JSONDataclass): + x: ClassVar[int] + obj = MyDC1() + assert obj.to_dict() == {} + assert obj.to_dict(full=True) == {} + assert MyDC1.from_dict({}) == obj + with pytest.raises(AttributeError, match='object has no attribute'): + _ = obj.x + @dataclass + class MyDC2(JSONDataclass): + x: ClassVar[int] = field(metadata={'suppress': False}) + obj = MyDC2() + with pytest.raises(AttributeError, match='object has no attribute'): + _ = obj.to_dict() + assert MyDC2.from_dict({}) == obj + @dataclass + class MyDC3(JSONDataclass): + x: ClassVar[int] = 1 + obj = MyDC3() + assert obj.to_dict() == {} + assert obj.to_dict(full=True) == {} + obj0 = MyDC3.from_dict({}) + assert obj0 == obj + assert obj0.x == 1 + # ClassVar gets ignored when loading from dict + obj1 = MyDC3.from_dict({'x': 1}) + assert obj1 == obj + assert obj1.x == 1 + obj2 = MyDC3.from_dict({'x': 2}) + assert obj2 == obj + assert obj2.x == 1 + MyDC3.x = 2 + obj = MyDC3() + assert obj.to_dict() == {} + # ClassVar field has to override with suppress=False to include it + assert obj.to_dict(full=True) == {} + @dataclass + class MyDC4(JSONDataclass): + x: ClassVar[int] = field(default=1, metadata={'suppress': False}) + obj = MyDC4() + assert obj.to_dict() == {} # equals default, so suppress it + assert obj.to_dict(full=True) == {'x': 1} + obj0 = MyDC4.from_dict({}) + assert obj0 == obj + obj2 = MyDC4.from_dict({'x': 2}) + assert obj2 == obj + assert obj2.x == 1 + MyDC4.x = 2 + obj = MyDC4() + assert obj.to_dict() == {'x': 2} # no longer equals default + assert obj.to_dict(full=True) == {'x': 2} + + def test_from_dict_kwargs(self): + """Tests behavior of from_json_string with respect to partitioning kwargs into from_dict and json.loads.""" + @dataclass + class MyDC(JSONDataclass): + x: int = 1 + s = '{"x": 1}' + assert MyDC.from_json_string(s) == MyDC() + assert MyDC.from_json_string(s, strict=True) == MyDC() + with pytest.raises(ValueError, match="'y' is not a valid field for MyDC"): + _ = MyDC.from_json_string('{"x": 1, "y": 2}', strict=True) + parse_int = lambda val: int(val) + 1 + assert MyDC.from_json_string(s, parse_int=parse_int) == MyDC(2) + assert MyDC.from_json_string(s, strict=True, parse_int=parse_int) == MyDC(2) + with pytest.raises(TypeError, match="unexpected keyword argument 'fake_kwarg'"): + _ = MyDC.from_json_string(s, fake_kwarg=True) + + @pytest.mark.parametrize(['obj', 'd', 'obj2'], [ + (DCEmpty(), {}, None), + (DCFloat(1), {'x': 1}, None), + (DCFloat(math.inf), {'x': math.inf}, None), + (DCFloat(math.nan), {'x': math.nan}, None), + (DCColors([Color.RED, Color.BLUE]), {'colors': [1, 4]}, None), + (DCStrEnum(MyStrEnum.a), {'enum': 'a'}, None), + (DCNonJSONSerializable(1, MyObject()), {'x': 1, 'obj': MyObject()}, None), + (DCList([]), {'vals': []}, None), + (DCList([DCAny(None)]), {'vals': [{'val': None}]}, None), + (DCList([DCAny(1)]), {'vals': [{'val': 1}]}, None), + (DCList([DCAny([])]), {'vals': [{'val': []}]}, None), + (DCList([DCAny({})]), {'vals': [{'val': {}}]}, None), + (DCList([DCAny(DCAny(1))]), {'vals': [{'val': {'val': 1}}]}, DCList([DCAny({'val': 1})])), + ]) + def test_dict_round_trips(self, obj, d, obj2): + """Tests round-trip fidelity to/from dict.""" + assert obj.to_dict() == d + obj2 = type(obj).from_dict(d) + if obj2 is None: # round-trip is valid + assert obj == obj2 + d2 = obj2.to_dict() + assert d == d2 + + +class TestTOML(TestDict): + """Unit tests for TOMLDataclass.""" + + base_cls = TOMLDataclass + + @pytest.mark.parametrize('obj', TEST_JSON) + def test_dict_convert(self, obj): + """Tests conversion to/from dict.""" + self._test_dict_convert(obj) + + @staticmethod + def _json_to_toml_dataclass(obj): + return coerce_to_dataclass(_convert_json_dataclass(type(obj), TOMLDataclass), obj) + + def _test_toml_string(self, obj, s): + cls = _convert_json_dataclass(type(obj), TOMLDataclass) + obj = coerce_to_dataclass(cls, obj) + assert obj.to_toml_string() == s + assert cls.from_toml_string(s) == obj + + @pytest.mark.parametrize(['obj', 's'], [ + (DCFloat(1), 'x = 1\n'), + (DCFloat(1.0), 'x = 1.0\n'), + (DCFloat(1e-12), 'x = 1e-12\n'), + (DCFloat(math.inf), 'x = inf\n'), + (DCFloat(math.nan), 'x = nan\n'), + ]) + def test_float(self, obj, s): + """Tests floating-point support.""" + self._test_toml_string(obj, s) + + @pytest.mark.parametrize(['obj', 's'], [ + (DCOptional(1), 'x = 1\n'), + (DCOptional(None), ''), + ]) + def test_optional(self, obj, s): + """Tests behavior of Optional types with TOML conversion.""" + self._test_toml_string(obj, s) + + @pytest.mark.parametrize(['obj', 's'], [ + (DCDatetime(datetime.strptime('2024-01-01', '%Y-%m-%d')), 'dt = 2024-01-01T00:00:00\n'), + ]) + def test_datetime(self, obj, s): + """Tests support for the datetime type.""" + self._test_toml_string(obj, s) + + @pytest.mark.parametrize('obj', TEST_JSON) + def test_toml_convert(self, obj, tmp_path): + """Tests conversion to/from TOML.""" + obj = TestTOML._json_to_toml_dataclass(obj) + assert isinstance(obj, TOMLDataclass) + # write to TOML text file + toml_path = tmp_path / 'test.toml' + with open(toml_path, 'w') as f: + obj.to_toml(f) + with open(toml_path) as f: + obj1 = type(obj).from_toml(f) + assert obj1 == obj + # write to TOML binary file + with open(toml_path, 'wb') as f: + obj.to_toml(f) + with open(toml_path, 'rb') as f: + obj2 = type(obj).from_toml(f) + assert obj2 == obj + # convert to TOML string + s = obj.to_toml_string() + obj3 = type(obj).from_toml_string(s) + assert obj3 == obj