Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Enums #29 #44

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
17 changes: 15 additions & 2 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from dataclasses import MISSING, Field, asdict, dataclass, fields, is_dataclass, replace
from datetime import date, datetime, time, timedelta
from enum import Enum, ReprEnum
from functools import reduce
from importlib import import_module
from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract
Expand Down Expand Up @@ -49,7 +50,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
return object
elif not isinstance(field_type, type):
raise TypeError(f"Annotation for field '{context}' is not a type")
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path):
elif issubclass(field_type, str | int | float | date | time | timedelta | ModuleType | Path | Enum | ReprEnum):
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
return field_type
elif field_type is type:
# https://github.com/python/mypy/issues/13026
Expand Down Expand Up @@ -209,7 +210,6 @@ def _check_field(field: Field, field_type: type, context: str) -> None:

@dataclass(slots=True)
class _ClassInfo(Generic[T]):

_cache: ClassVar[MutableMapping[type[Any], _ClassInfo[Any]]] = WeakKeyDictionary()

dataclass: type[T]
Expand Down Expand Up @@ -314,6 +314,15 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) ->
if not isinstance(value, str):
raise TypeError(f"Expected TOML string for path '{context}', got '{type(value).__name__}'")
return field_type(value)
elif issubclass(field_type, ReprEnum):
return field_type(value)
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
elif issubclass(field_type, Enum):
if not isinstance(value, str):
raise TypeError(f"'{value}' is not a valid key for enum '{field_type}', must be of type str")
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
for enum_value in field_type:
if enum_value.name.lower() == value.lower():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optimization: you could store value.lower() in a local variable, to avoid converting the same string multiple times.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no other case where this is used in this function I'm not sure this would help 🤔

return enum_value
raise TypeError(f"'{value}' is not a valid key for enum '{field_type}', could not be found")
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(value, field_type) and (
type(value) is not bool or field_type is bool or field_type is object
):
Expand Down Expand Up @@ -697,6 +706,10 @@ def _to_toml_pair(value: object) -> tuple[str | None, Any]:
return "-weeks", days // 7
else:
return "-days", days
case Enum():
return None, value.name.lower()
case ReprEnum():
return None, value.value
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
case ModuleType():
return None, value.__name__
case Mapping():
Expand Down
33 changes: 33 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum, IntEnum, auto
from io import BytesIO
from pathlib import Path
from types import ModuleType, NoneType, UnionType
Expand Down Expand Up @@ -859,3 +860,35 @@ def test_format_template_no_module(sourceless_class: type[Any]) -> None:
value = 0
""".strip()
)


class Verbosity(Enum):
QUIET = auto()
NORMAL = auto()
DETAILED = auto()


class IntVerbosity(IntEnum):
QUIET = 0
NORMAL = 1
DETAILED = 2


def test_format_with_enums() -> None:
@dataclass
class Log:
message: str
verbosity: Verbosity
verbosity_level: IntVerbosity

log = Log("Hello, World", Verbosity.DETAILED, IntVerbosity.DETAILED)

template = "\n".join(Binder(log).format_toml())

assert template == (
"""
message = 'Hello, World'
verbosity = 'detailed'
verbosity-level = 2
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
""".strip()
)
109 changes: 109 additions & 0 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from dataclasses import FrozenInstanceError, dataclass, field
from datetime import date, datetime, time, timedelta
from enum import Enum, IntEnum
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -1084,3 +1085,111 @@ def test_bind_merge() -> None:
assert merged_config.flag is True
assert merged_config.nested1.value == "sun"
assert merged_config.nested2.value == "cheese"


class Color(Enum):
RED = "#FF0000"
GREEN = "#00FF00"
BLUE = "#0000FF"


class Number(IntEnum):
ONE = 1
TWO = 2
THREE = 3


@dataclass
class EnumEntry:
name: str
color: Color
number: Number


def test_enums() -> None:
@dataclass
class Config:
best_colors: list[Color]
best_numbers: list[Number]
entries: list[EnumEntry]

with stream_text(
"""
best-colors = ["red", "green", "blue"]
best-numbers = [1, 2, 3]

[[entries]]
name = "Entry 1"
color = "blue"
number = 2

[[entries]]
name = "Entry 2"
color = "red"
number = 1
"""
) as stream:
config = Binder(Config).parse_toml(stream)

assert len(config.best_colors) == 3
assert len(config.best_numbers) == 3
assert config.best_colors.index(Color.RED) == 0
assert config.best_colors.index(Color.GREEN) == 1
assert config.best_colors.index(Color.BLUE) == 2
assert all(num in config.best_numbers for num in Number)
assert len(config.entries) == 2
assert config.entries[0].color is Color.BLUE
assert config.entries[0].number is Number.TWO
assert config.entries[1].color is Color.RED
assert config.entries[1].number is Number.ONE


def test_enum_with_invalid_value() -> None:
@dataclass
class UserFavorites:
favorite_number: Number
favorite_color: Color

with stream_text(
"""
favorite-number = "one"
favorite-color = "red"
"""
) as stream, pytest.raises(ValueError): # noqa: PT011
atomicptr marked this conversation as resolved.
Show resolved Hide resolved
Binder(UserFavorites).parse_toml(stream)


def test_enum_keys_being_case_insensitive() -> None:
@dataclass
class Theme:
primary: Color
secondary: Color
accent: Color

with stream_text(
"""
primary = "RED"
secondary = "green"
accent = "blUE"
"""
) as stream:
theme = Binder(Theme).parse_toml(stream)

assert theme.primary is Color.RED
assert theme.secondary is Color.GREEN
assert theme.accent is Color.BLUE


def test_key_based_enum_while_using_value_ident() -> None:
@dataclass
class UserColorPreference:
primary: Color
secondary: Color

with stream_text(
"""
primary = "#FF0000"
seconadry = "blue"
"""
) as stream, pytest.raises(TypeError):
Binder(UserColorPreference).parse_toml(stream)