Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Enums #29 #44

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
48 changes: 42 additions & 6 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from dataclasses import MISSING, Field, asdict, dataclass, fields, is_dataclass, replace
from datetime import date, datetime, time, timedelta
from enum import Enum
from functools import reduce
from importlib import import_module
from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract
Expand All @@ -28,10 +29,21 @@
from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload
from weakref import WeakKeyDictionary

if sys.version_info < (3, 11):
import tomli as tomllib # pragma: no cover
else:
import tomllib # pragma: no cover
if sys.version_info < (3, 11): # pragma: no cover
import tomli as tomllib

if TYPE_CHECKING:

class ReprEnum(Enum):
...

else:
from enum import IntEnum, IntFlag

ReprEnum = IntEnum | IntFlag
else: # pragma: no cover
import tomllib # noqa: I001
from enum import ReprEnum


def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
Expand All @@ -49,7 +61,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
return object
elif not isinstance(field_type, type):
raise TypeError(f"Annotation for field '{context}' is not a type")
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path):
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path | Enum):
return field_type
elif field_type is type:
# https://github.com/python/mypy/issues/13026
Expand Down Expand Up @@ -209,7 +221,6 @@ def _check_field(field: Field, field_type: type, context: str) -> None:

@dataclass(slots=True)
class _ClassInfo(Generic[T]):

_cache: ClassVar[MutableMapping[type[Any], _ClassInfo[Any]]] = WeakKeyDictionary()

dataclass: type[T]
Expand Down Expand Up @@ -314,6 +325,24 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) ->
if not isinstance(value, str):
raise TypeError(f"Expected TOML string for path '{context}', got '{type(value).__name__}'")
return field_type(value)
elif issubclass(field_type, ReprEnum):
if issubclass(field_type, int) and not isinstance(value, int):
raise TypeError(f"Value for '{context}': '{value}' is not of type int")
if issubclass(field_type, str) and not isinstance(value, str):
raise TypeError(f"Value for '{context}': '{value}' is not of type str")
return field_type(value)
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
elif issubclass(field_type, Enum):
if not isinstance(value, str):
raise TypeError(
f"Value for '{context}': '{value}' is not a valid key for enum '{field_type}', "
f"must be of type str"
)
for enum_value in field_type:
if enum_value.name.lower() == value.lower():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optimization: you could store value.lower() in a local variable, to avoid converting the same string multiple times.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no other case where this is used in this function I'm not sure this would help 🤔

return enum_value
raise TypeError(
f"Value for '{context}': '{value}' is not a valid key for enum '{field_type}', could not be found"
)
elif isinstance(value, field_type) and (
type(value) is not bool or field_type is bool or field_type is object
):
Expand Down Expand Up @@ -668,6 +697,13 @@ def format_toml_pair(key: str, value: object) -> str:
def _to_toml_pair(value: object) -> tuple[str | None, Any]:
"""Return a TOML-compatible suffix and value pair with the data from the given rich value object."""
match value:
# enums have to be checked before basic types because for instance
# IntEnum is also of type int
case Enum():
if isinstance(value, ReprEnum):
return None, value.value
else:
return None, value.name.lower()
case str() | int() | float() | date() | time() | Path(): # note: 'bool' is a subclass of 'int'
return None, value
case timedelta():
Expand Down
33 changes: 33 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum, IntEnum, auto
from io import BytesIO
from pathlib import Path
from types import ModuleType, NoneType, UnionType
Expand Down Expand Up @@ -859,3 +860,35 @@ def test_format_template_no_module(sourceless_class: type[Any]) -> None:
value = 0
""".strip()
)


class Verbosity(Enum):
QUIET = auto()
NORMAL = auto()
DETAILED = auto()


class IntVerbosity(IntEnum):
QUIET = 0
NORMAL = 1
DETAILED = 2


def test_format_with_enums() -> None:
@dataclass
class Log:
message: str
verbosity: Verbosity
verbosity_level: IntVerbosity

log = Log("Hello, World", Verbosity.DETAILED, IntVerbosity.DETAILED)

template = "\n".join(Binder(log).format_toml())

assert template == (
"""
message = 'Hello, World'
verbosity = 'detailed'
verbosity-level = 2
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
""".strip()
)
142 changes: 142 additions & 0 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from dataclasses import FrozenInstanceError, dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum, IntEnum
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -1084,3 +1085,144 @@ def test_bind_merge() -> None:
assert merged_config.flag is True
assert merged_config.nested1.value == "sun"
assert merged_config.nested2.value == "cheese"


class Color(Enum):
RED = "#FF0000"
GREEN = "#00FF00"
BLUE = "#0000FF"


class Number(IntEnum):
ONE = 1
TWO = 2
THREE = 3


class Weekday(Enum):
MONDAY = 0
TUESDAY = 1
WEDNESDAY = 2
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
SUNDAY = 6


@dataclass
class EnumEntry:
name: str
color: Color
number: Number


def test_enums() -> None:
@dataclass
class Config:
best_colors: list[Color]
best_numbers: list[Number]
entries: list[EnumEntry]

with stream_text(
"""
best-colors = ["red", "green", "blue"]
best-numbers = [1, 2, 3]

[[entries]]
name = "Entry 1"
color = "blue"
number = 2

[[entries]]
name = "Entry 2"
color = "red"
number = 1
"""
) as stream:
config = Binder(Config).parse_toml(stream)

assert len(config.best_colors) == 3
assert len(config.best_numbers) == 3
assert config.best_colors.index(Color.RED) == 0
assert config.best_colors.index(Color.GREEN) == 1
assert config.best_colors.index(Color.BLUE) == 2
assert all(num in config.best_numbers for num in Number)
assert len(config.entries) == 2
assert config.entries[0].color is Color.BLUE
assert config.entries[0].number is Number.TWO
assert config.entries[1].color is Color.RED
assert config.entries[1].number is Number.ONE


def test_enum_with_invalid_value() -> None:
@dataclass
class UserFavorites:
favorite_number: Number
favorite_color: Color

with stream_text(
"""
favorite-number = "one"
favorite-color = "red"
"""
) as stream, pytest.raises(TypeError):
Binder(UserFavorites).parse_toml(stream)


def test_enum_keys_being_case_insensitive() -> None:
@dataclass
class Theme:
primary: Color
secondary: Color
accent: Color

with stream_text(
"""
primary = "RED"
secondary = "green"
accent = "blUE"
"""
) as stream:
theme = Binder(Theme).parse_toml(stream)

assert theme.primary is Color.RED
assert theme.secondary is Color.GREEN
assert theme.accent is Color.BLUE


def test_key_based_enum_while_using_value_ident() -> None:
@dataclass
class UserColorPreference:
primary: Color
secondary: Color

with stream_text(
"""
primary = "#FF0000"
seconadry = "blue"
"""
) as stream, pytest.raises(TypeError):
Binder(UserColorPreference).parse_toml(stream)


def test_enum_parsing_with_invalid_key_type() -> None:
@dataclass
class UserPrefs:
name: str
start_of_the_week: Weekday

with stream_text(
"""
name = "Peter Testuser"
start-of-the-week = "sunday"
"""
) as stream:
Binder(UserPrefs).parse_toml(stream)

with stream_text(
"""
name = "Peter Testuser"
start-of-the-week = 1
"""
) as stream, pytest.raises(TypeError):
Binder(UserPrefs).parse_toml(stream)