Skip to content

Commit

Permalink
Add openapi module for pygeoapi override
Browse files Browse the repository at this point in the history
  • Loading branch information
francbartoli committed Feb 4, 2024
1 parent 0257cff commit fc38616
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 53 deletions.
5 changes: 3 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from openapi_pydantic.v3.v3_0_3 import OAuthFlow
from openapi_pydantic.v3.v3_0_3 import OAuthFlows
from openapi_pydantic.v3.v3_0_3 import SecurityScheme
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.cors import CORSMiddleware

from pygeoapi.l10n import LocaleError
from pygeoapi.openapi import generate_openapi_document
from pygeoapi.provider.base import ProviderConnectionError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.cors import CORSMiddleware


if cfg.LOG_LEVEL == "debug":
Expand Down
56 changes: 6 additions & 50 deletions app/middleware/pygeoapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from typing import Dict
from typing import List

from app.config.app import configuration as cfg
from app.config.logging import create_logger
from openapi_pydantic.v3.v3_0_3 import OpenAPI
from app.pygeoapi.openapi import augment_security
from openapi_pydantic.v3.v3_0_3 import SecurityScheme
from pydantic_core import ValidationError
from starlette.datastructures import Headers
from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp
Expand Down Expand Up @@ -83,54 +81,12 @@ async def send_with_security(self, message: Message) -> None: # noqa: C901
self.headers.update(headers_dict)
if message_type == "http.response.body":
initial_body = message.get("body", b"").decode()
try:
openapi = OpenAPI.model_validate_json(initial_body)
except ValidationError as e:
logger.error(e)
raise
security_scheme_types = [
security_scheme.type for security_scheme in self.security_schemes
]
if all(
item in ["http", "apiKey", "oauth2", "openIdConnect"]
for item in security_scheme_types
):
security_schemes = {"securitySchemes": {}} # type: dict[str, dict]
dumped_schemes = {}
for scheme in self.security_schemes:
dumped_schemes.update(
{
f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": scheme.model_dump( # noqa B950
by_alias=True, exclude_none=True
)
}
)
security_schemes["securitySchemes"] = dumped_schemes
body = openapi.model_dump(by_alias=True, exclude_none=True)
components = body.get("components")
if components:
components.update(security_schemes)
body["components"] = components
paths = openapi.paths
if paths:
secured_paths = {}
for key, value in paths.items():
if value.get:
value.get.security = [
{f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": []}
]
if value.post:
value.post.security = [
{f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": []}
]
secured_paths.update({key: value})
if secured_paths:
body["paths"] = secured_paths
binary_body = (
OpenAPI(**body)
.model_dump_json(by_alias=True, exclude_none=True, indent=2)
.encode()
openapi_body = augment_security(
doc=initial_body, security_schemes=self.security_schemes
)
binary_body = openapi_body.model_dump_json(
by_alias=True, exclude_none=True, indent=2
).encode()
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Length"] = str(len(binary_body))
message["body"] = binary_body
Expand Down
1 change: 1 addition & 0 deletions app/pygeoapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""pygeoapi package."""
56 changes: 56 additions & 0 deletions app/pygeoapi/openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Override vanilla openapi module."""
from typing import Dict
from typing import List

from app.config.app import configuration as cfg
from app.config.logging import create_logger
from openapi_pydantic.v3.v3_0_3 import OpenAPI
from openapi_pydantic.v3.v3_0_3 import SecurityScheme
from pydantic_core import ValidationError


logger = create_logger("app.pygeoapi.openapi")


def augment_security(doc: Dict, security_schemes: List[SecurityScheme]) -> OpenAPI:
"""Augment openapi document with security sections."""
try:
openapi = OpenAPI.model_validate_json(doc)
except ValidationError as e:
logger.error(e)
raise
security_scheme_types = [
security_scheme.type for security_scheme in security_schemes
]
_security_schemes = {"securitySchemes": {}} # type: dict[str, dict]
if all(
item in ["http", "apiKey", "oauth2", "openIdConnect"]
for item in security_scheme_types
):
dumped_schemes = {}
for scheme in security_schemes:
dumped_schemes.update(
{
f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": scheme.model_dump( # noqa B950
by_alias=True, exclude_none=True
)
}
)
_security_schemes["securitySchemes"] = dumped_schemes
content = openapi.model_dump(by_alias=True, exclude_none=True)
components = content.get("components")
if components:
components.update(_security_schemes)
content["components"] = components
paths = openapi.paths
secured_paths = {}
if paths:
for key, value in paths.items():
if value.get:
value.get.security = [{f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": []}]
if value.post:
value.post.security = [{f"pygeoapi {cfg.PYGEOAPI_SECURITY_SCHEME}": []}]
secured_paths.update({key: value})
if secured_paths:
content["paths"] = secured_paths
return OpenAPI(**content)
3 changes: 2 additions & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from loguru import logger
from openapi_pydantic.v3.v3_0_3 import OpenAPI
from openapi_pydantic.v3.v3_0_3 import SecurityScheme
from rich.console import Console

from pygeoapi.l10n import LocaleError
from pygeoapi.openapi import generate_openapi_document
from pygeoapi.provider.base import ProviderConnectionError
from rich.console import Console


err_console = Console(stderr=True)
Expand Down

0 comments on commit fc38616

Please sign in to comment.