Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check incoming request Authorization header #724

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ class OIDCConfig(BlueapiBaseModel):
description="URL to fetch OIDC config from the provider"
)
client_id: str = Field(description="Client ID")
client_audience: str | tuple[str, ...] | None = Field(
description="Client Audience(s)"
)
client_audience: str = Field(description="Client Audience")

@cached_property
def _config_from_oidc_url(self) -> dict[str, Any]:
Expand Down Expand Up @@ -128,6 +126,10 @@ def id_token_signing_alg_values_supported(self) -> list[str]:
list[str],
self._config_from_oidc_url.get("id_token_signing_alg_values_supported"),
)

@cached_property
def introspection_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("introspection_endpoint"))


class CLIClientConfig(OIDCConfig):
Expand Down
3 changes: 1 addition & 2 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def _do_device_flow(self) -> None:
self._server_config.device_authorization_endpoint,
data={
"client_id": self._server_config.client_id,
"scope": "openid profile offline_access",
"audience": f"{self._server_config.client_audience}",
"scope": f"openid profile offline_access {self._server_config.instance}",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
Expand Down
55 changes: 42 additions & 13 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from opentelemetry.propagate import get_global_textmap
from opentelemetry.trace import get_tracer_provider
from pydantic import ValidationError
import requests
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

from blueapi.config import ApplicationConfig
from blueapi.config import ApplicationConfig, OIDCConfig
from blueapi.service import interface
from blueapi.worker import Task, TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum
Expand Down Expand Up @@ -57,7 +58,7 @@ def _runner() -> WorkerDispatcher:
return RUNNER


def setup_runner(config: ApplicationConfig | None = None, use_subprocess: bool = True):
def setup_runner(config: ApplicationConfig, use_subprocess: bool = True):
global RUNNER
runner = WorkerDispatcher(config, use_subprocess)
runner.start()
Expand All @@ -73,28 +74,31 @@ def teardown_runner():
RUNNER = None


@asynccontextmanager
async def lifespan(app: FastAPI):
config: ApplicationConfig = app.state.config
setup_runner(config)
yield
teardown_runner()
def lifespan(config: ApplicationConfig):
@asynccontextmanager
async def lifespan(app: FastAPI):
setup_runner(config)
yield
teardown_runner()
return lifespan


router = APIRouter()


def get_app():
def get_app(config: ApplicationConfig):
app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
lifespan=lifespan(config),
version=REST_API_VERSION,
)
app.include_router(router)
app.add_exception_handler(KeyError, on_key_error_404)
app.middleware("http")(add_api_version_header)
app.middleware("http")(inject_propagated_observability_context)
if config.oidc:
app.middleware("http")(check_token_validity(config.oidc))
return app


Expand Down Expand Up @@ -382,16 +386,14 @@ def start(config: ApplicationConfig):
"%(asctime)s %(levelprefix)s %(client_addr)s"
+ " - '%(request_line)s' %(status_code)s"
)
app = get_app()
app = get_app(config)

FastAPIInstrumentor().instrument_app(
app,
tracer_provider=get_tracer_provider(),
http_capture_headers_server_request=[",*"],
http_capture_headers_server_response=[",*"],
)
app.state.config = config

uvicorn.run(app, host=config.api.host, port=config.api.port)


Expand All @@ -416,3 +418,30 @@ async def inject_propagated_observability_context(
attach(ctx)
response = await call_next(request)
return response


def check_token_validity(config: OIDCConfig):
async def check_token_validity(request: Request,
call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
if "Authorization" not in request.headers:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="No Authorization header")
authz_header = request.headers["Authorization"]
if not authz_header.startswith("Bearer "):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization header not Bearer token")
access_token = authz_header.removeprefix("Bearer ")

response = requests.post(
config.introspection_endpoint,
data={
"token": access_token,
"token_type_hint": "access_token"
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
if not (response.json["active"] is True):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization Bearer token not active")
response = await call_next(request)
return response
return check_token_validity
4 changes: 2 additions & 2 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class WorkerDispatcher:

def __init__(
self,
config: ApplicationConfig | None = None,
config: ApplicationConfig,
use_subprocess: bool = True,
) -> None:
self._config = config or ApplicationConfig()
self._config = config
self._subprocess = None
self._use_subprocess = use_subprocess
self._state = EnvironmentResponse(
Expand Down
Loading