Skip to content

Commit

Permalink
Prevent resource warnings from pymongo.
Browse files Browse the repository at this point in the history
To prevent resource warnings from pymongo when running tests, ensure the Mongo client connection is closed by wrapping it in a context manager.
  • Loading branch information
fniessink committed Sep 19, 2024
1 parent 98c2f54 commit ed427d0
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 41 deletions.
13 changes: 6 additions & 7 deletions components/api_server/src/initialization/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
import pymongo
from pymongo.database import Database

from shared.initialization.database import client
from shared.initialization.database import get_database

from .migrations import perform_migrations
from .datamodel import import_datamodel
from .report import import_example_reports, initialize_reports_overview
from .secrets import initialize_secrets


def init_database() -> Database: # pragma: no feature-test-cover
def init_database(client: pymongo.MongoClient) -> Database: # pragma: no feature-test-cover
"""Initialize the database contents."""
db_client = client()
set_feature_compatibility_version(db_client.admin)
database = db_client.quality_time_db
set_feature_compatibility_version(client)
database = get_database(client)
logging.info("Connected to database: %s", database)
nr_reports = database.reports.count_documents({})
nr_measurements = database.measurements.count_documents({})
Expand All @@ -33,12 +32,12 @@ def init_database() -> Database: # pragma: no feature-test-cover
return database


def set_feature_compatibility_version(admin_database: Database) -> None:
def set_feature_compatibility_version(client: pymongo.MongoClient) -> None:
"""Set the feature compatibility version to the current MongoDB version to prepare for upgrade to the next version.
See https://docs.mongodb.com/manual/reference/command/setFeatureCompatibilityVersion/
"""
admin_database.command("setFeatureCompatibilityVersion", "7.0", confirm=True)
get_database(client, "admin").command("setFeatureCompatibilityVersion", "7.0", confirm=True)


def create_indexes(database: Database) -> None:
Expand Down
11 changes: 7 additions & 4 deletions components/api_server/src/quality_time_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import bottle

from shared.initialization.database import mongo_client

from initialization.database import init_database
from initialization.bottle import init_bottle

Expand All @@ -18,10 +20,11 @@ def serve() -> None: # pragma: no feature-test-cover
log_level = str(os.getenv("API_SERVER_LOG_LEVEL", "WARNING"))
logger = logging.getLogger()
logger.setLevel(log_level)
database = init_database()
init_bottle(database)
server_port = int(os.getenv("API_SERVER_PORT", "5001"))
bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104
with mongo_client() as client:
database = init_database(client)
init_bottle(database)
server_port = int(os.getenv("API_SERVER_PORT", "5001"))
bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104


if __name__ == "__main__": # pragma: no feature-test-cover, pragma: no cover
Expand Down
5 changes: 2 additions & 3 deletions components/api_server/tests/initialization/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self):
self.database.sessions.find_one.return_value = {"user": "jodoe"}
self.database.measurements.count_documents.return_value = 0
self.database.measurements.index_information.return_value = {}
self.mongo_client().quality_time_db = self.database
self.mongo_client.get_database.return_value = self.database

def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None:
"""Initialize the database."""
Expand All @@ -35,9 +35,8 @@ def init_database(self, data_model_json: str, assert_glob_called: bool = True) -
"open",
mock_open(read_data=data_model_json),
),
patch("pymongo.MongoClient", self.mongo_client),
):
init_database()
init_database(self.mongo_client)
if assert_glob_called:
glob_mock.assert_called()
else:
Expand Down
4 changes: 2 additions & 2 deletions components/collector/src/quality_time_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database

# Make sure subclasses are registered
import metric_collectors # noqa: F401
Expand All @@ -15,7 +15,7 @@
async def collect() -> NoReturn:
"""Collect the measurements indefinitely."""
logging.getLogger().setLevel(config.LOG_LEVEL)
await Collector(database_connection()).start()
await Collector(get_database()).start()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def test_fetch_with_post_error(self):
async def test_collect(self):
"""Test the collect method."""
with (
patch("quality_time_collector.database_connection", return_value=self.database),
patch("quality_time_collector.get_database", return_value=self.database),
self.assertRaises(RuntimeError),
):
await quality_time_collector.collect()
Expand Down
4 changes: 2 additions & 2 deletions components/notifier/src/quality_time_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database

from notifier.notifier import notify

Expand All @@ -14,7 +14,7 @@ def start_notifications() -> NoReturn:
"""Notify indefinitely."""
logging.getLogger().setLevel(str(os.getenv("NOTIFIER_LOG_LEVEL", "WARNING")))
sleep_duration = int(os.getenv("NOTIFIER_SLEEP_DURATION", "60"))
asyncio.run(notify(database_connection(), sleep_duration))
asyncio.run(notify(get_database(), sleep_duration))


if __name__ == "__main__": # pragma: no cover
Expand Down
24 changes: 19 additions & 5 deletions components/shared_code/src/shared/initialization/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Database initialization."""

import contextlib
import os
from collections.abc import Generator

import pymongo
from pymongo import database


def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
@contextlib.contextmanager
def mongo_client() -> Generator[pymongo.MongoClient]: # pragma: no feature-test-cover
"""Return a pymongo client."""
database_url = os.environ.get("DATABASE_URL")
if not database_url:
Expand All @@ -16,9 +19,20 @@ def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
db_port = os.environ.get("DATABASE_PORT", "27017")
database_url = f"mongodb://{db_user}:{db_pass}@{db_host}:{db_port}"

return pymongo.MongoClient(database_url)
client: pymongo.MongoClient = pymongo.MongoClient(database_url)
try:
yield client
finally:
client.close()


def database_connection() -> database.Database: # pragma: no feature-test-cover
"""Return a pymongo database."""
return client()["quality_time_db"]
def get_databases(*database_names: str) -> tuple[database.Database, ...]: # pragma: no feature-test-cover
"""Return multiple Mongo databases."""
with mongo_client() as client:
return tuple(client.get_database(database_name) for database_name in database_names)


def get_database(database_name: str = "quality_time_db") -> database.Database: # pragma: no feature-test-cover
"""Return one Mongo database."""
with mongo_client() as client:
return client.get_database(database_name)
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
import unittest
from unittest.mock import Mock, patch

from shared.initialization.database import client
import pymongo

from shared.initialization.database import mongo_client

OS_ENVIRON_GET = "shared.initialization.database.os.environ.get"


class TestConnectionParams(unittest.TestCase):
"""Test the database connection parameters."""

def _assert_dbclient_host_url(self, dbclient, expected_url) -> None:
def assert_mongo_client_host_url(self, client: pymongo.MongoClient, expected_url: str) -> None:
"""Assert that the dbclient was initialized with expected url."""
self.assertEqual(expected_url, dbclient._init_kwargs["host"]) # noqa: SLF001
self.assertEqual(expected_url, client._init_kwargs["host"]) # noqa: SLF001

def test_default(self):
"""Test the default url."""
db = client()
_default_user_pass = "root:root" # nosec # noqa: S105
self._assert_dbclient_host_url(db, f"mongodb://{_default_user_pass}@localhost:27017")
db.close()
with mongo_client() as client:
self.assert_mongo_client_host_url(client, f"mongodb://{_default_user_pass}@localhost:27017")

def test_full_url_override(self):
"""Test setting full url with env var override."""
local_url = "mongodb://localhost"
with patch("shared.initialization.database.os.environ.get", Mock(return_value=local_url)):
db = client()
self._assert_dbclient_host_url(db, local_url)
db.close()
with patch(OS_ENVIRON_GET, Mock(return_value=local_url)), mongo_client() as client:
self.assert_mongo_client_host_url(client, local_url)

def test_partial_url_override(self):
"""Test setting partial url with env var overrides."""
Expand All @@ -41,7 +42,5 @@ def _os_environ_get(variable_name, default=None): # noqa: ANN202
}
return values.get(variable_name, default)

with patch("shared.initialization.database.os.environ.get", Mock(side_effect=_os_environ_get)):
db = client()
self._assert_dbclient_host_url(db, "mongodb://user:pass@host:4242")
db.close()
with patch(OS_ENVIRON_GET, Mock(side_effect=_os_environ_get)), mongo_client() as client:
self.assert_mongo_client_host_url(client, "mongodb://user:pass@host:4242")
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import mongomock

from shared.initialization.database import database_connection
from shared.initialization.database import get_database, get_databases

from tests.shared.base import DataModelTestCase

Expand All @@ -13,7 +13,13 @@ class DatabaseInitTest(DataModelTestCase):
"""Unit tests for database initialization."""

@patch("shared.initialization.database.pymongo", return_value=mongomock)
def test_client(self, client: Mock) -> None:
def test_get_database(self, client: Mock) -> None:
"""Test that the client is called."""
database_connection()
get_database()
client.MongoClient.assert_called_once()

@patch("shared.initialization.database.pymongo", return_value=mongomock)
def test_get_databases(self, client: Mock) -> None:
"""Test that the client is called."""
get_databases("database 1", "database 2")
client.MongoClient.assert_called_once()

0 comments on commit ed427d0

Please sign in to comment.