From 7b1a9fee10a2040623e9cf3cccece12e5684e707 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Wed, 15 May 2024 16:17:55 +0000 Subject: [PATCH] Improve typing This passes pyright although we are only testing mypy for now as pvi does not yet comply with pyright and could cause unsolvable typing conflicts here if we use those APIs. --- pyproject.toml | 1 + src/odin_fastcs/http_connection.py | 10 ++- src/odin_fastcs/odin_controller.py | 102 ++++++++++++++++++----------- src/odin_fastcs/util.py | 24 +------ 4 files changed, 73 insertions(+), 64 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f39de9c..76d6fdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "sphinx-design", "tox-direct", "types-mock", + "types-requests", ] [project.scripts] diff --git a/src/odin_fastcs/http_connection.py b/src/odin_fastcs/http_connection.py index 98b2acf..d6f8019 100644 --- a/src/odin_fastcs/http_connection.py +++ b/src/odin_fastcs/http_connection.py @@ -4,7 +4,7 @@ ValueType = bool | int | float | str JsonElementary = str | int | float | bool | None -JsonType = JsonElementary | list[JsonElementary] | Mapping[str, JsonElementary] +JsonType = JsonElementary | list["JsonType"] | Mapping[str, "JsonType"] class HTTPConnection: @@ -44,7 +44,7 @@ def get_session(self) -> ClientSession: raise ConnectionRefusedError("Session is not open") - async def get(self, uri: str, headers: dict | None = None) -> JsonType: + async def get(self, uri: str, headers: dict | None = None) -> dict[str, JsonType]: """Perform HTTP GET request and return response content as JSON. Args: @@ -55,7 +55,11 @@ async def get(self, uri: str, headers: dict | None = None) -> JsonType: """ session = self.get_session() async with session.get(self.full_url(uri), headers=headers) as response: - return await response.json() + match await response.json(): + case dict() as d: + return d + case _: + raise ValueError(f"Got unexpected response:\n{response}") async def get_bytes(self, uri: str) -> tuple[ClientResponse, bytes]: """Perform HTTP GET request and return response content as bytes. diff --git a/src/odin_fastcs/odin_controller.py b/src/odin_fastcs/odin_controller.py index e44e6d1..179559f 100644 --- a/src/odin_fastcs/odin_controller.py +++ b/src/odin_fastcs/odin_controller.py @@ -1,11 +1,12 @@ import asyncio import logging +from collections.abc import Mapping from dataclasses import dataclass from typing import Any from fastcs.attributes import AttrR, AttrRW, AttrW, Handler from fastcs.connections.ip_connection import IPConnectionSettings -from fastcs.controller import Controller +from fastcs.controller import Controller, SubController from fastcs.datatypes import Bool, Float, Int, String from fastcs.util import snake_to_pascal @@ -31,48 +32,50 @@ class ParamTreeHandler(Handler): async def put( self, - controller: Any, + controller: "OdinController", attr: AttrW[Any], value: Any, ) -> None: try: response = await controller._connection.put(self.path, value) - if "error" in response: - raise AdapterResponseError(response["error"]) + match response: + case {"error": error}: + raise AdapterResponseError(error) except Exception as e: logging.error("Update loop failed for %s:\n%s", self.path, e) async def update( self, - controller: Any, + controller: "OdinController", attr: AttrR[Any], ) -> None: try: response = await controller._connection.get(self.path) - # TODO: Don't like this... - value = response[self.path.split("/")[-1]] + + # TODO: This would be nicer if the key was 'value' so we could match + parameter = self.path.split("/")[-1] + value = response.get(parameter, None) + if value is None: + raise ValueError(f"{parameter} not found in response:\n{response}") + await attr.set(value) except Exception as e: logging.error("Update loop failed for %s:\n%s", self.path, e) - - - -class OdinController(Controller): +class OdinController(SubController): def __init__( self, connection: HTTPConnection, - param_tree: dict[str, Any], + param_tree: Mapping[str, Any], api_prefix: str, process_prefix: str, ): - super().__init__() + super().__init__(process_prefix) self._connection = connection self._param_tree = param_tree self._api_prefix = api_prefix - self._path = process_prefix async def _create_parameter_tree(self): parameters = create_odin_parameters(self._param_tree) @@ -129,9 +132,16 @@ def __init__(self, settings: IPConnectionSettings) -> None: async def initialise(self) -> None: self._connection.open() - adapters: list[str] = ( - await self._connection.get(f"{self.API_PREFIX}/adapters") - )["adapters"] + adapters_response = await self._connection.get(f"{self.API_PREFIX}/adapters") + match adapters_response: + case {"adapters": [*adapter_list]}: + adapters = tuple(a for a in adapter_list if isinstance(a, str)) + if len(adapters) != len(adapter_list): + raise ValueError(f"Received invalid adapters list:\n{adapter_list}") + case _: + raise ValueError( + f"Did not find valid adapters in response:\n{adapters_response}" + ) for adapter in adapters: if adapter in IGNORED_ADAPTERS: @@ -139,11 +149,16 @@ async def initialise(self) -> None: # Get full parameter tree and split into parameters at the root and under # an index where there are N identical trees for each underlying process - response: dict[str, Any] = await self._connection.get( + response = await self._connection.get( f"{self.API_PREFIX}/{adapter}", headers=REQUEST_METADATA_HEADER ) + assert isinstance(response, Mapping) root_tree = {k: v for k, v in response.items() if not k.isdigit()} - indexed_trees = {k: v for k, v in response.items() if k.isdigit()} + indexed_trees = { + k: v + for k, v in response.items() + if k.isdigit() and isinstance(v, Mapping) + } odin_controller = OdinController( self._connection, @@ -171,38 +186,45 @@ async def connect(self) -> None: class FPOdinController(OdinController): - def __init__(self, settings: IPConnectionSettings, api: str = "0.1"): + def __init__( + self, + connection: HTTPConnection, + param_tree: Mapping[str, Any], + api: str = "0.1", + ): super().__init__( - settings, f"api/{api}/fp", "FP", param_tree=True, process_params=False + connection, + param_tree, + f"api/{api}/fp", + "FP", ) class FROdinController(OdinController): - def __init__(self, settings: IPConnectionSettings, api: str = "0.1"): + def __init__( + self, + connection: HTTPConnection, + param_tree: Mapping[str, Any], + api: str = "0.1", + ): super().__init__( - settings, f"api/{api}/fr", "FR", param_tree=True, process_params=True + connection, + param_tree, + f"api/{api}/fr", + "FR", ) class MLOdinController(OdinController): - def __init__(self, settings: IPConnectionSettings, api: str = "0.1"): - super().__init__( - settings, - f"api/{api}/meta_listener", - "ML", - param_tree=True, - process_params=False, - ) - - -class OdinDetectorController(OdinController): def __init__( - self, adapter_name: str, settings: IPConnectionSettings, api: str = "0.1" + self, + connection: HTTPConnection, + param_tree: Mapping[str, Any], + api: str = "0.1", ): super().__init__( - settings, - f"api/{api}/{adapter_name}", - adapter_name.capitalize(), - param_tree=True, - process_params=False, + connection, + param_tree, + f"api/{api}/meta_listener", + "ML", ) diff --git a/src/odin_fastcs/util.py b/src/odin_fastcs/util.py index dfe189f..2a4fff2 100644 --- a/src/odin_fastcs/util.py +++ b/src/odin_fastcs/util.py @@ -9,7 +9,7 @@ def is_metadata_object(v: Any) -> bool: @dataclass class OdinParameter: - uri: str + uri: list[str] """Full URI.""" metadata: dict[str, Any] """JSON response from GET of parameter.""" @@ -32,12 +32,12 @@ def create_odin_parameters(metadata: Mapping[str, Any]) -> list[OdinParameter]: """ return [ OdinParameter(uri=uri, metadata=metadata) - for uri, metadata in _walk_odin_metadata(metadata) + for uri, metadata in _walk_odin_metadata(metadata, []) ] def _walk_odin_metadata( - tree: dict[str, Any], path: str = None + tree: Mapping[str, Any], path: list[str] ) -> Iterator[tuple[list[str], dict[str, Any]]]: """Walk through tree and yield the leaves and their paths. @@ -49,7 +49,6 @@ def _walk_odin_metadata( (path to leaf, value of leaf) """ - path = path or [] for node_name, node_value in tree.items(): if node_name: node_path = path + [node_name] @@ -96,20 +95,3 @@ def infer_metadata(parameter: int | float | bool | str, uri: list[str]): "type": type(parameter).__name__, "writeable": "config" in uri, } - - -def tag_key_clashes(parameters: list[OdinParameter]): - """Find key clashes between subsystems and tag parameters to use extended name. - - Modifies list of parameters in place. - - Args: - parameters: Parameters to search - - """ - for idx, parameter in enumerate(parameters): - for other in parameters[idx + 1 :]: - if parameter.key == other.key: - parameter.has_unique_key = False - other.has_unique_key = False - break