diff --git a/components/api_server/src/initialization/database.py b/components/api_server/src/initialization/database.py index 5af9576585..a0e64c12d5 100644 --- a/components/api_server/src/initialization/database.py +++ b/components/api_server/src/initialization/database.py @@ -6,19 +6,14 @@ import pymongo from pymongo.database import Database -from shared.initialization.database import client - from .migrations import perform_migrations from .datamodel import import_datamodel from .report import import_example_reports, initialize_reports_overview from .secrets import initialize_secrets -def init_database() -> Database: # pragma: no feature-test-cover +def init_database(database: Database) -> None: # pragma: no feature-test-cover """Initialize the database contents.""" - db_client = client() - set_feature_compatibility_version(db_client.admin) - database = db_client.quality_time_db logging.info("Connected to database: %s", database) nr_reports = database.reports.count_documents({}) nr_measurements = database.measurements.count_documents({}) @@ -30,7 +25,6 @@ def init_database() -> Database: # pragma: no feature-test-cover perform_migrations(database) if os.environ.get("LOAD_EXAMPLE_REPORTS", "True").lower() == "true": import_example_reports(database) - return database def set_feature_compatibility_version(admin_database: Database) -> None: diff --git a/components/api_server/src/quality_time_server.py b/components/api_server/src/quality_time_server.py index fc0ccfa122..10a66aea0b 100644 --- a/components/api_server/src/quality_time_server.py +++ b/components/api_server/src/quality_time_server.py @@ -7,9 +7,11 @@ import logging import os -import bottle +from bottle import run -from initialization.database import init_database +from shared.initialization.database import get_database, mongo_client + +from initialization.database import init_database, set_feature_compatibility_version from initialization.bottle import init_bottle @@ -18,10 +20,14 @@ def serve() -> None: # pragma: no feature-test-cover log_level = str(os.getenv("API_SERVER_LOG_LEVEL", "WARNING")) logger = logging.getLogger() logger.setLevel(log_level) - database = init_database() - init_bottle(database) - server_port = int(os.getenv("API_SERVER_PORT", "5001")) - bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104 + with mongo_client() as client: + admin_database = get_database(client, "admin") + set_feature_compatibility_version(admin_database) + database = get_database(client) + init_database(database) + init_bottle(database) + server_port = int(os.getenv("API_SERVER_PORT", "5001")) + run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104 if __name__ == "__main__": # pragma: no feature-test-cover, pragma: no cover diff --git a/components/api_server/tests/initialization/test_database.py b/components/api_server/tests/initialization/test_database.py index 84f296b194..c6f3e87fb7 100644 --- a/components/api_server/tests/initialization/test_database.py +++ b/components/api_server/tests/initialization/test_database.py @@ -14,7 +14,6 @@ class DatabaseInitTest(DataModelTestCase): def setUp(self): """Extend to set up the Mongo client and database contents.""" super().setUp() - self.mongo_client = Mock() self.database.reports.find.return_value = [] self.database.reports.distinct.return_value = [] self.database.datamodels.find_one.return_value = None @@ -24,7 +23,6 @@ def setUp(self): self.database.sessions.find_one.return_value = {"user": "jodoe"} self.database.measurements.count_documents.return_value = 0 self.database.measurements.index_information.return_value = {} - self.mongo_client().quality_time_db = self.database def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None: """Initialize the database.""" @@ -35,9 +33,8 @@ def init_database(self, data_model_json: str, assert_glob_called: bool = True) - "open", mock_open(read_data=data_model_json), ), - patch("pymongo.MongoClient", self.mongo_client), ): - init_database() + init_database(self.database) if assert_glob_called: glob_mock.assert_called() else: diff --git a/components/api_server/tests/test_quality_time_server.py b/components/api_server/tests/test_quality_time_server.py index 7d02c83335..92c0440b16 100644 --- a/components/api_server/tests/test_quality_time_server.py +++ b/components/api_server/tests/test_quality_time_server.py @@ -2,29 +2,30 @@ import logging import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, MagicMock, Mock from quality_time_server import serve +@patch("quality_time_server.mongo_client", MagicMock()) @patch("quality_time_server.init_database", Mock()) @patch("bottle.install", Mock()) class APIServerTestCase(unittest.TestCase): """Unit tests for starting the API-server.""" - @patch("bottle.run") + @patch("quality_time_server.run") def test_start(self, mocked_run): """Test that the server is started.""" serve() mocked_run.assert_called_once() - @patch("bottle.run", Mock()) + @patch("quality_time_server.run", Mock()) def test_default_log_level(self): """Test the default logging level.""" serve() self.assertEqual("WARNING", logging.getLevelName(logging.getLogger().getEffectiveLevel())) - @patch("bottle.run", Mock()) + @patch("quality_time_server.run", Mock()) @patch( "os.getenv", Mock(side_effect=lambda key, default=None: "DEBUG" if key == "API_SERVER_LOG_LEVEL" else default), diff --git a/components/collector/src/quality_time_collector.py b/components/collector/src/quality_time_collector.py index de20dc5222..dcadcf2fa9 100644 --- a/components/collector/src/quality_time_collector.py +++ b/components/collector/src/quality_time_collector.py @@ -4,7 +4,7 @@ import logging from typing import NoReturn -from shared.initialization.database import database_connection +from shared.initialization.database import get_database, mongo_client # Make sure subclasses are registered import metric_collectors # noqa: F401 @@ -15,7 +15,8 @@ async def collect() -> NoReturn: """Collect the measurements indefinitely.""" logging.getLogger().setLevel(config.LOG_LEVEL) - await Collector(database_connection()).start() + with mongo_client() as client: + await Collector(get_database(client)).start() if __name__ == "__main__": diff --git a/components/collector/tests/base_collectors/test_collector.py b/components/collector/tests/base_collectors/test_collector.py index a61a6d1aa6..bf70785bf1 100644 --- a/components/collector/tests/base_collectors/test_collector.py +++ b/components/collector/tests/base_collectors/test_collector.py @@ -178,7 +178,7 @@ async def test_fetch_with_post_error(self): async def test_collect(self): """Test the collect method.""" with ( - patch("quality_time_collector.database_connection", return_value=self.database), + patch("quality_time_collector.get_database", return_value=self.database), self.assertRaises(RuntimeError), ): await quality_time_collector.collect() diff --git a/components/notifier/src/quality_time_notifier.py b/components/notifier/src/quality_time_notifier.py index 2d3c42799b..19f3d8cce0 100644 --- a/components/notifier/src/quality_time_notifier.py +++ b/components/notifier/src/quality_time_notifier.py @@ -5,7 +5,7 @@ import os from typing import NoReturn -from shared.initialization.database import database_connection +from shared.initialization.database import get_database, mongo_client from notifier.notifier import notify @@ -14,7 +14,8 @@ def start_notifications() -> NoReturn: """Notify indefinitely.""" logging.getLogger().setLevel(str(os.getenv("NOTIFIER_LOG_LEVEL", "WARNING"))) sleep_duration = int(os.getenv("NOTIFIER_SLEEP_DURATION", "60")) - asyncio.run(notify(database_connection(), sleep_duration)) + with mongo_client() as client: + asyncio.run(notify(get_database(client), sleep_duration)) if __name__ == "__main__": # pragma: no cover diff --git a/components/shared_code/src/shared/initialization/database.py b/components/shared_code/src/shared/initialization/database.py index f0ad3453a9..59d4dea3fa 100644 --- a/components/shared_code/src/shared/initialization/database.py +++ b/components/shared_code/src/shared/initialization/database.py @@ -1,12 +1,15 @@ """Database initialization.""" +import contextlib import os +from collections.abc import Generator import pymongo from pymongo import database -def client() -> pymongo.MongoClient: # pragma: no feature-test-cover +@contextlib.contextmanager +def mongo_client() -> Generator[pymongo.MongoClient]: # pragma: no feature-test-cover """Return a pymongo client.""" database_url = os.environ.get("DATABASE_URL") if not database_url: @@ -16,9 +19,13 @@ def client() -> pymongo.MongoClient: # pragma: no feature-test-cover db_port = os.environ.get("DATABASE_PORT", "27017") database_url = f"mongodb://{db_user}:{db_pass}@{db_host}:{db_port}" - return pymongo.MongoClient(database_url) + client: pymongo.MongoClient = pymongo.MongoClient(database_url) + try: + yield client + finally: + client.close() -def database_connection() -> database.Database: # pragma: no feature-test-cover - """Return a pymongo database.""" - return client()["quality_time_db"] +def get_database(client: pymongo.MongoClient, database_name: str = "quality_time_db") -> database.Database: + """Return one Mongo database.""" + return client.get_database(database_name) diff --git a/components/shared_code/tests/shared/database/test_connection_params.py b/components/shared_code/tests/shared/database/test_connection_params.py deleted file mode 100644 index 3b246194df..0000000000 --- a/components/shared_code/tests/shared/database/test_connection_params.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Test the database connection parameters.""" - -import unittest -from unittest.mock import Mock, patch - -from shared.initialization.database import client - - -class TestConnectionParams(unittest.TestCase): - """Test the database connection parameters.""" - - def _assert_dbclient_host_url(self, dbclient, expected_url) -> None: - """Assert that the dbclient was initialized with expected url.""" - self.assertEqual(expected_url, dbclient._init_kwargs["host"]) # noqa: SLF001 - - def test_default(self): - """Test the default url.""" - db = client() - _default_user_pass = "root:root" # nosec # noqa: S105 - self._assert_dbclient_host_url(db, f"mongodb://{_default_user_pass}@localhost:27017") - db.close() - - def test_full_url_override(self): - """Test setting full url with env var override.""" - local_url = "mongodb://localhost" - with patch("shared.initialization.database.os.environ.get", Mock(return_value=local_url)): - db = client() - self._assert_dbclient_host_url(db, local_url) - db.close() - - def test_partial_url_override(self): - """Test setting partial url with env var overrides.""" - - def _os_environ_get(variable_name, default=None): # noqa: ANN202 - """Mock method for os.environ.get calls in shared.initialization.database.""" - values = { - "DATABASE_USERNAME": "user", - "DATABASE_PASSWORD": "pass", - "DATABASE_HOST": "host", - "DATABASE_PORT": 4242, - } - return values.get(variable_name, default) - - with patch("shared.initialization.database.os.environ.get", Mock(side_effect=_os_environ_get)): - db = client() - self._assert_dbclient_host_url(db, "mongodb://user:pass@host:4242") - db.close() diff --git a/components/shared_code/tests/shared/initialization/test_database.py b/components/shared_code/tests/shared/initialization/test_database.py index 72ab2ce9bc..a3e8b9f74a 100644 --- a/components/shared_code/tests/shared/initialization/test_database.py +++ b/components/shared_code/tests/shared/initialization/test_database.py @@ -1,19 +1,58 @@ """Unit tests for the database initialization.""" +import unittest from unittest.mock import Mock, patch import mongomock +import pymongo -from shared.initialization.database import database_connection +from shared.initialization.database import get_database, mongo_client -from tests.shared.base import DataModelTestCase +OS_ENVIRON_GET = "shared.initialization.database.os.environ.get" -class DatabaseInitTest(DataModelTestCase): +class TestConnectionParams(unittest.TestCase): + """Test the database connection parameters.""" + + def _assert_dbclient_host_url(self, client: pymongo.MongoClient, expected_url: str) -> None: + """Assert that the dbclient was initialized with expected url.""" + self.assertEqual(expected_url, client._init_kwargs["host"]) # noqa: SLF001 + + def test_default(self): + """Test the default url.""" + with mongo_client() as client: + _default_user_pass = "root:root" # nosec # noqa: S105 + self._assert_dbclient_host_url(client, f"mongodb://{_default_user_pass}@localhost:27017") + + def test_full_url_override(self): + """Test setting full url with env var override.""" + local_url = "mongodb://localhost" + with patch(OS_ENVIRON_GET, Mock(return_value=local_url)), mongo_client() as client: + self._assert_dbclient_host_url(client, local_url) + + def test_partial_url_override(self): + """Test setting partial url with env var overrides.""" + + def _os_environ_get(variable_name, default=None): # noqa: ANN202 + """Mock method for os.environ.get calls in shared.initialization.database.""" + values = { + "DATABASE_USERNAME": "user", + "DATABASE_PASSWORD": "pass", + "DATABASE_HOST": "host", + "DATABASE_PORT": 4242, + } + return values.get(variable_name, default) + + with patch(OS_ENVIRON_GET, Mock(side_effect=_os_environ_get)), mongo_client() as client: + self._assert_dbclient_host_url(client, "mongodb://user:pass@host:4242") + + +class DatabaseInitTest(unittest.TestCase): """Unit tests for database initialization.""" @patch("shared.initialization.database.pymongo", return_value=mongomock) - def test_client(self, client: Mock) -> None: + def test_client(self, pymongo_mock: Mock) -> None: """Test that the client is called.""" - database_connection() - client.MongoClient.assert_called_once() + with mongo_client() as client: + get_database(client) + pymongo_mock.MongoClient.assert_called_once()