diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 54bbb99dea..5ee979c4cc 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -81,6 +81,10 @@ ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) ResponseT = TypeVar("ResponseT") +P = TypeVar("P") +T = TypeVar("T") +SimpleFunction = Callable[P, T] +DecoratorFunction = Callable[[SimpleFunction], SimpleFunction] if TYPE_CHECKING: from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -951,7 +955,7 @@ def route( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: raise NotImplementedError() def use(self, middlewares: List[Callable[..., Response]]) -> None: @@ -1011,7 +1015,7 @@ def get( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: """Get route decorator with GET `method` Examples @@ -1068,7 +1072,7 @@ def post( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: """Post route decorator with POST `method` Examples @@ -1126,7 +1130,7 @@ def put( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: """Put route decorator with PUT `method` Examples @@ -1184,7 +1188,7 @@ def delete( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: """Delete route decorator with DELETE `method` Examples @@ -1241,7 +1245,7 @@ def patch( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, - ): + ) -> DecoratorFunction: """Patch route decorator with PATCH `method` Examples @@ -1301,7 +1305,7 @@ def head( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, - ): + ) -> DecoratorFunction: """Head route decorator with HEAD `method` Examples @@ -1356,7 +1360,7 @@ def _reset_processed_stack(self): """Reset the Processed Stack Frames""" self.processed_stack_frames.clear() - def append_context(self, **additional_context): + def append_context(self, **additional_context) -> None: """Append key=value data as routing context""" self.context.update(**additional_context) @@ -1528,7 +1532,7 @@ def __init__( self._dynamic_routes: List[Route] = [] self._static_routes: List[Route] = [] self._route_keys: List[str] = [] - self._exception_handlers: Dict[Type, Callable] = {} + self._exception_handlers: Dict[Type, DecoratorFunction] = {} self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} @@ -1988,7 +1992,7 @@ def route( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: """Route decorator includes parameter `method`""" def register_resolver(func: Callable): @@ -2345,7 +2349,7 @@ def not_found(self, func: Optional[Callable] = None): return self.exception_handler(NotFoundError) return self.exception_handler(NotFoundError)(func) - def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]): + def exception_handler(self, exc_class: Union[Type[Exception], List[Type[Exception]]]) -> DecoratorFunction: def register_exception_handler(func: Callable): if isinstance(exc_class, list): # pragma: no cover for exp in exc_class: @@ -2356,7 +2360,7 @@ def register_exception_handler(func: Callable): return register_exception_handler - def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]: + def _lookup_exception_handler(self, exp_type: Type) -> Optional[DecoratorFunction]: # Use "Method Resolution Order" to allow for matching against a base class # of an exception for cls in exp_type.__mro__: @@ -2506,12 +2510,12 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: class Router(BaseRouter): """Router helper class to allow splitting ApiGatewayResolver into multiple files""" - def __init__(self): + def __init__(self) -> "Router": self._routes: Dict[tuple, Callable] = {} self._routes_with_middleware: Dict[tuple, List[Callable]] = {} self.api_resolver: Optional[BaseRouter] = None self.context = {} # early init as customers might add context before event resolution - self._exception_handlers: Dict[Type, Callable] = {} + self._exception_handlers: Dict[Type, DecoratorFunction] = {} def route( self, @@ -2530,7 +2534,7 @@ def route( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: def register_route(func: Callable): # All dict keys needs to be hashable. So we'll need to do some conversions: methods = (method,) if isinstance(method, str) else tuple(method) @@ -2636,7 +2640,7 @@ def route( security: Optional[List[Dict[str, List[str]]]] = None, openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, - ): + ) -> DecoratorFunction: # NOTE: see #1552 for more context. return super().route( rule.rstrip("/"),