diff --git a/Makefile b/Makefile index 489aae6..7680ea0 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,11 @@ setup-db: | tr '\n' ' ' \ | sed "s|value='\(.*\)'|sqlite3 '\1' < data/setup.sql|") +populate-db: + $(shell yq -o='shell' '.env_variables.SQLITE_FILE' config/local.toml \ + | tr '\n' ' ' \ + | sed "s|value='\(.*\)'|sqlite3 '\1' < data/initial_data.sql|") + .PHONY: all test venv run clean coverage diff --git a/app/data_service/data_service.py b/app/data_service/data_service.py index d1df15b..4d3bdfe 100644 --- a/app/data_service/data_service.py +++ b/app/data_service/data_service.py @@ -1,7 +1,7 @@ from uuid import UUID from .sqlite3 import Sqlite3Driver -from app.models import Survey +from app.models import Survey, TextQuestion class DataService: @@ -21,4 +21,10 @@ def get_survey_if_open(self, survey_uid: UUID) -> Survey | None: return None def insert_survey(self, survey: Survey) -> None: - return self._driver.insert_survey(survey) + return self._driver.insert_survey(survey=survey) + + def get_text_question(self, question_uid: UUID) -> TextQuestion | None: + return self._driver.get_text_question(question_uid=question_uid) + + def get_text_questions_from_survey(self, survey_uid: UUID) -> list[TextQuestion]: + return self._driver.get_text_questions_from_survey(survey_uid=survey_uid) diff --git a/app/data_service/sqlite3/driver.py b/app/data_service/sqlite3/driver.py index ff6b9b8..daefba9 100644 --- a/app/data_service/sqlite3/driver.py +++ b/app/data_service/sqlite3/driver.py @@ -5,7 +5,7 @@ from uuid import UUID from .model_factory import fetch_query_results_as_model -from app.models import Survey +from app.models import Survey, TextQuestion class Sqlite3Driver: @@ -46,3 +46,23 @@ def insert_survey(self, survey: Survey) -> None: with self._get_cursor() as cursor: cursor.execute(query, (str(survey.uid), survey.name, survey.is_open)) + + def get_text_question(self, question_uid: UUID) -> TextQuestion | None: + query = "SELECT * FROM text_question WHERE uid = ? LIMIT 1;" + + with self._get_cursor() as cursor: + cursor.execute(query, (str(question_uid),)) + result = fetch_query_results_as_model(cursor, TextQuestion) + + if len(result) > 0: + return result[0] + return None + + def get_text_questions_from_survey(self, survey_uid: UUID) -> list[TextQuestion]: + query = "SELECT * FROM text_question WHERE survey_uid = ?;" + + with self._get_cursor() as cursor: + cursor.execute(query, (str(survey_uid),)) + results = fetch_query_results_as_model(cursor, TextQuestion) + + return results diff --git a/app/models/__init__.py b/app/models/__init__.py index 5a56549..15ef524 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["Survey"] +__all__ = ["Survey", "TextQuestion"] from .survey import Survey +from .text_question import TextQuestion diff --git a/app/models/base_data_model.py b/app/models/base_data_model.py new file mode 100644 index 0000000..8ce47c8 --- /dev/null +++ b/app/models/base_data_model.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class BaseDataModel(BaseModel): + """ """ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return all( + self.__getattribute__(field) == other.__getattribute__(field) + for field in self.__annotations__.keys() + ) diff --git a/app/models/question.py b/app/models/question.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/models/survey.py b/app/models/survey.py index 92b24ee..c400e2a 100644 --- a/app/models/survey.py +++ b/app/models/survey.py @@ -1,19 +1,11 @@ -from pydantic import BaseModel, Field from uuid import UUID, uuid4 +from pydantic import Field -class Survey(BaseModel): +from .base_data_model import BaseDataModel + + +class Survey(BaseDataModel): uid: UUID = Field(default_factory=uuid4) name: str is_open: bool - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Survey): - return NotImplemented - return all( - ( - self.uid == other.uid, - self.name == other.name, - self.is_open == other.is_open, - ) - ) diff --git a/app/models/text_question.py b/app/models/text_question.py new file mode 100644 index 0000000..7542581 --- /dev/null +++ b/app/models/text_question.py @@ -0,0 +1,10 @@ +from uuid import UUID, uuid4 +from pydantic import Field + +from .base_data_model import BaseDataModel + + +class TextQuestion(BaseDataModel): + uid: UUID = Field(default_factory=uuid4) + survey_uid: UUID + question: str diff --git a/app/routes.py b/app/routes.py index 889dc57..cbf253c 100644 --- a/app/routes.py +++ b/app/routes.py @@ -46,12 +46,13 @@ def get_open_surveys() -> Dict[str, Any]: def get_survey(uid: str) -> str: data_service: DataService = app.data_service # type: ignore[attr-defined] survey_uid = UUID(uid) - survey = data_service.get_survey_if_open(survey_uid) + survey = data_service.get_survey_if_open(survey_uid=survey_uid) + questions = data_service.get_text_questions_from_survey(survey_uid=survey_uid) if survey is None: abort(404, "Could not find a survey with that UUID.") - return render_template("survey.html", survey=survey) + return render_template("survey.html", survey=survey, questions=questions) @app.route("/surveys/new") diff --git a/app/templates/survey.html b/app/templates/survey.html index 2611135..4e0573b 100644 --- a/app/templates/survey.html +++ b/app/templates/survey.html @@ -1,8 +1,14 @@ {% extends '_layout.html' %} {% block content %} -

{{survey.name}}

+

{{survey.name}}

-Not Implemented +
+{% for question in questions %} +

{{question.question}}

+{% endfor %} +
+ +Answering Questions is not yet implemented. {% endblock %} diff --git a/data/initial_data.sql b/data/initial_data.sql new file mode 100644 index 0000000..220a10a --- /dev/null +++ b/data/initial_data.sql @@ -0,0 +1,23 @@ +INSERT INTO survey ( + uid + , name + , is_open +) VALUES ( + '4b5bfb06-2060-4abf-b5fd-3bae5dcf72b9' + , 'Example Survey #1' + , TRUE +); + +INSERT INTO text_question ( + uid + , survey_uid + , question +) VALUES ( + '9c9facb5-f360-4155-852a-8e2ac04607ea' + , '4b5bfb06-2060-4abf-b5fd-3bae5dcf72b9' + , 'What is your name?' +), ( + 'ee947616-3d16-4095-bc8f-603be72022d3' + , '4b5bfb06-2060-4abf-b5fd-3bae5dcf72b9' + , 'Are you sure?' +); diff --git a/data/setup.sql b/data/setup.sql index 72c5fb4..ded0dd9 100644 --- a/data/setup.sql +++ b/data/setup.sql @@ -8,20 +8,11 @@ CREATE TABLE IF NOT EXISTS survey ( , name TEXT ); -CREATE TABLE IF NOT EXISTS question ( +CREATE TABLE IF NOT EXISTS text_question ( uid TEXT PRIMARY KEY , survey_uid TEXT , question TEXT , FOREIGN KEY(survey_uid) REFERENCES survey(uid) ); -CREATE TABLE IF NOT EXISTS ranking ( - uid TEXT PRIMARY KEY - , question_uid TEXT - , first_dimension TEXT - , second_dimension TEXT - , third_dimension TEXT - , FOREIGN KEY(question_uid) REFERENCES question(uid) -); - COMMIT; diff --git a/tests/conftest.py b/tests/conftest.py index 65024f5..3e0e794 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,10 @@ import pytest from pathlib import Path -from uuid import uuid4 +from uuid import uuid4, UUID from app.routes import app from app.data_service.sqlite3 import Sqlite3Driver -from app.models import Survey +from app.models import Survey, TextQuestion @pytest.fixture @@ -47,9 +47,18 @@ def app_client(): @pytest.fixture -def new_survey_open(): - # survey with a random UUID +def open_survey() -> Survey: return Survey( + uid=UUID("f21ccd82-83d6-40bc-8e60-703382f73860"), name="Test Survey", is_open=True, ) + + +@pytest.fixture +def text_question() -> TextQuestion: + return TextQuestion( + uid=UUID("8c7d6885-0ebb-4870-a8e1-a4630497a089"), + survey_uid=UUID("f21ccd82-83d6-40bc-8e60-703382f73860"), + question="When is it time?", + ) diff --git a/tests/test_assets/test_db_data.sql b/tests/test_assets/test_db_data.sql index bcbe657..11a8445 100644 --- a/tests/test_assets/test_db_data.sql +++ b/tests/test_assets/test_db_data.sql @@ -12,9 +12,37 @@ INSERT INTO survey ( ( "00000000-a087-4fb6-a123-24ff30263530" , true - , "Open Test Survey" + , "Open Test Survey - 1Q" +), +( + "00000000-b37a-32b3-19d9-72ec921021e3" + , true + , "Open Test Survey - 2Qs" ); - +INSERT INTO text_question ( + uid + , survey_uid + , question +) VALUES ( + "11111111-9c88-4b81-9de4-bac7444fbb0a" + , "00000000-9c88-4b81-9de4-bac7444fbb0a" + , "What story would you tell your best friend about this company?" +), +( + "11111111-a087-4fb6-a123-24ff30263530" + , "00000000-a087-4fb6-a123-24ff30263530" + , "What stands out to you about current quarterly plan?" +), +( + "11111111-b37a-32b3-19d9-72ec921021e3" + , "00000000-b37a-32b3-19d9-72ec921021e3" + , "What story?" +), +( + "11111111-b37a-44a1-19d9-72ec921021e3" + , "00000000-b37a-32b3-19d9-72ec921021e3" + , "What story?" +); COMMIT; diff --git a/tests/unit/data_service/sqlite3/test_driver.py b/tests/unit/data_service/sqlite3/test_driver.py index de21f60..cd4f432 100644 --- a/tests/unit/data_service/sqlite3/test_driver.py +++ b/tests/unit/data_service/sqlite3/test_driver.py @@ -2,73 +2,142 @@ import pytest from uuid import UUID -from app.models import Survey +from app.models import Survey, TextQuestion +from app.data_service.sqlite3 import Sqlite3Driver + + +class TestDriver: + def test_sqlite3_driver_enforces_foreign_key_constraints(self, empty_db_driver): + with empty_db_driver._get_cursor() as cursor: + with pytest.raises(sqlite3.IntegrityError): + cursor.execute( + "INSERT INTO text_question (uid, survey_uid, question) VALUES (" + ' "ed7b2f97-cd9d-4786-9266-a9397172397b", ' + ' "f54e6029-a7bd-4b74-a4a4-e0bbbe1435eb", ' + ' "WHAT AM I???"' + ");" + ) + + +class TestDriverSurveyMethods: + def test_sqlite3_driver_can_get_list_open_surveys(self, populated_db_driver): + surveys = populated_db_driver.get_open_surveys() + + assert len(surveys) == 2 + for survey in surveys: + assert isinstance(survey, Survey) + + @pytest.mark.parametrize( + "survey_uid", + ( + UUID("00000000-9c88-4b81-9de4-bac7444fbb0a"), + UUID("00000000-a087-4fb6-a123-24ff30263530"), + ), + ) + def test_driver_get_survey_returns_appropriate_survey( + self, populated_db_driver, survey_uid + ): + survey = populated_db_driver.get_survey(survey_uid=survey_uid) + assert survey is not None + assert survey.uid == survey_uid + + @pytest.mark.parametrize( + "survey_uid", + ( + UUID("99999999-9c88-4b81-9de4-bac7444fbb0a"), + UUID("99999999-a087-4fb6-a123-24ff30263530"), + ), + ) + def test_driver_get_survey_returns_none_when_no_survey_found( + self, populated_db_driver, survey_uid + ): + survey = populated_db_driver.get_survey(survey_uid=survey_uid) + assert survey is None + + def test_sqlite3_driver_throws_error_if_adding_a_survey_that_already_exists( + self, + populated_db_driver, + ): + surveys = populated_db_driver.get_open_surveys() + survey = surveys[0] - -def test_sqlite3_driver_enforces_foreign_key_constraints(empty_db_driver): - with empty_db_driver._get_cursor() as cursor: with pytest.raises(sqlite3.IntegrityError): - cursor.execute( - "INSERT INTO question (uid, survey_uid, question) VALUES (" - ' "ed7b2f97-cd9d-4786-9266-a9397172397b", ' - ' "f54e6029-a7bd-4b74-a4a4-e0bbbe1435eb", ' - ' "WHAT AM I???"' - ");" - ) - - -def test_sqlite3_driver_can_get_list_open_surveys(populated_db_driver): - surveys = populated_db_driver.get_open_surveys() - - assert len(surveys) == 1 - for survey in surveys: - assert isinstance(survey, Survey) - - -@pytest.mark.parametrize( - "survey_uid", - ( - "00000000-9c88-4b81-9de4-bac7444fbb0a", - "00000000-a087-4fb6-a123-24ff30263530", - ), -) -def test_driver_get_survey_returns_appropriate_survey(populated_db_driver, survey_uid): - survey = populated_db_driver.get_survey(survey_uid=survey_uid) - assert survey is not None - assert survey.uid == UUID(survey_uid) - - -@pytest.mark.parametrize( - "survey_uid", - ( - "99999999-9c88-4b81-9de4-bac7444fbb0a", - "99999999-a087-4fb6-a123-24ff30263530", - ), -) -def test_driver_get_survey_returns_none_when_no_survey_found( - populated_db_driver, survey_uid -): - survey = populated_db_driver.get_survey(survey_uid=UUID(survey_uid)) - assert survey is None - - -def test_sqlite3_driver_throws_error_if_adding_a_survey_that_already_exists( - populated_db_driver, -): - surveys = populated_db_driver.get_open_surveys() - survey = surveys[0] - - with pytest.raises(sqlite3.IntegrityError): - populated_db_driver.insert_survey(survey) - - -def test_sqlite3_driver_can_query_an_added_survey( - empty_db_driver, - new_survey_open, -): - empty_db_driver.insert_survey(new_survey_open) - - surveys = empty_db_driver.get_open_surveys() - - assert len(surveys) == 1 - assert surveys[0] == new_survey_open + populated_db_driver.insert_survey(survey) + + def test_sqlite3_driver_can_query_an_added_survey( + self, + empty_db_driver, + open_survey, + ): + empty_db_driver.insert_survey(open_survey) + + surveys = empty_db_driver.get_open_surveys() + + assert len(surveys) == 1 + assert surveys[0] == open_survey + + +class TestDriverQuestionMethods: + @pytest.mark.parametrize( + "text_question_uid", + ( + UUID("11111111-9c88-4b81-9de4-bac7444fbb0a"), + UUID("11111111-a087-4fb6-a123-24ff30263530"), + UUID("11111111-b37a-32b3-19d9-72ec921021e3"), + ), + ) + def test_driver_get_test_question( + self, populated_db_driver: Sqlite3Driver, text_question_uid: UUID + ): + question = populated_db_driver.get_text_question(question_uid=text_question_uid) + + assert isinstance(question, TextQuestion) + assert question.uid == text_question_uid + + def test_get_text_question_returns_none_if_no_question_with_uid_exists( + self, populated_db_driver: Sqlite3Driver + ): + question = populated_db_driver.get_text_question( + question_uid=UUID("209d67a3-d354-4cd8-afc4-7e6479582086") + ) + + assert question is None + + @pytest.mark.parametrize( + "survey_uid, expected_tq_uids", + ( + ( + UUID("00000000-9c88-4b81-9de4-bac7444fbb0a"), + {UUID("11111111-9c88-4b81-9de4-bac7444fbb0a")}, + ), + ( + UUID("00000000-a087-4fb6-a123-24ff30263530"), + {UUID("11111111-a087-4fb6-a123-24ff30263530")}, + ), + ( + UUID("00000000-b37a-32b3-19d9-72ec921021e3"), + { + UUID("11111111-b37a-44a1-19d9-72ec921021e3"), + UUID("11111111-b37a-32b3-19d9-72ec921021e3"), + }, + ), + ( + UUID("99999999-a087-4fb6-a123-24ff30263530"), + {}, + ), + ), + ) + def test_driver_can_list_questions_related_to_a_survey_uid( + self, + populated_db_driver: Sqlite3Driver, + survey_uid: UUID, + expected_tq_uids: set[UUID], + ): + text_questions = populated_db_driver.get_text_questions_from_survey( + survey_uid=survey_uid + ) + + assert len(text_questions) == len(expected_tq_uids) + for tq in text_questions: + assert isinstance(tq, TextQuestion) + assert tq.uid in expected_tq_uids diff --git a/tests/unit/data_service/test_data_service.py b/tests/unit/data_service/test_data_service.py index a21a733..70d830d 100644 --- a/tests/unit/data_service/test_data_service.py +++ b/tests/unit/data_service/test_data_service.py @@ -4,24 +4,34 @@ from app.data_service import DataService from app.data_service.sqlite3 import Sqlite3Driver -from app.models import Survey +from app.models import Survey, TextQuestion @pytest.fixture -def surveys() -> list[Survey]: +def surveys(open_survey) -> list[Survey]: + open_survey.uid = UUID("74bce4cf-0875-471b-a7c4-f25c7ef42864") return [ - Survey( - name="test survey", - is_open=True, + open_survey, + ] + + +@pytest.fixture +def questions() -> list[TextQuestion]: + return [ + TextQuestion( + question="What's up?", + survey_uid=UUID("74bce4cf-0875-471b-a7c4-f25c7ef42864"), ) ] @pytest.fixture -def data_service(surveys) -> DataService: +def data_service(surveys, questions) -> DataService: mock_driver = Mock(spec=Sqlite3Driver) mock_driver.get_open_surveys.return_value = surveys + mock_driver.get_text_questions_from_survey.return_value = questions + mock_driver.get_text_question.return_value = questions[0] mock_driver.insert_survey.return_value = None return DataService(driver=mock_driver) @@ -73,11 +83,32 @@ def test_get_survey_if_open_returns_none_if_survey_closed( survey = Survey(name="test", is_open=False) data_service._driver.get_survey.return_value = survey - returned_survey = data_service.get_survey_if_open(survey.uid) + returned_survey = data_service.get_survey_if_open(survey_uid=survey.uid) assert returned_survey is None - def test_can_insert_a_survey( - self, data_service: DataService, new_survey_open: Survey + def test_can_insert_a_survey(self, data_service: DataService, open_survey: Survey): + data_service.insert_survey(open_survey) + data_service._driver.insert_survey.assert_called_once_with(survey=open_survey) + + def test_get_questions_from_survey( + self, + data_service: DataService, + ): + survey_uid = UUID("ee50dd84-86a0-4a9d-a632-ec6670e2cd89") + questions = data_service.get_text_questions_from_survey(survey_uid=survey_uid) + + assert len(questions) == 1 + assert isinstance(questions[0], TextQuestion) + data_service._driver.get_text_questions_from_survey.assert_called_once_with( + survey_uid=survey_uid + ) + + def test_get_question( + self, + data_service: DataService, ): - data_service.insert_survey(new_survey_open) - data_service._driver.insert_survey.assert_called_once_with(new_survey_open) + question_uid = UUID("ee50dd84-86a0-4a9d-a632-ec6670e2cd89") + question = data_service.get_text_question(question_uid=question_uid) + + assert isinstance(question, TextQuestion) + assert question == data_service._driver.get_text_question.return_value diff --git a/tests/unit/models/test_models.py b/tests/unit/models/test_models.py index 3bfc195..ae32af8 100644 --- a/tests/unit/models/test_models.py +++ b/tests/unit/models/test_models.py @@ -1,6 +1,6 @@ import pytest -from app.models import Survey +from app.models import Survey, TextQuestion @pytest.mark.parametrize( @@ -15,6 +15,21 @@ }, ], [Survey, {"is_open": True, "name": "Test Survey"}], # .uid should be optional + [ + TextQuestion, + { + "question": "How are you today?", + "survey_uid": "63163031-ce99-46c3-a70b-c3df75a51258", + }, + ], # .uid should be optional + [ + TextQuestion, + { + "uid": "bb92a5f5-7d62-4e77-9cbb-c8c903c4e65f", + "question": "How are you today?", + "survey_uid": "63163031-ce99-46c3-a70b-c3df75a51258", + }, + ], # type cast uid ], ) def test_all_models_enforce_type_hints(model_class, arguments): @@ -36,15 +51,59 @@ def test_all_models_enforce_type_hints(model_class, arguments): (Survey(name="Test", is_open=True), Survey(name="Test", is_open=True), False), ( Survey( - uid="bb92a5f5-7d62-4e77-9cbb-c8c903c4e65f", name="Test", is_open=True + uid="bb92a5f5-7d62-4e77-9cbb-c8c903c4e65f", + name="Test", + is_open=True, ), Survey( - uid="bb92a5f5-7d62-4e77-9cbb-c8c903c4e65f", name="Test", is_open=True + uid="bb92a5f5-7d62-4e77-9cbb-c8c903c4e65f", + name="Test", + is_open=True, ), True, ), + ( + TextQuestion( + question="What is?", + survey_uid="63163031-ce99-46c3-a70b-c3df75a51258", + ), + TextQuestion( + question="What is?", + survey_uid="63163031-ce99-46c3-a70b-c3df75a51258", + ), + False, + ), # UIDs should be different + ( + TextQuestion( + question="What is?", + uid="aa11a5f5-7d42-4e77-9cbb-c8c903c4e65f", + survey_uid="63163031-ce99-46c3-a70b-c3df75a51258", + ), + TextQuestion( + question="What is?", + uid="aa11a5f5-7d42-4e77-9cbb-c8c903c4e65f", + survey_uid="63163031-ce99-46c3-a70b-c3df75a51258", + ), + True, + ), + ( + TextQuestion( + uid="aa11a5f5-7d42-4e77-9cbb-c8c903c4e65f", + question="What is?", + survey_uid="63163031-ce99-46c3-a70b-c3df75a51258", + ), + Survey( + uid="63163031-ce99-46c3-a70b-c3df75a51258", + name="Test", + is_open=True, + ), + False, + ), ), ) def test_equality_function_works_on_all_functions(value1, value2, expected): + """ + Equality of these models means that all fields have the same values. + """ result = value1 == value2 assert result == expected diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index 443bd2f..5f96f3d 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from app.data_service import DataService -from app.models import Survey +from app.models import Survey, TextQuestion from tests.unit.test_routes_utils import ( assert_response_is_valid_htmx, @@ -13,10 +13,15 @@ @pytest.fixture -def mock_data_service_class(monkeypatch): +def mock_data_service(monkeypatch, open_survey: Survey, text_question: TextQuestion): data_service = MagicMock() monkeypatch.setattr("app.routes.DataService", data_service) + data_service.return_value.get_survey_if_open.return_value = open_survey + data_service.return_value.get_text_questions_from_survey.return_value = ( + text_question + ) + yield data_service monkeypatch.undo() @@ -40,7 +45,7 @@ class TestGetHTMLEndpoints: def test_get_requests_provide_an_html_page( self, app_client, - mock_data_service_class, + mock_data_service, endpoint: str, ): response = app_client.get(endpoint) @@ -53,13 +58,38 @@ def test_get_requests_provide_an_html_page( def test_get_request_returns_404( self, app_client, - mock_data_service_class, + mock_data_service, endpoint: str, ): - mock_data_service_class.return_value.get_survey_if_open.return_value = None + mock_data_service.return_value.get_survey_if_open.return_value = None response = app_client.get(endpoint) assert response.status_code == 404 + @pytest.mark.parametrize( + "slug, expected_data_service_calls", + ( + ( + "/surveys/00000000-a087-4fb6-a123-24ff30263530", + ("get_survey_if_open", "get_text_questions_from_survey"), + ), + ), + ) + def test_get_requests_make_expected_data_service_calls( + self, + app_client, + mock_data_service, + slug: str, + expected_data_service_calls: str, + ): + response = app_client.get(slug) + + for method in expected_data_service_calls: + assert_mocked_class_has_method_call_on_object( + mock_class=mock_data_service, + method_call=method, + ) + assert_response_is_valid_html(response) + class TestGetHTMXEndpoints: @@ -71,44 +101,48 @@ class TestGetHTMXEndpoints: test_cases = ( ( "/surveys", - "get_open_surveys", + ("get_open_surveys",), ), ) @pytest.mark.parametrize( - "slug, expected_data_service_call", + "slug, expected_data_service_calls", test_cases, ) def test_get_requests_return_partial_html_if_htmx_headers_are_present( self, app_client, - mock_data_service_class, + mock_data_service, slug: str, - expected_data_service_call: str, + expected_data_service_calls: str, ): response = app_client.get(slug, headers={"Hx-Request": "true"}) - assert_mocked_class_has_method_call_on_object( - mock_class=mock_data_service_class, - method_call=expected_data_service_call, - ) + + for method in expected_data_service_calls: + assert_mocked_class_has_method_call_on_object( + mock_class=mock_data_service, + method_call=method, + ) assert_response_is_valid_htmx(response) @pytest.mark.parametrize( - "slug, expected_data_service_call", + "slug, expected_data_service_calls", test_cases, ) def test_htmx_endpoints_returns_html_page_if_htmx_headers_are_not_present( self, app_client, - mock_data_service_class, + mock_data_service, slug: str, - expected_data_service_call: str, + expected_data_service_calls: str, ): response = app_client.get(slug) - assert_mocked_class_has_method_call_on_object( - mock_class=mock_data_service_class, - method_call=expected_data_service_call, - ) + + for method in expected_data_service_calls: + assert_mocked_class_has_method_call_on_object( + mock_class=mock_data_service, + method_call=method, + ) assert_response_is_valid_html(response) @@ -153,12 +187,12 @@ def test_post_requests_can_return_htmx( data: dict[str, str], expected_data_service_call: str, argument_class: str, - mock_data_service_class: DataService, + mock_data_service: DataService, ): response = app_client.post(slug, data=data, headers={"Hx-Request": "true"}) assert_mocked_class_has_method_call_on_object( - mock_class=mock_data_service_class, + mock_class=mock_data_service, method_call=expected_data_service_call, argument_types=[argument_class], ) @@ -175,12 +209,12 @@ def test_post_requests_can_return_html( data: dict[str, str], expected_data_service_call: str, argument_class: str, - mock_data_service_class: DataService, + mock_data_service: DataService, ): response = app_client.post(slug, data=data) assert_mocked_class_has_method_call_on_object( - mock_class=mock_data_service_class, + mock_class=mock_data_service, method_call=expected_data_service_call, argument_types=[argument_class], ) diff --git a/tests/unit/test_routes_utils.py b/tests/unit/test_routes_utils.py index 0f81d7a..39b8ab2 100644 --- a/tests/unit/test_routes_utils.py +++ b/tests/unit/test_routes_utils.py @@ -47,3 +47,5 @@ def assert_mocked_class_has_method_call_on_object( assert len(method.call_args.args) == len(argument_types) for arg, of_type in zip(method.call_args.args, argument_types): assert isinstance(arg, of_type) + else: + method.assert_called_once()