Skip to content

Commit

Permalink
Return a ServiceUnavailable Error when an SQL db is not reachable, and
Browse files Browse the repository at this point in the history
retry connecting
  • Loading branch information
chaen committed Oct 31, 2023
1 parent b726d2a commit daeb49e
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 21 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- aiohttp
- aiomysql
- aiosqlite
- asyncache
- azure-core
- cachetools
########
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
aiohttp
aiomysql
aiosqlite
asyncache
azure-core
cachetools
m2crypto >=0.38.0
Expand Down
12 changes: 12 additions & 0 deletions src/diracx/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, status_code: int, data):

class DiracError(RuntimeError):
http_status_code = status.HTTP_400_BAD_REQUEST
http_headers: dict[str, str] | None = None

def __init__(self, detail: str = "Unknown"):
self.detail = detail
Expand Down Expand Up @@ -42,3 +43,14 @@ class JobNotFound(Exception):
def __init__(self, job_id: int):
self.job_id: int = job_id
super().__init__(f"Job {job_id} not found")


class RouteUnavailableError(DiracError):
""" "The route is not available (bad init)"""

http_status_code = status.HTTP_503_SERVICE_UNAVAILABLE
http_headers = {"Retry-After": "10"}


class DBConnectionError(DiracError):
"""Used whenever we encounter a problem with the B connection"""
28 changes: 24 additions & 4 deletions src/diracx/db/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
from sqlalchemy import DateTime, Enum, MetaData
from sqlalchemy import DateTime, Enum, MetaData, select
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import expression

from diracx.core.exceptions import InvalidQueryError
from diracx.core.exceptions import DBConnectionError, InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.core.settings import SqlalchemyDsn

Expand Down Expand Up @@ -146,7 +147,10 @@ async def engine_context(self) -> AsyncIterator[None]:
"""
assert self._engine is None, "engine_context cannot be nested"

engine = create_async_engine(self._db_url)
# Set the pool_recycle to 30mn
# That should prevent the problem of MySQL expiring connection
# after 60mn by default
engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
self._engine = engine

yield
Expand All @@ -166,8 +170,12 @@ async def __aenter__(self):
This is called by the Dependency mechanism (see ``db_transaction``),
It will create a new connection/transaction for each route call.
"""
assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
try:
self._conn.set(await self.engine.connect().__aenter__())
except Exception as e:
raise DBConnectionError("Cannot connect to DB") from e

self._conn.set(await self.engine.connect().__aenter__())
return self

async def __aexit__(self, exc_type, exc, tb):
Expand All @@ -181,6 +189,18 @@ async def __aexit__(self, exc_type, exc, tb):
await self._conn.get().__aexit__(exc_type, exc, tb)
self._conn.set(None)

async def ping(self) -> tuple[bool, str]:
"""
Check whether the connection to the DB is still working.
We could enable the ``pre_ping`` in the engine, but this would
be ran at every query.
"""
try:
await self.conn.scalar(select(1))
return True, ""
except OperationalError as e:
return False, repr(e)


def apply_search_filters(table, stmt, search):
# Apply any filters
Expand Down
80 changes: 64 additions & 16 deletions src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, AsyncContextManager, AsyncGenerator, Iterable, TypeVar

import dotenv
from asyncache import cached
from cachetools import TTLCache
from fastapi import APIRouter, Depends, Request
from fastapi.dependencies.models import Dependant
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -15,7 +17,12 @@
from pydantic import parse_raw_as

from diracx.core.config import ConfigSource
from diracx.core.exceptions import DiracError, DiracHttpResponse
from diracx.core.exceptions import (
DBConnectionError,
DiracError,
DiracHttpResponse,
RouteUnavailableError,
)
from diracx.core.extensions import select_from_extension
from diracx.core.utils import dotenv_files_from_environment
from diracx.db.os.utils import BaseOSDB
Expand All @@ -27,6 +34,8 @@

T = TypeVar("T")
T2 = TypeVar("T2", bound=AsyncContextManager)
T3 = TypeVar("T3", bound=BaseSQLDB)


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,21 +68,33 @@ def create_app_inner(
# Override the configuration source
app.dependency_overrides[ConfigSource.create] = config_source.read_config

fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()
for db_name, db_url in database_urls.items():
sql_db_classes = BaseSQLDB.available_implementations(db_name)
# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)
app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)
try:
sql_db_classes = BaseSQLDB.available_implementations(db_name)

# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)

app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)

# At least one DB works, so we do not fail the startup
fail_startup = False
except Exception:
logger.exception(f"Failed to initialize DB {db_name}, {db_url}")

if fail_startup:
raise Exception("No SQL database could be initialized, aborting")

# Add the OpenSearch DBs to the application
available_os_db_classes: set[type[BaseOSDB]] = set()
Expand Down Expand Up @@ -199,7 +220,9 @@ def create_app() -> DiracFastAPI:

def dirac_error_handler(request: Request, exc: DiracError) -> Response:
return JSONResponse(
status_code=exc.http_status_code, content={"detail": exc.detail}
status_code=exc.http_status_code,
content={"detail": exc.detail},
headers=exc.http_headers,
)


Expand All @@ -225,9 +248,34 @@ def find_dependents(
yield from find_dependents(dependency.dependencies, cls)


_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10)


@cached(_db_alive_cache)
async def is_db_alive(db: T3):
"""Cache the result of pinging the DB"""
is_alive, reason = await db.ping()
logger.debug("Pinged db %s, is_alive %s, reason %s", type(db), is_alive, reason)
if not is_alive:
raise DBConnectionError(reason)


async def db_transaction(db: T2) -> AsyncGenerator[T2, None]:
async with db:
yield db
"""
Initiate a DB transaction.
:raises: RouteUnavailableError in case the connection to the DB fails
"""

# Entering the context already triggers a connection to the DB
# that may fail, hence the try/except
try:
async with db:
# Check whether the connection still works before executing the query
await is_db_alive(db)
yield db
except DBConnectionError as e:
raise RouteUnavailableError(repr(e)) from e


async def db_session(db: T2) -> AsyncGenerator[T2, None]:
Expand Down
10 changes: 9 additions & 1 deletion tests/db/test_dummyDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from diracx.core.exceptions import InvalidQueryError
from diracx.core.exceptions import DBConnectionError, InvalidQueryError
from diracx.db.sql.dummy.db import DummyDB

# Each DB test class must defined a fixture looking like this one
Expand Down Expand Up @@ -67,3 +67,11 @@ async def test_insert_and_summary(dummy_db: DummyDB):
}
],
)


async def test_bad_connection():
dummy_db = DummyDB("mysql+aiomysql://tata:[email protected]:3306/name")
async with dummy_db.engine_context():
with pytest.raises(DBConnectionError):
async with dummy_db:
dummy_db.ping()
12 changes: 12 additions & 0 deletions tests/routers/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ def test_installation_metadata(test_client):

assert r.status_code == 200
assert r.json()


def test_unavailable_db(monkeypatch, test_client):
# TODO
# That does not work because test_client is already initialized
monkeypatch.setenv(
"DIRACX_DB_URL_JOBDB", "mysql+aiomysql://tata:[email protected]:3306/name"
)

r = test_client.get("/api/job/123")
assert r.status_code == 503

Check failure on line 28 in tests/routers/test_generic.py

View workflow job for this annotation

GitHub Actions / pytest

test_unavailable_db assert 404 == 503 + where 404 = <Response [404 Not Found]>.status_code

Check failure on line 28 in tests/routers/test_generic.py

View workflow job for this annotation

GitHub Actions / pytest-integration

test_unavailable_db assert 404 == 503 + where 404 = <Response [404 Not Found]>.status_code
assert r.json()

0 comments on commit daeb49e

Please sign in to comment.