Skip to content

Commit

Permalink
enum conversions (this may be too hacky...)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Oct 17, 2024
1 parent 94346b6 commit 91e6cbb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
24 changes: 24 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Tests for the ZHA model module."""

from enum import Enum

from zigpy.types import NWK
from zigpy.types.named import EUI64

from zha.model import BaseModel
from zha.zigbee.device import DeviceInfo, ZHAEvent


Expand Down Expand Up @@ -83,3 +86,24 @@ def test_ser_deser_zha_event():
'"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",'
'"lqi":1,"rssi":2,"last_seen":"","available":true,"device_type":"test","signature":{"foo":"bar"}}'
)


def test_raw_enum_handling():
"""Test the conversions for enums that are not StrEnums."""

class SomeEnum(Enum):
"""Some enum for testing."""

FOO = 1
BAR = 2

class JunkModel(BaseModel):
"""Model for testing raw enum handling."""

foo: SomeEnum

model = JunkModel(foo=SomeEnum.FOO.name)

assert model.foo == SomeEnum.FOO

assert model.model_dump()["foo"] == SomeEnum.FOO.name
48 changes: 44 additions & 4 deletions zha/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Shared models for ZHA."""

from enum import Enum, StrEnum
import logging
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
SerializerFunctionWrapHandler,
ValidationInfo,
field_serializer,
field_validator,
)
Expand All @@ -29,6 +32,37 @@ def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]:
return EUI64.convert(ieee)
return ieee

@field_validator("*", mode="before", check_fields=False)
@classmethod
def convert_enum(
cls: type, value: int | Enum | Any, validation_info: ValidationInfo
) -> Enum | Any:
"""Convert enum to Pydantic model."""
enum__type_info = cls.model_fields[validation_info.field_name].annotation
if hasattr(enum__type_info, "__args__"):
options = getattr(enum__type_info, "__args__")
else:
options = [enum__type_info]
for enum_type in options:
if (
isinstance(enum_type, type)
and issubclass(enum_type, Enum)
and not issubclass(enum_type, StrEnum)
):
# If the value is already an instance of the Enum, return it directly
if isinstance(value, enum_type):
return value

# If the value is a valid name, return the corresponding enum member
if isinstance(value, str) and value in enum_type.__members__:
return enum_type[value]

# If the value is a valid enum value, return the corresponding enum member
if isinstance(value, int) and value in enum_type._value2member_map_:
return enum_type(value)

Check warning on line 62 in zha/model.py

View check run for this annotation

Codecov / codecov/patch

zha/model.py#L61-L62

Added lines #L61 - L62 were not covered by tests

return value

@field_validator("nwk", mode="before", check_fields=False)
@classmethod
def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]:
Expand All @@ -37,10 +71,16 @@ def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]:
return NWK(nwk)
return nwk

@field_serializer("ieee", "device_ieee", check_fields=False)
def serialize_ieee(self, ieee: EUI64):
@field_serializer("*", mode="wrap", check_fields=False)
def serialize_enum_or_ieee(
self, value: EUI64 | Enum | Any, nxt: SerializerFunctionWrapHandler
) -> str | Any:
"""Customize how ieee is serialized."""
return str(ieee)
if isinstance(value, Enum) and not isinstance(value, StrEnum):
return value.name
if isinstance(value, EUI64):
return str(value)
return nxt(value)


class BaseEvent(BaseModel):
Expand Down

0 comments on commit 91e6cbb

Please sign in to comment.