diff --git a/src/rsp_restspawner/spawner.py b/src/rsp_restspawner/spawner.py index 8da1f07..cebdbf4 100644 --- a/src/rsp_restspawner/spawner.py +++ b/src/rsp_restspawner/spawner.py @@ -3,13 +3,13 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Callable, Coroutine from dataclasses import dataclass from datetime import timedelta from enum import Enum from functools import wraps from pathlib import Path -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, ParamSpec, TypeVar from httpx import AsyncClient, HTTPError from httpx_sse import ServerSentEvent, aconnect_sse @@ -23,7 +23,8 @@ SpawnFailedError, ) -F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) +P = ParamSpec("P") +T = TypeVar("T") __all__ = [ "LabStatus", @@ -107,17 +108,19 @@ def to_dict(self) -> dict[str, int | str]: } -def _convert_exception(f: F) -> F: +def _convert_exception( + f: Callable[P, Coroutine[None, None, T]] +) -> Callable[P, Coroutine[None, None, T]]: """Convert ``httpx`` exceptions to `ControllerWebError`.""" @wraps(f) - async def wrapper(*args: Any, **kwargs: Any) -> Any: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return await f(*args, **kwargs) except HTTPError as e: raise ControllerWebError.from_exception(e) from e - return cast(F, wrapper) + return wrapper class RSPRestSpawner(Spawner):