Skip to content

Commit

Permalink
Add more tests for the queries route
Browse files Browse the repository at this point in the history
As a result, `test_utils` and the `queries` route handler have been also updated.
  • Loading branch information
pkhalaj committed May 29, 2024
1 parent d034e57 commit 7a37297
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 25 deletions.
8 changes: 2 additions & 6 deletions trolldb/api/routes/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ async def queries(
time_min: datetime.datetime = Query(default=None), # noqa: B008
time_max: datetime.datetime = Query(default=None)) -> list[str]: # noqa: B008
"""Please consult the auto-generated documentation by FastAPI."""
# We
pipelines = Pipelines()

if platform:
Expand All @@ -42,10 +41,7 @@ async def queries(
start_time = PipelineAttribute("start_time")
end_time = PipelineAttribute("end_time")
pipelines += (
(start_time >= time_min) |
(start_time <= time_max) |
(end_time >= time_min) |
(end_time <= time_max)
((start_time >= time_min) & (start_time <= time_max)) |
((end_time >= time_min) & (end_time <= time_max))
)

return await get_ids(collection.aggregate(pipelines))
86 changes: 68 additions & 18 deletions trolldb/test_utils/mongodb_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The module which provides testing utilities to make MongoDB databases/collections and fill them with test data."""
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime, timedelta
from random import choices, randint, shuffle
from typing import Iterator
Expand Down Expand Up @@ -114,17 +115,23 @@ def like_mongodb_document(self) -> dict:
class TestDatabase:
"""A static class which encloses functionalities to prepare and fill the test database with test data."""

unique_platform_names: list[str] = ["PA", "PB", "PC"]
"""The unique platform names that will be used to generate the sample of all platform names."""

# We suppress ruff (S311) here as we are not generating anything cryptographic here!
platform_names = choices(["PA", "PB", "PC"], k=10) # noqa: S311
platform_names = choices(["PA", "PB", "PC"], k=20) # noqa: S311
"""Example platform names.
Warning:
The value of this variable changes randomly every time. What you see above is just an example which has been
generated as a result of building the documentation!
"""

unique_sensors: list[str] = ["SA", "SB", "SC"]
"""The unique sensor names that will be used to generate the sample of all sensor names."""

# We suppress ruff (S311) here as we are not generating anything cryptographic here!
sensors = choices(["SA", "SB", "SC"], k=10) # noqa: S311
sensors = choices(["SA", "SB", "SC"], k=20) # noqa: S311
"""Example sensor names.
Warning:
Expand Down Expand Up @@ -192,6 +199,23 @@ def write_test_data(cls):
collection.delete_many({})
collection.insert_many(cls.documents)

@classmethod
def get_all_documents_from_database(cls) -> list[dict]:
"""Retrieves all the documents from the database.
Returns:
A list of all documents from the database. This matches the content of :obj:`~TestDatabase.documents` with
the addition of `IDs` which are assigned by the MongoDB.
"""
with mongodb_for_test_context() as client:
collection = client[
test_app_config.database.main_database_name
][
test_app_config.database.main_collection_name
]
documents = list(collection.find({}))
return documents

@classmethod
def find_min_max_datetime(cls):
"""Finds the minimum and the maximum for both the ``start_time`` and the ``end_time``.
Expand All @@ -212,27 +236,53 @@ def find_min_max_datetime(cls):
_max=dict(_id=None, _time="1900-01-01T00:00:00"))
)

with mongodb_for_test_context() as client:
collection = client[
test_app_config.database.main_database_name
][
test_app_config.database.main_collection_name
]
documents = collection.find({})
documents = cls.get_all_documents_from_database()

for document in documents:
for k in ["start_time", "end_time"]:
dt = document[k].isoformat()
if dt > result[k]["_max"]["_time"]:
result[k]["_max"]["_time"] = dt
result[k]["_max"]["_id"] = str(document["_id"])
for document in documents:
for k in ["start_time", "end_time"]:
dt = document[k].isoformat()
if dt > result[k]["_max"]["_time"]:
result[k]["_max"]["_time"] = dt
result[k]["_max"]["_id"] = str(document["_id"])

if dt < result[k]["_min"]["_time"]:
result[k]["_min"]["_time"] = dt
result[k]["_min"]["_id"] = str(document["_id"])
if dt < result[k]["_min"]["_time"]:
result[k]["_min"]["_time"] = dt
result[k]["_min"]["_id"] = str(document["_id"])

return result

@classmethod
def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None):
"""Matches the given query.
We first take all the documents and then progressively remove all that do not match the given queries until
we end up with those that match. When a query is ``None``, it does not have any effect on the results.
"""
documents = cls.get_all_documents_from_database()

buffer = deepcopy(documents)
for document in documents:
should_remove = False
if platform:
should_remove = document["platform_name"] not in platform

if sensor and not should_remove:
should_remove = document["sensor"] not in sensor

if time_min and time_max and not should_remove:
should_remove = document["end_time"] < time_min or document["start_time"] > time_max

Check warning on line 273 in trolldb/test_utils/mongodb_database.py

View check run for this annotation

Codecov / codecov/patch

trolldb/test_utils/mongodb_database.py#L273

Added line #L273 was not covered by tests

if time_min and not time_max and not should_remove:
should_remove = document["end_time"] < time_min

if time_max and not time_min and not should_remove:
should_remove = document["end_time"] > time_max

if should_remove and document in buffer:
buffer.remove(document)

return [str(item["_id"]) for item in buffer]

Check notice on line 284 in trolldb/test_utils/mongodb_database.py

View check run for this annotation

codefactor.io / CodeFactor

trolldb/test_utils/mongodb_database.py#L255-L284

Complex Method

@classmethod
def prepare(cls):
"""Prepares the MongoDB instance by first resetting the database and filling it with generated test data."""
Expand Down
98 changes: 97 additions & 1 deletion trolldb/tests/tests_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
"""

from collections import Counter
from datetime import datetime

import pytest
from fastapi import status

from trolldb.test_utils.common import http_get
from trolldb.test_utils.common import http_get, test_app_config
from trolldb.test_utils.mongodb_database import TestDatabase, mongodb_for_test_context

main_database_name = test_app_config.database.main_database_name
main_collection_name = test_app_config.database.main_collection_name


def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool:
"""Checks if the test and expected list of collection names match."""
Expand All @@ -26,6 +30,44 @@ def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bo
return Counter(test_ids) == Counter(expected_ids)


def single_query_is_correct(key: str, value: str | datetime) -> bool:
"""Checks if the given single query, denoted by ``key`` matches correctly against the ``value``."""
return (
Counter(http_get(f"queries?{key}={value}").json()) ==
Counter(TestDatabase.match_query(**{key: value}))
)


def query_results_are_correct(keys: list[str], values_list: list[list[str | datetime]]) -> bool:
"""Checks if the retrieved result from querying the database via the API matches the expected result.
There can be more than one query `key/value` pair.
Args:
keys:
A list of all query keys, e.g. ``keys=["platform", "sensor"]``
values_list:
A list in which each element is a list of values itself. The `nth` element corresponds to the `nth` key in
the ``keys``.
Returns:
A boolean flag indicating whether the retrieved result matches the expected result.
"""
# Make a single query string for all queries
query_buffer = []
for label, value_list in zip(keys, values_list, strict=True):
query_buffer += [f"{label}={value}" for value in value_list]
query_string = "&".join(query_buffer)

return (
Counter(http_get(f"queries?{query_string}").json()) ==
Counter(TestDatabase.match_query(
**{label: value_list for label, value_list in zip(keys, values_list, strict=True)}
))
)


@pytest.mark.usefixtures("_test_server_fixture")
def test_root():
"""Checks that the server is up and running, i.e. the root routes responds with 200."""
Expand Down Expand Up @@ -85,3 +127,57 @@ def test_collections_negative():
def test_datetime():
"""Checks that the datetime route works properly."""
assert http_get("datetime").json() == TestDatabase.find_min_max_datetime()


@pytest.mark.usefixtures("_test_server_fixture")
def test_queries_all():
"""Tests that the queries route returns all documents when no actual queries are given."""
assert document_ids_are_correct(
http_get("queries").json(),
[str(doc["_id"]) for doc in TestDatabase.get_all_documents_from_database()]
)


@pytest.mark.usefixtures("_test_server_fixture")
@pytest.mark.parametrize(("key", "values"), [
("platform", TestDatabase.unique_platform_names),
("sensor", TestDatabase.unique_sensors)
])
def test_queries_platform_or_sensor(key, values):
"""Tests the platform and sensor queries, one at a time.
There is only a single key in the query, but it has multiple corresponding values.
"""
for i in range(len(values)):
assert query_results_are_correct(
[key],
[values[:i]]
)


@pytest.mark.usefixtures("_test_server_fixture")
def test_queries_mix_platform_sensor():
"""Tests a mix of platform and sensor queries."""
for n_plt, n_sns in zip([1, 1, 2, 3, 3], [1, 3, 2, 1, 3], strict=False):
assert query_results_are_correct(
["platform", "sensor"],
[TestDatabase.unique_platform_names[:n_plt], TestDatabase.unique_sensors[:n_sns]]
)


@pytest.mark.usefixtures("_test_server_fixture")
def test_queries_time():
"""Checks that a single time query works properly."""
res = http_get("datetime").json()
time_min = datetime.fromisoformat(res["start_time"]["_min"]["_time"])
time_max = datetime.fromisoformat(res["end_time"]["_max"]["_time"])

assert single_query_is_correct(
"time_min",
time_min
)

assert single_query_is_correct(
"time_max",
time_max
)

0 comments on commit 7a37297

Please sign in to comment.