Skip to content

Commit f615e52

Browse files
committed
Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links
1 parent 2c23410 commit f615e52

File tree

8 files changed

+264
-209
lines changed

8 files changed

+264
-209
lines changed

pccommon/pccommon/middleware.py

Lines changed: 80 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,99 @@
11
import asyncio
22
import logging
33
import time
4-
from typing import Awaitable, Callable
4+
from functools import wraps
5+
from typing import Any, Callable
56

6-
from fastapi import HTTPException, Request, Response
7+
from fastapi import HTTPException, Request
78
from fastapi.applications import FastAPI
8-
from starlette.middleware.base import BaseHTTPMiddleware
9-
from starlette.responses import PlainTextResponse
9+
from fastapi.dependencies.utils import (
10+
get_body_field,
11+
get_dependant,
12+
get_parameterless_sub_dependant,
13+
)
14+
from fastapi.responses import PlainTextResponse
15+
from fastapi.routing import APIRoute, request_response
1016
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
11-
from starlette.types import Message
1217

1318
from pccommon.logging import get_custom_dimensions
14-
from pccommon.tracing import trace_request
1519

1620
logger = logging.getLogger(__name__)
1721

1822

19-
async def handle_exceptions(
20-
request: Request,
21-
call_next: Callable[[Request], Awaitable[Response]],
22-
) -> Response:
23-
try:
24-
return await call_next(request)
25-
except HTTPException:
23+
async def http_exception_handler(request: Request, exc: Exception) -> Any:
24+
# Log the exception with additional request info if needed
25+
logger.exception("Exception when handling request", exc_info=exc)
26+
# Return a custom response for HTTPException
27+
if isinstance(exc, HTTPException):
2628
raise
27-
except Exception as e:
29+
# Handle other exceptions, possibly with a generic response
30+
else:
2831
logger.exception(
2932
"Exception when handling request",
30-
extra=get_custom_dimensions({"stackTrace": f"{e}"}, request),
33+
extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request),
3134
)
3235
raise
3336

3437

35-
class RequestTracingMiddleware(BaseHTTPMiddleware):
36-
"""Custom middleware to use opencensus request traces
37-
38-
Middleware implementations that access a Request object directly
39-
will cause subsequent middleware or route handlers to hang. See
40-
41-
https://github.com/tiangolo/fastapi/issues/394
42-
43-
for more details on this implementation.
44-
45-
An alternative approach is to use dependencies on the APIRouter, but
46-
the stac-fast api implementation makes that difficult without having
47-
to override much of the app initialization.
48-
"""
49-
50-
def __init__(self, app: FastAPI, service_name: str):
51-
super().__init__(app)
52-
self.service_name = service_name
53-
54-
async def set_body(self, request: Request) -> None:
55-
receive_ = await request._receive()
56-
57-
async def receive() -> Message:
58-
return receive_
59-
60-
request._receive = receive
61-
62-
async def dispatch(
63-
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
64-
) -> Response:
65-
await self.set_body(request)
66-
response = await trace_request(self.service_name, request, call_next)
67-
return response
68-
69-
70-
async def timeout_middleware(
71-
request: Request,
72-
call_next: Callable[[Request], Awaitable[Response]],
73-
timeout: int,
74-
) -> Response:
75-
try:
76-
start_time = time.time()
77-
return await asyncio.wait_for(call_next(request), timeout=timeout)
78-
79-
except asyncio.TimeoutError:
80-
process_time = time.time() - start_time
81-
log_dimensions = get_custom_dimensions({"request_time": process_time}, request)
82-
83-
logger.exception(
84-
"Request timeout",
85-
extra=log_dimensions,
86-
)
87-
88-
ref_id = log_dimensions["custom_dimensions"].get("ref_id")
89-
debug_msg = f"Debug information for support: {ref_id}" if ref_id else ""
90-
91-
return PlainTextResponse(
92-
f"The request exceeded the maximum allowed time, please try again."
93-
" If the issue persists, please contact [email protected]."
94-
f"\n\n{debug_msg}",
95-
status_code=HTTP_504_GATEWAY_TIMEOUT,
96-
)
38+
def with_timeout(
39+
timeout_seconds: float,
40+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
41+
def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]:
42+
if asyncio.iscoroutinefunction(func):
43+
logger.debug("Adding timeout to function %s", func.__name__)
44+
45+
@wraps(func)
46+
async def inner(*args: Any, **kwargs: Any) -> Any:
47+
start_time = time.monotonic()
48+
try:
49+
return await asyncio.wait_for(
50+
func(*args, **kwargs), timeout=timeout_seconds
51+
)
52+
except asyncio.TimeoutError as e:
53+
process_time = time.monotonic() - start_time
54+
# don't have a request object here to get custom dimensions.
55+
log_dimensions = {
56+
"request_time": process_time,
57+
}
58+
logger.exception(
59+
f"Request timeout {e}",
60+
extra=log_dimensions,
61+
)
62+
63+
ref_id = log_dimensions.get("ref_id")
64+
debug_msg = (
65+
f" Debug information for support: {ref_id}" if ref_id else ""
66+
)
67+
68+
return PlainTextResponse(
69+
f"The request exceeded the maximum allowed time, please"
70+
" try again. If the issue persists, please contact "
71+
72+
f"\n\n{debug_msg}",
73+
status_code=HTTP_504_GATEWAY_TIMEOUT,
74+
)
75+
76+
return inner
77+
else:
78+
return func
79+
80+
return with_timeout_
81+
82+
83+
def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
84+
for route in app.router.routes:
85+
if isinstance(route, APIRoute):
86+
new_endpoint = with_timeout(timeout_seconds)(route.endpoint)
87+
route.endpoint = new_endpoint
88+
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
89+
for depends in route.dependencies[::-1]:
90+
route.dependant.dependencies.insert(
91+
0,
92+
get_parameterless_sub_dependant(
93+
depends=depends, path=route.path_format
94+
),
95+
)
96+
route.body_field = get_body_field(
97+
dependant=route.dependant, name=route.unique_id
98+
)
99+
route.app = request_response(route.get_route_handler())

pccommon/tests/test_timeouts.py

Lines changed: 16 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
import random
3-
from typing import Awaitable, Callable
2+
from typing import Any
43

54
import pytest
6-
from fastapi import FastAPI, Request, Response
7-
from fastapi.responses import PlainTextResponse
5+
from fastapi import FastAPI
6+
7+
# from fastapi.responses import PlainTextResponse
88
from httpx import AsyncClient
9-
from starlette.status import HTTP_200_OK, HTTP_504_GATEWAY_TIMEOUT
9+
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
1010

11-
from pccommon.middleware import timeout_middleware
11+
from pccommon.middleware import add_timeout
1212

1313
TIMEOUT_SECONDS = 2
1414
BASE_URL = "http://test"
@@ -20,80 +20,22 @@
2020
app.state.service_name = "test"
2121

2222

23-
@app.middleware("http")
24-
async def _timeout_middleware(
25-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
26-
) -> Response:
27-
"""Add a timeout to all requests."""
28-
return await timeout_middleware(request, call_next, timeout=TIMEOUT_SECONDS)
29-
30-
31-
# Test endpoint to sleep for a configurable amount of time, which may exceed the
32-
# timeout middleware setting
33-
@app.get("/sleep", response_class=PlainTextResponse)
34-
async def route_for_test(t: int) -> str:
35-
await asyncio.sleep(t)
36-
return "Done"
37-
38-
39-
# Test endpoint to sleep and confirm that the task is cancelled after the timeout
40-
@app.get("/cancel", response_class=PlainTextResponse)
41-
async def route_for_cancel_test(t: int) -> str:
42-
for i in range(t):
43-
await asyncio.sleep(1)
44-
if i > TIMEOUT_SECONDS:
45-
raise Exception("Task should have been cancelled")
46-
47-
return "Done"
48-
49-
50-
# Test middleware
51-
# ===============
52-
53-
54-
async def success_response(client: AsyncClient, timeout: int) -> None:
55-
print("making request")
56-
response = await client.get("/sleep", params={"t": timeout})
57-
assert response.status_code == HTTP_200_OK
58-
assert response.text == "Done"
23+
@app.get("/asleep")
24+
async def asleep() -> Any:
25+
await asyncio.sleep(1)
26+
return {}
5927

6028

61-
async def timeout_response(client: AsyncClient, timeout: int) -> None:
62-
response = await client.get("/sleep", params={"t": timeout})
63-
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT
64-
65-
66-
@pytest.mark.asyncio
67-
async def test_timeout() -> None:
68-
async with AsyncClient(app=app, base_url=BASE_URL) as client:
69-
await timeout_response(client, 10)
70-
29+
# Run this after registering the routes
7130

72-
@pytest.mark.asyncio
73-
async def test_no_timeout() -> None:
74-
async with AsyncClient(app=app, base_url=BASE_URL) as client:
75-
await success_response(client, 1)
31+
add_timeout(app, timeout_seconds=0.001)
7632

7733

7834
@pytest.mark.asyncio
79-
async def test_multiple_requests() -> None:
80-
async with AsyncClient(app=app, base_url=BASE_URL) as client:
81-
timeout_tasks = []
82-
for _ in range(100):
83-
t = TIMEOUT_SECONDS + random.randint(1, 10)
84-
timeout_tasks.append(asyncio.ensure_future(timeout_response(client, t)))
85-
86-
await asyncio.gather(*timeout_tasks)
87-
88-
success_tasks = []
89-
for _ in range(100):
90-
t = TIMEOUT_SECONDS - 1
91-
success_tasks.append(asyncio.ensure_future(success_response(client, t)))
35+
async def test_add_timeout() -> None:
9236

93-
await asyncio.gather(*success_tasks)
37+
client = AsyncClient(app=app, base_url=BASE_URL)
9438

39+
response = await client.get("/asleep")
9540

96-
@pytest.mark.asyncio
97-
async def test_request_cancelled() -> None:
98-
async with AsyncClient(app=app, base_url=BASE_URL) as client:
99-
await client.get("/cancel", params={"t": 10})
41+
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT

pcstac/pcstac/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class Settings(BaseSettings):
9898
api_version: str = f"v{API_VERSION}"
9999
rate_limits: RateLimits = RateLimits()
100100
back_pressures: BackPressures = BackPressures()
101-
request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30)
101+
request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30)
102102

103103
def get_tiler_href(self, request: Request) -> str:
104104
"""Generates the tiler HREF.

pcstac/pcstac/main.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""FastAPI application using PGStac."""
22
import logging
33
import os
4-
from typing import Any, Awaitable, Callable, Dict
4+
from typing import Any, Dict
55

6-
from fastapi import FastAPI, HTTPException, Request, Response
6+
from brotli_asgi import BrotliMiddleware
7+
from fastapi import FastAPI, Request
78
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
89
from fastapi.openapi.utils import get_openapi
910
from fastapi.responses import ORJSONResponse
@@ -16,9 +17,8 @@
1617

1718
from pccommon.logging import ServiceName, init_logging
1819
from pccommon.middleware import (
19-
RequestTracingMiddleware,
20-
handle_exceptions,
21-
timeout_middleware,
20+
add_timeout,
21+
http_exception_handler,
2222
)
2323
from pccommon.openapi import fixup_schema
2424
from pccommon.redis import connect_to_redis
@@ -32,6 +32,7 @@
3232
get_settings,
3333
)
3434
from pcstac.errors import PC_DEFAULT_STATUS_CODES
35+
from pcstac.middleware import ProxyHeaderHostMiddleware
3536
from pcstac.search import PCSearch, PCSearchGetRequest, RedisBaseItemCache
3637

3738
DEBUG: bool = os.getenv("DEBUG") == "TRUE" or False
@@ -70,13 +71,16 @@
7071
search_get_request_model=search_get_request_model,
7172
search_post_request_model=search_post_request_model,
7273
response_class=ORJSONResponse,
74+
middlewares=[BrotliMiddleware, ProxyHeaderHostMiddleware],
7375
exceptions={**DEFAULT_STATUS_CODES, **PC_DEFAULT_STATUS_CODES},
7476
)
7577

7678
app: FastAPI = api.app
7779

7880
app.state.service_name = ServiceName.STAC
7981

82+
add_timeout(app, app_settings.request_timeout)
83+
8084
# Note: If requests are being sent through an application gateway like
8185
# nginx-ingress, you may need to configure CORS through that system.
8286
app.add_middleware(
@@ -86,25 +90,6 @@
8690
allow_headers=["*"],
8791
)
8892

89-
app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.STAC)
90-
91-
92-
@app.middleware("http")
93-
async def _timeout_middleware(
94-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
95-
) -> Response:
96-
"""Add a timeout to all requests."""
97-
return await timeout_middleware(
98-
request, call_next, timeout=app_settings.request_timout
99-
)
100-
101-
102-
@app.middleware("http")
103-
async def _handle_exceptions(
104-
request: Request, call_next: Callable[[Request], Awaitable[Response]]
105-
) -> Response:
106-
return await handle_exceptions(request, call_next)
107-
10893

10994
@app.on_event("startup")
11095
async def startup_event() -> None:
@@ -119,13 +104,7 @@ async def shutdown_event() -> None:
119104
await close_db_connection(app)
120105

121106

122-
@app.exception_handler(HTTPException)
123-
async def http_exception_handler(
124-
request: Request, exc: HTTPException
125-
) -> PlainTextResponse:
126-
return PlainTextResponse(
127-
str(exc.detail), status_code=exc.status_code, headers=exc.headers
128-
)
107+
app.add_exception_handler(Exception, http_exception_handler)
129108

130109

131110
@app.exception_handler(StarletteHTTPException)

0 commit comments

Comments
 (0)