Skip to content

Commit

Permalink
Prevent resource warnings from pymongo
Browse files Browse the repository at this point in the history
To prevent resource warnings from pymongo when running tests, ensure the Mongo client connection is closed by wrapping it in a context manager.
  • Loading branch information
fniessink committed Sep 20, 2024
1 parent 2e9db0a commit bf88ebc
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 84 deletions.
8 changes: 1 addition & 7 deletions components/api_server/src/initialization/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
import pymongo
from pymongo.database import Database

from shared.initialization.database import client

from .migrations import perform_migrations
from .datamodel import import_datamodel
from .report import import_example_reports, initialize_reports_overview
from .secrets import initialize_secrets


def init_database() -> Database: # pragma: no feature-test-cover
def init_database(database: Database) -> None: # pragma: no feature-test-cover
"""Initialize the database contents."""
db_client = client()
set_feature_compatibility_version(db_client.admin)
database = db_client.quality_time_db
logging.info("Connected to database: %s", database)
nr_reports = database.reports.count_documents({})
nr_measurements = database.measurements.count_documents({})
Expand All @@ -30,7 +25,6 @@ def init_database() -> Database: # pragma: no feature-test-cover
perform_migrations(database)
if os.environ.get("LOAD_EXAMPLE_REPORTS", "True").lower() == "true":
import_example_reports(database)
return database


def set_feature_compatibility_version(admin_database: Database) -> None:
Expand Down
18 changes: 12 additions & 6 deletions components/api_server/src/quality_time_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import logging
import os

import bottle
from bottle import run

from initialization.database import init_database
from shared.initialization.database import get_database, mongo_client

from initialization.database import init_database, set_feature_compatibility_version
from initialization.bottle import init_bottle


Expand All @@ -18,10 +20,14 @@ def serve() -> None: # pragma: no feature-test-cover
log_level = str(os.getenv("API_SERVER_LOG_LEVEL", "WARNING"))
logger = logging.getLogger()
logger.setLevel(log_level)
database = init_database()
init_bottle(database)
server_port = int(os.getenv("API_SERVER_PORT", "5001"))
bottle.run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104
with mongo_client() as client:
admin_database = get_database(client, "admin")
set_feature_compatibility_version(admin_database)
database = get_database(client)
init_database(database)
init_bottle(database)
server_port = int(os.getenv("API_SERVER_PORT", "5001"))
run(server="gevent", host="0.0.0.0", port=server_port, reloader=True, log=logger) # nosec, # noqa: S104


if __name__ == "__main__": # pragma: no feature-test-cover, pragma: no cover
Expand Down
5 changes: 1 addition & 4 deletions components/api_server/tests/initialization/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class DatabaseInitTest(DataModelTestCase):
def setUp(self):
"""Extend to set up the Mongo client and database contents."""
super().setUp()
self.mongo_client = Mock()
self.database.reports.find.return_value = []
self.database.reports.distinct.return_value = []
self.database.datamodels.find_one.return_value = None
Expand All @@ -24,7 +23,6 @@ def setUp(self):
self.database.sessions.find_one.return_value = {"user": "jodoe"}
self.database.measurements.count_documents.return_value = 0
self.database.measurements.index_information.return_value = {}
self.mongo_client().quality_time_db = self.database

def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None:
"""Initialize the database."""
Expand All @@ -35,9 +33,8 @@ def init_database(self, data_model_json: str, assert_glob_called: bool = True) -
"open",
mock_open(read_data=data_model_json),
),
patch("pymongo.MongoClient", self.mongo_client),
):
init_database()
init_database(self.database)
if assert_glob_called:
glob_mock.assert_called()
else:
Expand Down
9 changes: 5 additions & 4 deletions components/api_server/tests/test_quality_time_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@

import logging
import unittest
from unittest.mock import patch, Mock
from unittest.mock import patch, MagicMock, Mock

from quality_time_server import serve


@patch("quality_time_server.mongo_client", MagicMock())
@patch("quality_time_server.init_database", Mock())
@patch("bottle.install", Mock())
class APIServerTestCase(unittest.TestCase):
"""Unit tests for starting the API-server."""

@patch("bottle.run")
@patch("quality_time_server.run")
def test_start(self, mocked_run):
"""Test that the server is started."""
serve()
mocked_run.assert_called_once()

@patch("bottle.run", Mock())
@patch("quality_time_server.run", Mock())
def test_default_log_level(self):
"""Test the default logging level."""
serve()
self.assertEqual("WARNING", logging.getLevelName(logging.getLogger().getEffectiveLevel()))

@patch("bottle.run", Mock())
@patch("quality_time_server.run", Mock())
@patch(
"os.getenv",
Mock(side_effect=lambda key, default=None: "DEBUG" if key == "API_SERVER_LOG_LEVEL" else default),
Expand Down
5 changes: 3 additions & 2 deletions components/collector/src/quality_time_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database, mongo_client

# Make sure subclasses are registered
import metric_collectors # noqa: F401
Expand All @@ -15,7 +15,8 @@
async def collect() -> NoReturn:
"""Collect the measurements indefinitely."""
logging.getLogger().setLevel(config.LOG_LEVEL)
await Collector(database_connection()).start()
with mongo_client() as client:
await Collector(get_database(client)).start()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def test_fetch_with_post_error(self):
async def test_collect(self):
"""Test the collect method."""
with (
patch("quality_time_collector.database_connection", return_value=self.database),
patch("quality_time_collector.get_database", return_value=self.database),
self.assertRaises(RuntimeError),
):
await quality_time_collector.collect()
Expand Down
5 changes: 3 additions & 2 deletions components/notifier/src/quality_time_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import NoReturn

from shared.initialization.database import database_connection
from shared.initialization.database import get_database, mongo_client

from notifier.notifier import notify

Expand All @@ -14,7 +14,8 @@ def start_notifications() -> NoReturn:
"""Notify indefinitely."""
logging.getLogger().setLevel(str(os.getenv("NOTIFIER_LOG_LEVEL", "WARNING")))
sleep_duration = int(os.getenv("NOTIFIER_SLEEP_DURATION", "60"))
asyncio.run(notify(database_connection(), sleep_duration))
with mongo_client() as client:
asyncio.run(notify(get_database(client), sleep_duration))


if __name__ == "__main__": # pragma: no cover
Expand Down
17 changes: 12 additions & 5 deletions components/shared_code/src/shared/initialization/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Database initialization."""

import contextlib
import os
from collections.abc import Generator

import pymongo
from pymongo import database


def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
@contextlib.contextmanager
def mongo_client() -> Generator[pymongo.MongoClient]: # pragma: no feature-test-cover
"""Return a pymongo client."""
database_url = os.environ.get("DATABASE_URL")
if not database_url:
Expand All @@ -16,9 +19,13 @@ def client() -> pymongo.MongoClient: # pragma: no feature-test-cover
db_port = os.environ.get("DATABASE_PORT", "27017")
database_url = f"mongodb://{db_user}:{db_pass}@{db_host}:{db_port}"

return pymongo.MongoClient(database_url)
client: pymongo.MongoClient = pymongo.MongoClient(database_url)
try:
yield client
finally:
client.close()


def database_connection() -> database.Database: # pragma: no feature-test-cover
"""Return a pymongo database."""
return client()["quality_time_db"]
def get_database(client: pymongo.MongoClient, database_name: str = "quality_time_db") -> database.Database:
"""Return one Mongo database."""
return client.get_database(database_name)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
"""Unit tests for the database initialization."""

import unittest
from unittest.mock import Mock, patch

import mongomock
import pymongo

from shared.initialization.database import database_connection
from shared.initialization.database import get_database, mongo_client

from tests.shared.base import DataModelTestCase
OS_ENVIRON_GET = "shared.initialization.database.os.environ.get"


class DatabaseInitTest(DataModelTestCase):
class TestConnectionParams(unittest.TestCase):
"""Test the database connection parameters."""

def _assert_dbclient_host_url(self, client: pymongo.MongoClient, expected_url: str) -> None:
"""Assert that the dbclient was initialized with expected url."""
self.assertEqual(expected_url, client._init_kwargs["host"]) # noqa: SLF001

def test_default(self):
"""Test the default url."""
with mongo_client() as client:
_default_user_pass = "root:root" # nosec # noqa: S105
self._assert_dbclient_host_url(client, f"mongodb://{_default_user_pass}@localhost:27017")

def test_full_url_override(self):
"""Test setting full url with env var override."""
local_url = "mongodb://localhost"
with patch(OS_ENVIRON_GET, Mock(return_value=local_url)), mongo_client() as client:
self._assert_dbclient_host_url(client, local_url)

def test_partial_url_override(self):
"""Test setting partial url with env var overrides."""

def _os_environ_get(variable_name, default=None): # noqa: ANN202
"""Mock method for os.environ.get calls in shared.initialization.database."""
values = {
"DATABASE_USERNAME": "user",
"DATABASE_PASSWORD": "pass",
"DATABASE_HOST": "host",
"DATABASE_PORT": 4242,
}
return values.get(variable_name, default)

with patch(OS_ENVIRON_GET, Mock(side_effect=_os_environ_get)), mongo_client() as client:
self._assert_dbclient_host_url(client, "mongodb://user:pass@host:4242")


class DatabaseInitTest(unittest.TestCase):
"""Unit tests for database initialization."""

@patch("shared.initialization.database.pymongo", return_value=mongomock)
def test_client(self, client: Mock) -> None:
def test_client(self, pymongo_mock: Mock) -> None:
"""Test that the client is called."""
database_connection()
client.MongoClient.assert_called_once()
with mongo_client() as client:
get_database(client)
pymongo_mock.MongoClient.assert_called_once()

0 comments on commit bf88ebc

Please sign in to comment.