Skip to content

Commit

Permalink
Prevent resource warnings from pymongo.
Browse files Browse the repository at this point in the history
To prevent resource warnings from pymongo when running tests, ensure the Mongo client connection is closed by wrapping it in a context manager.
  • Loading branch information
fniessink committed Sep 19, 2024
1 parent fe098f4 commit 9c7c4e0
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 32 deletions.
7 changes: 3 additions & 4 deletions components/api_server/src/initialization/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pymongo
from pymongo.database import Database

from shared.initialization.database import client
from shared.initialization.database import get_databases

from .migrations import perform_migrations
from .datamodel import import_datamodel
Expand All @@ -16,9 +16,8 @@

def init_database() -> Database: # pragma: no feature-test-cover
"""Initialize the database contents."""
db_client = client()
set_feature_compatibility_version(db_client.admin)
database = db_client.quality_time_db
admin_database, database = get_databases("admin", "quality_time_db")
set_feature_compatibility_version(admin_database)
logging.info("Connected to database: %s", database)
nr_reports = database.reports.count_documents({})
nr_measurements = database.measurements.count_documents({})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self):
self.database.sessions.find_one.return_value = {"user": "jodoe"}
self.database.measurements.count_documents.return_value = 0
self.database.measurements.index_information.return_value = {}
self.mongo_client().quality_time_db = self.database
self.mongo_client().get_database.return_value = self.database

def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None:
"""Initialize the database."""
Expand Down
4 changes: 2 additions & 2 deletions components/collector/src/quality_time_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database

# Make sure subclasses are registered
import metric_collectors # noqa: F401
Expand All @@ -15,7 +15,7 @@
async def collect() -> NoReturn:
"""Collect the measurements indefinitely."""
logging.getLogger().setLevel(config.LOG_LEVEL)
await Collector(database_connection()).start()
await Collector(get_database()).start()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def test_fetch_with_post_error(self):
async def test_collect(self):
"""Test the collect method."""
with (
patch("quality_time_collector.database_connection", return_value=self.database),
patch("quality_time_collector.get_database", return_value=self.database),
self.assertRaises(RuntimeError),
):
await quality_time_collector.collect()
Expand Down
4 changes: 2 additions & 2 deletions components/notifier/src/quality_time_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database

from notifier.notifier import notify

Expand All @@ -14,7 +14,7 @@ def start_notifications() -> NoReturn:
"""Notify indefinitely."""
logging.getLogger().setLevel(str(os.getenv("NOTIFIER_LOG_LEVEL", "WARNING")))
sleep_duration = int(os.getenv("NOTIFIER_SLEEP_DURATION", "60"))
asyncio.run(notify(database_connection(), sleep_duration))
asyncio.run(notify(get_database(), sleep_duration))


if __name__ == "__main__": # pragma: no cover
Expand Down
24 changes: 19 additions & 5 deletions components/shared_code/src/shared/initialization/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Database initialization."""

import contextlib
import os
from collections.abc import Generator

import pymongo
from pymongo import database


def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
@contextlib.contextmanager
def mongo_client() -> Generator[pymongo.MongoClient]: # pragma: no feature-test-cover
"""Return a pymongo client."""
database_url = os.environ.get("DATABASE_URL")
if not database_url:
Expand All @@ -16,9 +19,20 @@ def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
db_port = os.environ.get("DATABASE_PORT", "27017")
database_url = f"mongodb://{db_user}:{db_pass}@{db_host}:{db_port}"

return pymongo.MongoClient(database_url)
client: pymongo.MongoClient = pymongo.MongoClient(database_url)
try:
yield client
finally:
client.close()


def database_connection() -> database.Database: # pragma: no feature-test-cover
"""Return a pymongo database."""
return client()["quality_time_db"]
def get_databases(*database_names: str) -> tuple[database.Database, ...]: # pragma: no feature-test-cover
"""Return multiple Mongo databases."""
with mongo_client() as client:
return tuple(client.get_database(database_name) for database_name in database_names)


def get_database(database_name: str = "quality_time_db") -> database.Database: # pragma: no feature-test-cover
"""Return one Mongo database."""
with mongo_client() as client:
return client.get_database(database_name)
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
import unittest
from unittest.mock import Mock, patch

from shared.initialization.database import client
import pymongo

from shared.initialization.database import mongo_client

OS_ENVIRON_GET = "shared.initialization.database.os.environ.get"


class TestConnectionParams(unittest.TestCase):
"""Test the database connection parameters."""

def _assert_dbclient_host_url(self, dbclient, expected_url) -> None:
def assert_mongo_client_host_url(self, client: pymongo.MongoClient, expected_url: str) -> None:
"""Assert that the dbclient was initialized with expected url."""
self.assertEqual(expected_url, dbclient._init_kwargs["host"]) # noqa: SLF001
self.assertEqual(expected_url, client._init_kwargs["host"]) # noqa: SLF001

def test_default(self):
"""Test the default url."""
db = client()
_default_user_pass = "root:root" # nosec # noqa: S105
self._assert_dbclient_host_url(db, f"mongodb://{_default_user_pass}@localhost:27017")
db.close()
with mongo_client() as client:
self.assert_mongo_client_host_url(client, f"mongodb://{_default_user_pass}@localhost:27017")

def test_full_url_override(self):
"""Test setting full url with env var override."""
local_url = "mongodb://localhost"
with patch("shared.initialization.database.os.environ.get", Mock(return_value=local_url)):
db = client()
self._assert_dbclient_host_url(db, local_url)
db.close()
with patch(OS_ENVIRON_GET, Mock(return_value=local_url)), mongo_client() as client:
self.assert_mongo_client_host_url(client, local_url)

def test_partial_url_override(self):
"""Test setting partial url with env var overrides."""
Expand All @@ -41,7 +42,5 @@ def _os_environ_get(variable_name, default=None): # noqa: ANN202
}
return values.get(variable_name, default)

with patch("shared.initialization.database.os.environ.get", Mock(side_effect=_os_environ_get)):
db = client()
self._assert_dbclient_host_url(db, "mongodb://user:pass@host:4242")
db.close()
with patch(OS_ENVIRON_GET, Mock(side_effect=_os_environ_get)), mongo_client() as client:
self.assert_mongo_client_host_url(client, "mongodb://user:pass@host:4242")
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import mongomock

from shared.initialization.database import database_connection
from shared.initialization.database import get_database, get_databases

from tests.shared.base import DataModelTestCase

Expand All @@ -13,7 +13,13 @@ class DatabaseInitTest(DataModelTestCase):
"""Unit tests for database initialization."""

@patch("shared.initialization.database.pymongo", return_value=mongomock)
def test_client(self, client: Mock) -> None:
def test_get_database(self, client: Mock) -> None:
"""Test that the client is called."""
database_connection()
get_database()
client.MongoClient.assert_called_once()

@patch("shared.initialization.database.pymongo", return_value=mongomock)
def test_get_databases(self, client: Mock) -> None:
"""Test that the client is called."""
get_databases("database 1", "database 2")
client.MongoClient.assert_called_once()

0 comments on commit 9c7c4e0

Please sign in to comment.