diff --git a/tests/test_model.py b/tests/test_model.py index 604cf9d0..ca14b474 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,11 @@ """Tests for the ZHA model module.""" +from enum import Enum + from zigpy.types import NWK from zigpy.types.named import EUI64 +from zha.model import BaseModel from zha.zigbee.device import DeviceInfo, ZHAEvent @@ -83,3 +86,24 @@ def test_ser_deser_zha_event(): '"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",' '"lqi":1,"rssi":2,"last_seen":"","available":true,"device_type":"test","signature":{"foo":"bar"}}' ) + + +def test_raw_enum_handling(): + """Test the conversions for enums that are not StrEnums.""" + + class SomeEnum(Enum): + """Some enum for testing.""" + + FOO = 1 + BAR = 2 + + class JunkModel(BaseModel): + """Model for testing raw enum handling.""" + + foo: SomeEnum + + model = JunkModel(foo=SomeEnum.FOO.name) + + assert model.foo == SomeEnum.FOO + + assert model.model_dump()["foo"] == SomeEnum.FOO.name diff --git a/zha/model.py b/zha/model.py index eb366603..a143df52 100644 --- a/zha/model.py +++ b/zha/model.py @@ -1,11 +1,14 @@ """Shared models for ZHA.""" +from enum import Enum, StrEnum import logging -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, + SerializerFunctionWrapHandler, + ValidationInfo, field_serializer, field_validator, ) @@ -29,6 +32,37 @@ def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: return EUI64.convert(ieee) return ieee + @field_validator("*", mode="before", check_fields=False) + @classmethod + def convert_enum( + cls: type, value: int | Enum | Any, validation_info: ValidationInfo + ) -> Enum | Any: + """Convert enum to Pydantic model.""" + enum__type_info = cls.model_fields[validation_info.field_name].annotation + if hasattr(enum__type_info, "__args__"): + options = getattr(enum__type_info, "__args__") + else: + options = [enum__type_info] + for enum_type in options: + if ( + isinstance(enum_type, type) + and issubclass(enum_type, Enum) + and not issubclass(enum_type, StrEnum) + ): + # If the value is already an instance of the Enum, return it directly + if isinstance(value, enum_type): + return value + + # If the value is a valid name, return the corresponding enum member + if isinstance(value, str) and value in enum_type.__members__: + return enum_type[value] + + # If the value is a valid enum value, return the corresponding enum member + if isinstance(value, int) and value in enum_type._value2member_map_: + return enum_type(value) + + return value + @field_validator("nwk", mode="before", check_fields=False) @classmethod def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: @@ -37,10 +71,16 @@ def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: return NWK(nwk) return nwk - @field_serializer("ieee", "device_ieee", check_fields=False) - def serialize_ieee(self, ieee: EUI64): + @field_serializer("*", mode="wrap", check_fields=False) + def serialize_enum_or_ieee( + self, value: EUI64 | Enum | Any, nxt: SerializerFunctionWrapHandler + ) -> str | Any: """Customize how ieee is serialized.""" - return str(ieee) + if isinstance(value, Enum) and not isinstance(value, StrEnum): + return value.name + if isinstance(value, EUI64): + return str(value) + return nxt(value) class BaseEvent(BaseModel):