Skip to content

Commit

Permalink
clean up some more
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Oct 12, 2024
1 parent 871bf80 commit 86164ab
Show file tree
Hide file tree
Showing 27 changed files with 39 additions and 3,166 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enabled = true
ignore-words-list = "hass"

[tool.mypy]
plugins = "pydantic.mypy"
python_version = "3.12"
check_untyped_defs = true
disallow_incomplete_defs = true
Expand Down
15 changes: 8 additions & 7 deletions zhaws/client/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Client implementation for the zhaws.client."""

from __future__ import annotations

import asyncio
Expand All @@ -12,9 +13,9 @@
from aiohttp.http_websocket import WSMsgType
from async_timeout import timeout

from zha.event import EventBase
from zhaws.client.model.commands import CommandResponse, ErrorResponse
from zhaws.client.model.messages import Message
from zhaws.event import EventBase
from zhaws.server.websocket.api.model import WebSocketCommand

SIZE_PARSE_JSON_EXECUTOR = 8192
Expand Down Expand Up @@ -85,13 +86,13 @@ async def async_send_command(
async with timeout(20):
await self._send_json_message(command.json(exclude_none=True))
return await future
except asyncio.TimeoutError:
_LOGGER.error("Timeout waiting for response")
except TimeoutError:
_LOGGER.exception("Timeout waiting for response")
return CommandResponse.parse_obj(
{"message_id": message_id, "success": False}
)
except Exception as err:
_LOGGER.error("Error sending command: %s", err, exc_info=err)
_LOGGER.exception("Error sending command", exc_info=err)
finally:
self._result_futures.pop(message_id)

Expand All @@ -112,7 +113,7 @@ async def connect(self) -> None:
max_msg_size=0,
)
except client_exceptions.ClientError as err:
_LOGGER.error("Error connecting to server: %s", err)
_LOGGER.exception("Error connecting to server", exc_info=err)
raise err

async def listen_loop(self) -> None:
Expand Down Expand Up @@ -193,7 +194,7 @@ def _handle_incoming_message(self, msg: dict) -> None:
try:
message = Message.parse_obj(msg).__root__
except Exception as err:
_LOGGER.error("Error parsing message: %s", msg, exc_info=err)
_LOGGER.exception("Error parsing message: %s", msg, exc_info=err)

if message.message_type == "result":
future = self._result_futures.get(message.message_id)
Expand Down Expand Up @@ -230,7 +231,7 @@ def _handle_incoming_message(self, msg: dict) -> None:
try:
self.emit(message.event_type, message)
except Exception as err:
_LOGGER.error("Error handling event: %s", err, exc_info=err)
_LOGGER.exception("Error handling event", exc_info=err)

async def _send_json_message(self, message: str) -> None:
"""Send a message.
Expand Down
2 changes: 1 addition & 1 deletion zhaws/client/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from async_timeout import timeout
from zigpy.types.named import EUI64

from zha.event import EventBase
from zhaws.client.client import Client
from zhaws.client.helpers import (
AlarmControlPanelHelper,
Expand Down Expand Up @@ -45,7 +46,6 @@
ZHAEvent,
)
from zhaws.client.proxy import DeviceProxy, GroupProxy
from zhaws.event import EventBase
from zhaws.server.const import ControllerEvents, EventTypes
from zhaws.server.websocket.api.model import WebSocketCommand

Expand Down
3 changes: 2 additions & 1 deletion zhaws/client/model/events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Event models for zhawss.
Events are unprompted messages from the server -> client and they contain only the data that is necessary to handle the event.
Events are unprompted messages from the server -> client and they contain only the data that is necessary to
handle the event.
"""

from typing import Annotated, Any, Literal, Optional, Union
Expand Down
12 changes: 4 additions & 8 deletions zhaws/client/model/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Annotated, Any, Literal, Optional, Union

from pydantic import validator
from pydantic import field_validator
from pydantic.fields import Field
from zigpy.types.named import EUI64

from zhaws.event import EventBase
from zha.event import EventBase
from zhaws.model import BaseModel


Expand Down Expand Up @@ -599,12 +599,8 @@ class Group(BaseModel):
],
]

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("members", pre=True, always=True, each_item=False, check_fields=False)
def convert_member_ieee(
cls, members: dict[str, dict], values: dict[str, Any], **kwargs: Any
) -> dict[EUI64, Device]:
@field_validator("members", mode="before", check_fields=False)
def convert_member_ieee(cls, members: dict[str, dict]) -> dict[EUI64, GroupMember]:
"""Convert member IEEE to EUI64."""
return {EUI64.convert(k): GroupMember(**v) for k, v in members.items()}

Expand Down
2 changes: 1 addition & 1 deletion zhaws/client/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from typing import TYPE_CHECKING, Any

from zha.event import EventBase
from zhaws.client.model.events import PlatformEntityStateChangedEvent
from zhaws.client.model.types import (
ButtonEntity,
Device as DeviceModel,
Group as GroupModel,
)
from zhaws.event import EventBase

if TYPE_CHECKING:
from zhaws.client.client import Client
Expand Down
72 changes: 0 additions & 72 deletions zhaws/event.py

This file was deleted.

21 changes: 5 additions & 16 deletions zhaws/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, no_type_check

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, validator
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, field_validator
from zigpy.types.named import EUI64

if TYPE_CHECKING:
Expand All @@ -17,29 +17,18 @@ class BaseModel(PydanticBaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("ieee", pre=True, always=True, each_item=False, check_fields=False)
def convert_ieee(
cls, ieee: Optional[Union[str, EUI64]], values: dict[str, Any], **kwargs: Any
) -> Optional[EUI64]:
@field_validator("ieee", mode="before", check_fields=False)
def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]:
"""Convert ieee to EUI64."""
if ieee is None:
return None
if isinstance(ieee, str):
return EUI64.convert(ieee)
return ieee

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator(
"device_ieee", pre=True, always=True, each_item=False, check_fields=False
)
@field_validator("device_ieee", mode="before", check_fields=False)
def convert_device_ieee(
cls,
device_ieee: Optional[Union[str, EUI64]],
values: dict[str, Any],
**kwargs: Any,
cls, device_ieee: Optional[Union[str, EUI64]]
) -> Optional[EUI64]:
"""Convert device ieee to EUI64."""
if device_ieee is None:
Expand Down
3 changes: 1 addition & 2 deletions zhaws/server/const.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Constants."""

from enum import StrEnum
from typing import Final

from zhaws.backports.enum import StrEnum


class APICommands(StrEnum):
"""WS API commands."""
Expand Down
14 changes: 8 additions & 6 deletions zhaws/server/websocket/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Client classes for zhawss."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
import json
import logging
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Literal

from pydantic import ValidationError
from websockets.server import WebSocketServerProtocol
Expand Down Expand Up @@ -124,9 +126,9 @@ def _send_data(self, data: dict[str, Any]) -> None:
async def _handle_incoming_message(self, message: str | bytes) -> None:
"""Handle an incoming message."""
_LOGGER.info("Message received: %s", message)
handlers: dict[
str, tuple[Callable, WebSocketCommand]
] = self._client_manager.server.data[WEBSOCKET_API]
handlers: dict[str, tuple[Callable, WebSocketCommand]] = (
self._client_manager.server.data[WEBSOCKET_API]
)

loaded_message = json.loads(message)
_LOGGER.debug(
Expand Down Expand Up @@ -188,9 +190,9 @@ def will_accept_message(self, message: dict[str, Any]) -> bool:
class ClientListenRawZCLCommand(WebSocketCommand):
"""Listen to raw ZCL data."""

command: Literal[
command: Literal[APICommands.CLIENT_LISTEN_RAW_ZCL] = (
APICommands.CLIENT_LISTEN_RAW_ZCL
] = APICommands.CLIENT_LISTEN_RAW_ZCL
)


class ClientListenCommand(WebSocketCommand):
Expand Down
7 changes: 6 additions & 1 deletion zhaws/server/websocket/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import websockets

from zha.application.discovery import PLATFORMS
from zhaws.server.config.model import ServerConfiguration
from zhaws.server.const import APICommands
from zhaws.server.decorators import periodic
Expand All @@ -35,12 +36,16 @@ class Server:
def __init__(self, *, configuration: ServerConfiguration) -> None:
"""Initialize the server."""
self._config = configuration
self._ws_server: websockets.Serve | None = None
self._ws_server: websockets.WebSocketServer | None = None
self._controller: Controller = Controller(self)
self._client_manager: ClientManager = ClientManager(self)
self._stopped_event: asyncio.Event = asyncio.Event()
self._tracked_tasks: list[asyncio.Task] = []
self._tracked_completable_tasks: list[asyncio.Task] = []
self.data: dict[Any, Any] = {}
for platform in PLATFORMS:
self.data.setdefault(platform, [])
self._register_api_commands()
self._register_api_commands()
self._tracked_tasks.append(
asyncio.create_task(
Expand Down
4 changes: 2 additions & 2 deletions zhaws/server/zigbee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from pydantic import Field
from zigpy.types.named import EUI64

from zha.zigbee.device import Device
from zha.zigbee.group import Group, GroupMemberReference
from zhaws.server.const import DEVICES, DURATION, GROUPS, APICommands
from zhaws.server.websocket.api import decorators, register_api_command
from zhaws.server.websocket.api.model import WebSocketCommand
from zhaws.server.zigbee.controller import Controller
from zhaws.server.zigbee.device import Device
from zhaws.server.zigbee.group import Group, GroupMemberReference

if TYPE_CHECKING:
from zhaws.server.websocket.client import Client
Expand Down
Loading

0 comments on commit 86164ab

Please sign in to comment.