From ed427d0bd46a44d9064a9ffb49b5f3dd849521db Mon Sep 17 00:00:00 2001 From: Frank Niessink Date: Thu, 19 Sep 2024 14:03:07 +0200 Subject: [PATCH] Prevent resource warnings from pymongo. To prevent resource warnings from pymongo when running tests, ensure the Mongo client connection is closed by wrapping it in a context manager. --- .../api_server/src/initialization/database.py | 13 +++++---- .../api_server/src/quality_time_server.py | 11 +++++--- .../tests/initialization/test_database.py | 5 ++-- .../collector/src/quality_time_collector.py | 4 +-- .../tests/base_collectors/test_collector.py | 2 +- .../notifier/src/quality_time_notifier.py | 4 +-- .../src/shared/initialization/database.py | 24 +++++++++++++---- .../shared/database/test_connection_params.py | 27 +++++++++---------- .../shared/initialization/test_database.py | 12 ++++++--- 9 files changed, 61 insertions(+), 41 deletions(-) diff --git a/components/api_server/src/initialization/database.py b/components/api_server/src/initialization/database.py index 5af9576585..a4f49d1f2e 100644 --- a/components/api_server/src/initialization/database.py +++ b/components/api_server/src/initialization/database.py @@ -6,7 +6,7 @@ import pymongo from pymongo.database import Database -from shared.initialization.database import client +from shared.initialization.database import get_database from .migrations import perform_migrations from .datamodel import import_datamodel @@ -14,11 +14,10 @@ from .secrets import initialize_secrets -def init_database() -> Database: # pragma: no feature-test-cover +def init_database(client: pymongo.MongoClient) -> Database: # pragma: no feature-test-cover """Initialize the database contents.""" - db_client = client() - set_feature_compatibility_version(db_client.admin) - database = db_client.quality_time_db + set_feature_compatibility_version(client) + database = get_database(client) logging.info("Connected to database: %s", database) nr_reports = database.reports.count_documents({}) nr_measurements = database.measurements.count_documents({}) @@ -33,12 +32,12 @@ def init_database() -> Database: # pragma: no feature-test-cover return database -def set_feature_compatibility_version(admin_database: Database) -> None: +def set_feature_compatibility_version(client: pymongo.MongoClient) -> None: """Set the feature compatibility version to the current MongoDB version to prepare for upgrade to the next version. See https://docs.mongodb.com/manual/reference/command/setFeatureCompatibilityVersion/ """ - admin_database.command("setFeatureCompatibilityVersion", "7.0", confirm=True) + get_database(client, "admin").command("setFeatureCompatibilityVersion", "7.0", confirm=True) def create_indexes(database: Database) -> None: diff --git a/components/api_server/src/quality_time_server.py b/components/api_server/src/quality_time_server.py index fc0ccfa122..7bdc089bab 100644 --- a/components/api_server/src/quality_time_server.py +++ b/components/api_server/src/quality_time_server.py @@ -9,6 +9,8 @@ import bottle +from shared.initialization.database import mongo_client + from initialization.database import init_database from initialization.bottle import init_bottle @@ -18,10 +20,11 @@ def serve() -> None: # pragma: no feature-test-cover log_level = str(os.getenv("API_SERVER_LOG_LEVEL", "WARNING")) logger = logging.getLogger() logger.setLevel(log_level) - database = init_database() - init_bottle(database) - server_port = int(os.getenv("API_SERVER_PORT", "5001")) - bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104 + with mongo_client() as client: + database = init_database(client) + init_bottle(database) + server_port = int(os.getenv("API_SERVER_PORT", "5001")) + bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104 if __name__ == "__main__": # pragma: no feature-test-cover, pragma: no cover diff --git a/components/api_server/tests/initialization/test_database.py b/components/api_server/tests/initialization/test_database.py index 84f296b194..2b2ed2bae5 100644 --- a/components/api_server/tests/initialization/test_database.py +++ b/components/api_server/tests/initialization/test_database.py @@ -24,7 +24,7 @@ def setUp(self): self.database.sessions.find_one.return_value = {"user": "jodoe"} self.database.measurements.count_documents.return_value = 0 self.database.measurements.index_information.return_value = {} - self.mongo_client().quality_time_db = self.database + self.mongo_client.get_database.return_value = self.database def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None: """Initialize the database.""" @@ -35,9 +35,8 @@ def init_database(self, data_model_json: str, assert_glob_called: bool = True) - "open", mock_open(read_data=data_model_json), ), - patch("pymongo.MongoClient", self.mongo_client), ): - init_database() + init_database(self.mongo_client) if assert_glob_called: glob_mock.assert_called() else: diff --git a/components/collector/src/quality_time_collector.py b/components/collector/src/quality_time_collector.py index de20dc5222..f202dd1f4f 100644 --- a/components/collector/src/quality_time_collector.py +++ b/components/collector/src/quality_time_collector.py @@ -4,7 +4,7 @@ import logging from typing import NoReturn -from shared.initialization.database import database_connection +from shared.initialization.database import get_database # Make sure subclasses are registered import metric_collectors # noqa: F401 @@ -15,7 +15,7 @@ async def collect() -> NoReturn: """Collect the measurements indefinitely.""" logging.getLogger().setLevel(config.LOG_LEVEL) - await Collector(database_connection()).start() + await Collector(get_database()).start() if __name__ == "__main__": diff --git a/components/collector/tests/base_collectors/test_collector.py b/components/collector/tests/base_collectors/test_collector.py index a61a6d1aa6..bf70785bf1 100644 --- a/components/collector/tests/base_collectors/test_collector.py +++ b/components/collector/tests/base_collectors/test_collector.py @@ -178,7 +178,7 @@ async def test_fetch_with_post_error(self): async def test_collect(self): """Test the collect method.""" with ( - patch("quality_time_collector.database_connection", return_value=self.database), + patch("quality_time_collector.get_database", return_value=self.database), self.assertRaises(RuntimeError), ): await quality_time_collector.collect() diff --git a/components/notifier/src/quality_time_notifier.py b/components/notifier/src/quality_time_notifier.py index 2d3c42799b..4758ec2337 100644 --- a/components/notifier/src/quality_time_notifier.py +++ b/components/notifier/src/quality_time_notifier.py @@ -5,7 +5,7 @@ import os from typing import NoReturn -from shared.initialization.database import database_connection +from shared.initialization.database import get_database from notifier.notifier import notify @@ -14,7 +14,7 @@ def start_notifications() -> NoReturn: """Notify indefinitely.""" logging.getLogger().setLevel(str(os.getenv("NOTIFIER_LOG_LEVEL", "WARNING"))) sleep_duration = int(os.getenv("NOTIFIER_SLEEP_DURATION", "60")) - asyncio.run(notify(database_connection(), sleep_duration)) + asyncio.run(notify(get_database(), sleep_duration)) if __name__ == "__main__": # pragma: no cover diff --git a/components/shared_code/src/shared/initialization/database.py b/components/shared_code/src/shared/initialization/database.py index f0ad3453a9..c7f9ea7a05 100644 --- a/components/shared_code/src/shared/initialization/database.py +++ b/components/shared_code/src/shared/initialization/database.py @@ -1,12 +1,15 @@ """Database initialization.""" +import contextlib import os +from collections.abc import Generator import pymongo from pymongo import database -def client() -> pymongo.MongoClient: # pragma: no feature-test-cover +@contextlib.contextmanager +def mongo_client() -> Generator[pymongo.MongoClient]: # pragma: no feature-test-cover """Return a pymongo client.""" database_url = os.environ.get("DATABASE_URL") if not database_url: @@ -16,9 +19,20 @@ def client() -> pymongo.MongoClient: # pragma: no feature-test-cover db_port = os.environ.get("DATABASE_PORT", "27017") database_url = f"mongodb://{db_user}:{db_pass}@{db_host}:{db_port}" - return pymongo.MongoClient(database_url) + client: pymongo.MongoClient = pymongo.MongoClient(database_url) + try: + yield client + finally: + client.close() -def database_connection() -> database.Database: # pragma: no feature-test-cover - """Return a pymongo database.""" - return client()["quality_time_db"] +def get_databases(*database_names: str) -> tuple[database.Database, ...]: # pragma: no feature-test-cover + """Return multiple Mongo databases.""" + with mongo_client() as client: + return tuple(client.get_database(database_name) for database_name in database_names) + + +def get_database(database_name: str = "quality_time_db") -> database.Database: # pragma: no feature-test-cover + """Return one Mongo database.""" + with mongo_client() as client: + return client.get_database(database_name) diff --git a/components/shared_code/tests/shared/database/test_connection_params.py b/components/shared_code/tests/shared/database/test_connection_params.py index 3b246194df..2a0e91d239 100644 --- a/components/shared_code/tests/shared/database/test_connection_params.py +++ b/components/shared_code/tests/shared/database/test_connection_params.py @@ -3,30 +3,31 @@ import unittest from unittest.mock import Mock, patch -from shared.initialization.database import client +import pymongo + +from shared.initialization.database import mongo_client + +OS_ENVIRON_GET = "shared.initialization.database.os.environ.get" class TestConnectionParams(unittest.TestCase): """Test the database connection parameters.""" - def _assert_dbclient_host_url(self, dbclient, expected_url) -> None: + def assert_mongo_client_host_url(self, client: pymongo.MongoClient, expected_url: str) -> None: """Assert that the dbclient was initialized with expected url.""" - self.assertEqual(expected_url, dbclient._init_kwargs["host"]) # noqa: SLF001 + self.assertEqual(expected_url, client._init_kwargs["host"]) # noqa: SLF001 def test_default(self): """Test the default url.""" - db = client() _default_user_pass = "root:root" # nosec # noqa: S105 - self._assert_dbclient_host_url(db, f"mongodb://{_default_user_pass}@localhost:27017") - db.close() + with mongo_client() as client: + self.assert_mongo_client_host_url(client, f"mongodb://{_default_user_pass}@localhost:27017") def test_full_url_override(self): """Test setting full url with env var override.""" local_url = "mongodb://localhost" - with patch("shared.initialization.database.os.environ.get", Mock(return_value=local_url)): - db = client() - self._assert_dbclient_host_url(db, local_url) - db.close() + with patch(OS_ENVIRON_GET, Mock(return_value=local_url)), mongo_client() as client: + self.assert_mongo_client_host_url(client, local_url) def test_partial_url_override(self): """Test setting partial url with env var overrides.""" @@ -41,7 +42,5 @@ def _os_environ_get(variable_name, default=None): # noqa: ANN202 } return values.get(variable_name, default) - with patch("shared.initialization.database.os.environ.get", Mock(side_effect=_os_environ_get)): - db = client() - self._assert_dbclient_host_url(db, "mongodb://user:pass@host:4242") - db.close() + with patch(OS_ENVIRON_GET, Mock(side_effect=_os_environ_get)), mongo_client() as client: + self.assert_mongo_client_host_url(client, "mongodb://user:pass@host:4242") diff --git a/components/shared_code/tests/shared/initialization/test_database.py b/components/shared_code/tests/shared/initialization/test_database.py index 72ab2ce9bc..e4e347fb17 100644 --- a/components/shared_code/tests/shared/initialization/test_database.py +++ b/components/shared_code/tests/shared/initialization/test_database.py @@ -4,7 +4,7 @@ import mongomock -from shared.initialization.database import database_connection +from shared.initialization.database import get_database, get_databases from tests.shared.base import DataModelTestCase @@ -13,7 +13,13 @@ class DatabaseInitTest(DataModelTestCase): """Unit tests for database initialization.""" @patch("shared.initialization.database.pymongo", return_value=mongomock) - def test_client(self, client: Mock) -> None: + def test_get_database(self, client: Mock) -> None: """Test that the client is called.""" - database_connection() + get_database() + client.MongoClient.assert_called_once() + + @patch("shared.initialization.database.pymongo", return_value=mongomock) + def test_get_databases(self, client: Mock) -> None: + """Test that the client is called.""" + get_databases("database 1", "database 2") client.MongoClient.assert_called_once()