Skip to content

Commit

Permalink
Override routes for patching pygeoapi responses
Browse files Browse the repository at this point in the history
  • Loading branch information
francbartoli committed Apr 21, 2024
1 parent 02f3358 commit a5bd3d8
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 6 deletions.
36 changes: 30 additions & 6 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,30 @@ async def custom_app_exception_handler(request, e):
os.environ["HOST"] = cfg.HOST
os.environ["PORT"] = cfg.PORT

# import starlette application once env vars are set
# import pygeoapi starlette application once env vars are set
# and prepare the objects to override some core behavior
from pygeoapi.starlette_app import APP as PYGEOAPI_APP
from pygeoapi.starlette_app import url_prefix
from starlette.applications import Starlette
from starlette.routing import Mount

from app.utils.pygeoapi_utils import patch_route

static_route = PYGEOAPI_APP.routes[0]
api_app = PYGEOAPI_APP.routes[1].app
api_routes = api_app.routes

patched_routes = ()
for api_route in api_routes:
api_route_ = patch_route(api_route)
patched_routes += (api_route_,)

patched_app = Starlette(
routes=[
static_route,
Mount(url_prefix or "/", routes=list(patched_routes)),
]
)

pygeoapi_conf = Path.cwd() / os.environ["PYGEOAPI_CONFIG"]
pygeoapi_oapi = Path.cwd() / os.environ["PYGEOAPI_OPENAPI"]
Expand Down Expand Up @@ -125,22 +147,23 @@ async def custom_app_exception_handler(request, e):
)
from app.config.auth import auth_config

PYGEOAPI_APP.add_middleware(OPAMiddleware, config=auth_config)
patched_app.add_middleware(OPAMiddleware, config=auth_config)

security_schemes = [
SecurityScheme(
type="openIdConnect",
openIdConnectUrl=cfg.OIDC_WELL_KNOWN_ENDPOINT,
)
]
# Add Oauth2Middleware to the pygeoapi app
elif cfg.JWKS_ENABLED:
if cfg.API_KEY_ENABLED or cfg.OPA_ENABLED:
raise ValueError(
"OPA_ENABLED, JWKS_ENABLED and API_KEY_ENABLED are mutually exclusive"
)
from app.config.auth import auth_config

PYGEOAPI_APP.add_middleware(Oauth2Middleware, config=auth_config)
patched_app.add_middleware(Oauth2Middleware, config=auth_config)

security_schemes = [
SecurityScheme(
Expand All @@ -156,6 +179,7 @@ async def custom_app_exception_handler(request, e):
type="http", name="pygeoapi", scheme="bearer", bearerFormat="JWT"
),
]
# Add AuthorizerMiddleware to the pygeoapi app
elif cfg.API_KEY_ENABLED:
if cfg.OPA_ENABLED:
raise ValueError("OPA_ENABLED and API_KEY_ENABLED are mutually exclusive")
Expand All @@ -165,7 +189,7 @@ async def custom_app_exception_handler(request, e):

os.environ["PYGEOAPI_KEY_GLOBAL"] = cfg.PYGEOAPI_KEY_GLOBAL

PYGEOAPI_APP.add_middleware(
patched_app.add_middleware(
AuthorizerMiddleware,
public_paths=[f"{cfg.FASTGEOAPI_CONTEXT}/openapi"],
key_pattern="PYGEOAPI_KEY_",
Expand All @@ -176,11 +200,11 @@ async def custom_app_exception_handler(request, e):
]

if security_schemes:
PYGEOAPI_APP.add_middleware(
patched_app.add_middleware(
OpenapiSecurityMiddleware, security_schemes=security_schemes
)

app.mount(path=cfg.FASTGEOAPI_CONTEXT, app=PYGEOAPI_APP)
app.mount(path=cfg.FASTGEOAPI_CONTEXT, app=patched_app)

app.logger = create_logger(name="app.main")

Expand Down
1 change: 1 addition & 0 deletions app/pygeoapi/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""API package for patched pygeoapi."""
6 changes: 6 additions & 0 deletions app/pygeoapi/api/processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Processes patched module."""


def patch_response(response):
"""Patch pygeoapi response."""
return response
81 changes: 81 additions & 0 deletions app/pygeoapi/starlette_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Starlette application override module."""

import asyncio
from typing import Callable
from typing import Union

from pygeoapi.starlette_app import api_ as geoapi
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.responses import JSONResponse
from starlette.responses import Response


def call_api_threadsafe(
loop: asyncio.AbstractEventLoop, api_call: Callable, *args
) -> tuple:
"""Call api in a safe thread.
The api call needs a running loop. This method is meant to be called
from a thread that has no loop running.
:param loop: The loop to use.
:param api_call: The API method to call.
:param args: Arguments to pass to the API method.
:returns: The api call result tuple.
"""
asyncio.set_event_loop(loop)
return api_call(*args)


async def get_response(
api_call,
*args,
) -> Union[Response, JSONResponse, HTMLResponse]:
"""Creates a Starlette Response object and updates matching headers.
Runs the core api handler in a separate thread in order to avoid
blocking the main event loop.
:param result: The result of the API call.
This should be a tuple of (headers, status, content).
:returns: A Response instance.
"""
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, call_api_threadsafe, loop, api_call, *args
)

headers, status, content = result
if headers["Content-Type"] == "text/html":
response = HTMLResponse(content=content, status_code=status)
else:
if isinstance(content, dict):
response = JSONResponse(content, status_code=status)
else:
response = Response(content, status_code=status)

if headers is not None:
response.headers.update(headers)
return response


async def patched_get_job_result(request: Request, job_id=None):
"""OGC API - Processes job result endpoint.
:param request: Starlette Request instance
:param job_id: job identifier
:returns: HTTP response
"""
if "job_id" in request.path_params:
job_id = request.path_params["job_id"]

response = await get_response(geoapi.get_job_result, request, job_id)

from app.pygeoapi.api.processes import patch_response

patched_response = patch_response(response=response)

return patched_response
12 changes: 12 additions & 0 deletions app/utils/pygeoapi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Utilities for pygeoapi module."""

from starlette.routing import Route

from app.pygeoapi.starlette_app import patched_get_job_result


def patch_route(route: Route) -> Route:
"""Patch route behavior."""
if route.path == "/jobs/{job_id}/results":
route = Route("/jobs/{job_id}/results", patched_get_job_result)
return route

0 comments on commit a5bd3d8

Please sign in to comment.