Skip to content

Commit

Permalink
Compatibility fixes for FastAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
bbrondel committed Jul 31, 2024
1 parent 21cc0a4 commit 512c64f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 41 deletions.
6 changes: 3 additions & 3 deletions Dockerfile.pqserver
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
FROM python:3.11
RUN pip install flask gunicorn sqlalchemy psycopg2
RUN pip install fastapi safir astropy uvicorn gunicorn sqlalchemy psycopg2
WORKDIR /
COPY python/lsst/consdb/__init__.py python/lsst/consdb/pqserver.py python/lsst/consdb/utils.py /consdb-pq/
COPY python/lsst/consdb/__init__.py python/lsst/consdb/pqserver.py python/lsst/consdb/utils.py /consdb_pq/
# Environment variables that must be set:
# DB_HOST DB_PASS DB_USER DB_NAME or POSTGRES_URL

# Expose the port.
EXPOSE 8080

ENTRYPOINT [ "gunicorn", "-b", "0.0.0.0:8080", "-w", "2", "consdb-pq.pqserver:app" ]
ENTRYPOINT [ "gunicorn", "-b", "0.0.0.0:8080", "-w", "2", "-k", "uvicorn.workers.UvicornWorker", "consdb_pq.pqserver:app" ]

49 changes: 31 additions & 18 deletions python/lsst/consdb/pqserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from enum import Enum
from importlib.metadata import metadata, version
from typing import Annotated, Any, Iterable, Optional

from flask import Flask, request
from fastapi import FastAPI, APIRouter, Depends, Path
import sqlalchemy
import sqlalchemy.dialects.postgresql
from pydantic import BaseModel, Field, field_validator
from safir.metadata import Metadata, get_metadata
from safir.middleware.x_forwarded import XForwardedMiddleware
from .utils import setup_logging, setup_postgres

internal_router = APIRouter()
Expand Down Expand Up @@ -108,10 +109,21 @@ def validate_instrument_name(
# Global app setup #
####################

app = FastAPI()
path_prefix = "/consdb"
app = FastAPI(
title="consdb-pqserver",
description="HTTP API for consdb",
openapi_url=f"{path_prefix}/openapi.json",
docs_url=f"{path_prefix}/docs",
redoc_url=f"{path_prefix}/redoc",
)
engine = setup_postgres()
logger = setup_logging(__name__)

app.include_router(internal_router)
app.include_router(external_router, prefix=path_prefix)
app.add_middleware(XForwardedMiddleware)


########################
# Schema preload class #
Expand Down Expand Up @@ -423,15 +435,13 @@ def handle_sql_error(e: sqlalchemy.exc.SQLAlchemyError) -> tuple[dict[str, str],
###################################


@internal_router.get(
@app.get(
"/",
description="Metadata and health check endpoint.",
include_in_schema=False,
response_model=Metadata,
response_model_exclude_none=True,
summary="Application metadata",
)
async def internal_root() -> Metadata:
async def internal_root() -> dict[str, Any]:
"""Root URL for liveness checks.
Returns
Expand All @@ -440,10 +450,13 @@ async def internal_root() -> Metadata:
JSON response with a list of instruments, observation types, and
data types.
"""
return get_metadata(
package_name="consdb-pqserver",
application_name=config.name,
)
return {
"instruments": [ "foo", "bar", "baz" ],
}
# get_metadata(
# package_name="consdb-pqserver",
# application_name=config.name,
# )


class Index(BaseModel):
Expand Down Expand Up @@ -608,7 +621,7 @@ class GenericResponse(BaseModel):


@external_router.post("/flex/{instrument}/{obs_type}/obs/{obs_id}")
def insert_flexible_metadata(
async def insert_flexible_metadata(
instrument: Annotated[str, Depends(validate_instrument_name)],
obs_type: ObsTypeEnum,
obs_id: ObservationIdType,
Expand Down Expand Up @@ -659,7 +672,7 @@ def insert_flexible_metadata(


@external_router.post("/insert/{instrument}/{table}/obs/{obs_id}")
def insert(
async def insert(
instrument: Annotated[str, Depends(validate_instrument_name)],
table: str,
obs_id: ObservationIdType,
Expand Down Expand Up @@ -703,7 +716,7 @@ def insert(


@external_router.post("/insert/{instrument}/{table}")
def insert_multiple(
async def insert_multiple(
instrument: Annotated[str, Depends(validate_instrument_name)],
table: str,
) -> dict[str, Any] | tuple[dict[str, str], int]:
Expand Down Expand Up @@ -779,7 +792,7 @@ def insert_multiple(


@external_router.get("/query/{instrument}/{obs_type}/obs/{obs_id}")
def get_all_metadata(
async def get_all_metadata(
instrument: Annotated[str, Depends(validate_instrument_name)],
obs_type: ObsTypeEnum,
obs_id: ObservationIdType,
Expand Down Expand Up @@ -828,7 +841,7 @@ def get_all_metadata(


@external_router.post("/query")
def query() -> dict[str, Any] | tuple[dict[str, str], int]:
async def query() -> dict[str, Any] | tuple[dict[str, str], int]:
"""Query the ConsDB database.
Parameters
Expand Down Expand Up @@ -866,7 +879,7 @@ def query() -> dict[str, Any] | tuple[dict[str, str], int]:


@external_router.get("/schema")
def list_instruments() -> list[str]:
async def list_instruments() -> list[str]:
"""Retrieve the list of instruments available in ConsDB."""
global instrument_tables

Expand All @@ -875,7 +888,7 @@ def list_instruments() -> list[str]:


@external_router.get("/consdb/schema/{instrument}")
def list_table(
async def list_table(
instrument: Annotated[str, Depends(validate_instrument_name)],
) -> list[str]:
"""Retrieve the list of tables for an instrument."""
Expand All @@ -887,7 +900,7 @@ def list_table(


@external_router.get("/schema/{instrument}/<table>")
def schema(instrument: Annotated[str, Depends(validate_instrument_name)], table: str) -> dict[str, list[str]]:
async def schema(instrument: Annotated[str, Depends(validate_instrument_name)], table: str) -> dict[str, list[str]]:
"""Retrieve the descriptions of columns in a ConsDB table.
Parameters
Expand Down
42 changes: 22 additions & 20 deletions tests/test_pqserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import tempfile
from pathlib import Path

from fastapi.testclient import TestClient

import pytest
from requests import Response


def _assert_http_status(response: Response, status: int):
assert response.status_code == status, f"{response.status_code} {response.json}"
assert response.status_code == status, f"{response.status_code} {response.json()}"


@pytest.fixture
Expand Down Expand Up @@ -60,20 +62,20 @@ def app(db, scope="module"):
@pytest.fixture
def client(app, scope="module"):
# NOTE: all tests share the same client, app, and database.
return app.test_client()
return TestClient(app)


def test_root(client):
response = client.get("/")
result = response.json
result = response.json()
assert "instruments" in result
assert "obs_types" in result
assert "dtypes" in result


def test_root2(client):
response = client.get("/consdb")
result = response.json
result = response.json()
assert "instruments" in result
assert "latiss" in result["instruments"]
assert "obs_types" in result
Expand All @@ -88,7 +90,7 @@ def test_flexible_metadata(client):
json={"key": "foo", "dtype": "bool", "doc": "bool key"},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {
"message": "Key added to flexible metadata",
"key": "foo",
Expand All @@ -101,7 +103,7 @@ def test_flexible_metadata(client):
json={"key": "bar", "dtype": "int", "doc": "int key"},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {
"message": "Key added to flexible metadata",
"key": "bar",
Expand All @@ -114,15 +116,15 @@ def test_flexible_metadata(client):
json={"key": "baz", "dtype": "float", "doc": "float key"},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result["obs_type"] == "Exposure"

response = client.post(
"/consdb/flex/bad_instrument/exposure/addkey",
json={"key": "quux", "dtype": "str", "doc": "str key"},
)
_assert_http_status(response, 404)
result = response.json
result = response.json()
assert result == {
"message": "Unknown instrument",
"value": "bad_instrument",
Expand All @@ -131,7 +133,7 @@ def test_flexible_metadata(client):

response = client.get("/consdb/flex/latiss/exposure/schema")
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert "foo" in result
assert "bar" in result
assert "baz" in result
Expand All @@ -142,7 +144,7 @@ def test_flexible_metadata(client):
json={"values": {"foo": True, "bar": 42, "baz": 3.14159}},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result["message"] == "Flexible metadata inserted"
assert result["obs_id"] == 2024032100002

Expand All @@ -151,49 +153,49 @@ def test_flexible_metadata(client):
json={"values": {"foo": True, "bar": 42, "baz": 3.14159}},
)
_assert_http_status(response, 500)
result = response.json
result = response.json()
assert "UNIQUE" in result["message"]

response = client.post(
"/consdb/flex/latiss/exposure/obs/2024032100002",
json={"values": {"bad_key": 2.71828}},
)
_assert_http_status(response, 404)
result = response.json
result = response.json()
assert result["message"] == "Unknown key"
assert result["value"] == "bad_key"

response = client.get("/consdb/flex/latiss/exposure/obs/2024032100002")
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {"foo": True, "bar": 42, "baz": 3.14159}

response = client.get("/consdb/flex/latiss/exposure/obs/2024032100002?k=bar&k=baz")
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {"bar": 42, "baz": 3.14159}

response = client.post(
"/consdb/flex/latiss/exposure/obs/2024032100002?u=1",
json={"values": {"foo": False, "bar": 34, "baz": 2.71828}},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result["message"] == "Flexible metadata inserted"

response = client.get("/consdb/flex/latiss/exposure/obs/2024032100002")
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {"foo": False, "bar": 34, "baz": 2.71828}

response = client.get("/consdb/flex/latiss/exposure/obs/2024032100002?k=baz")
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {"baz": 2.71828}

response = client.post("/consdb/flex/latiss/exposure/obs/2024032100002", json={})
_assert_http_status(response, 404)
result = response.json
result = response.json()
assert "Invalid JSON" in result["message"]
assert result["required_keys"] == ["values"]

Expand All @@ -209,7 +211,7 @@ def test_flexible_metadata(client):
},
)
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert result == {
"message": "Data inserted",
"table": "cdb_latiss.exposure",
Expand All @@ -219,7 +221,7 @@ def test_flexible_metadata(client):

response = client.post("/consdb/query", json={"query": "SELECT * FROM exposure ORDER BY day_obs;"})
_assert_http_status(response, 200)
result = response.json
result = response.json()
assert len(result) == 2
assert "exposure_id" in result["columns"]
assert 20240321 in result["data"][0]
Expand Down

0 comments on commit 512c64f

Please sign in to comment.