From cdec8f8a5caae69c604225dc9d4a13ce79d5a8a4 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 6 Feb 2024 22:01:12 +0000 Subject: [PATCH] fix: default action assertions --- codeforlife/tests/model_view_set.py | 258 ++++++++++++---------------- 1 file changed, 110 insertions(+), 148 deletions(-) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index e65af3e9..8dbaba22 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -18,10 +18,12 @@ from pyotp import TOTP from rest_framework import status from rest_framework.response import Response +from rest_framework.serializers import DateTimeField from rest_framework.test import APIClient, APIRequestFactory, APITestCase from ..permissions import Permission from ..serializers import ModelSerializer +from ..types import JsonDict from ..user.models import ( AuthFactor, NonSchoolTeacherUser, @@ -86,7 +88,7 @@ def _assert_response(self, response: Response, make_assertions: t.Callable): def _assert_response_json( self, response: Response, - make_assertions: t.Callable[[Data], None], + make_assertions: t.Callable[[JsonDict], None], ): self._assert_response( response, @@ -98,7 +100,7 @@ def _assert_response_json( def _assert_response_json_bulk( self, response: Response, - make_assertions: t.Callable[[t.List[Data]], None], + make_assertions: t.Callable[[t.List[JsonDict]], None], data: t.List[Data], ): def _make_assertions(): @@ -109,14 +111,6 @@ def _make_assertions(): self._assert_response(response, _make_assertions) - def _assert_data_contains_subset(self, request: Data, response: Data): - for key, value in request.items(): - response_value = response[key] - if isinstance(value, dict): - self._assert_data_contains_subset(value, response_value) - else: - assert value == response_value - @staticmethod def status_code_is_ok(status_code: int): """Check if the status code is greater than or equal to 200 and less @@ -131,62 +125,34 @@ def status_code_is_ok(status_code: int): return 200 <= status_code < 300 - def assert_data_equals_model( + def _assert_serialized_model_equals_json_model( self, - data: Data, model: AnyModel, - model_serializer_class: t.Optional[ - t.Type[ModelSerializer[AnyModel]] - ] = None, + json_model: JsonDict, contains_subset: bool = False, ): - # pylint: disable=line-too-long - """Check if the data equals the current state of the model instance. + model_view_set = self._model_view_set_class() + model_serializer_class = model_view_set.get_serializer_class() + serialized_model = model_serializer_class(model).data - Args: - data: The data to check. - model: The model instance. - model_serializer_class: The serializer used to serialize the model's data. - contains_subset: A flag designating whether the data is a subset of the serialized model. + datetime_to_representation = DateTimeField().to_representation - Returns: - A flag designating if the data equals the current state of the model - instance. - """ - # pylint: enable=line-too-long + def datetime_values_to_representation(data: ModelViewSetClient.Data): + for key, value in data.copy().items(): + if isinstance(value, dict): + datetime_values_to_representation(value) + elif isinstance(value, datetime): + data[key] = datetime_to_representation(value) - def parse_data(data): - if isinstance(data, list): - return [parse_data(value) for value in data] - if isinstance(data, dict): - return {key: parse_data(value) for key, value in data.items()} - if isinstance(data, datetime): - return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") - return data - - if model_serializer_class is None: - model_serializer_class = ( - # pylint: disable-next=no-member - self._test_case.model_serializer_class - or self._model_view_set_class().get_serializer_class() - ) + datetime_values_to_representation(serialized_model) - actual_data = parse_data(model_serializer_class(model).data) - - if contains_subset: - # pylint: disable-next=no-member - self._test_case.assertDictContainsSubset( - data, - actual_data, - "Data is not a subset of serialized model.", - ) - else: - # pylint: disable-next=no-member - self._test_case.assertDictEqual( - data, - actual_data, - "Data does not equal serialized model.", - ) + ( + # pylint: disable=no-member + self._test_case.assertDictContainsSubset + if contains_subset + else self._test_case.assertDictEqual + # pylint: enable=no-member + )(json_model, serialized_model) # pylint: disable-next=too-many-arguments def generic( @@ -231,11 +197,17 @@ def generic( return response + def _assert_create(self, json_model: JsonDict): + model = self._model_class.objects.get( + **{self._lookup_field: json_model[self._lookup_field]} + ) + self._assert_serialized_model_equals_json_model(model, json_model) + def create( self, data: Data, status_code_assertion: StatusCodeAssertion = status.HTTP_201_CREATED, - assert_data_contains_subset: bool = True, + make_assertions: bool = True, **kwargs, ): # pylint: disable=line-too-long @@ -244,7 +216,7 @@ def create( Args: data: The values for each field. status_code_assertion: The expected status code. - assert_data_contains_subset: Assert if the request model is a subset of the response model. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. @@ -260,13 +232,8 @@ def create( **kwargs, ) - if assert_data_contains_subset: - self._assert_response_json( - response, - make_assertions=lambda actual_data: ( - self._assert_data_contains_subset(data, actual_data) - ), - ) + if make_assertions: + self._assert_response_json(response, self._assert_create) return response @@ -274,7 +241,7 @@ def bulk_create( self, data: t.List[Data], status_code_assertion: StatusCodeAssertion = status.HTTP_201_CREATED, - assert_data_contains_subset: bool = True, + make_assertions: bool = True, **kwargs, ): # pylint: disable=line-too-long @@ -283,7 +250,7 @@ def bulk_create( Args: data: The values for each field, for each model. status_code_assertion: The expected status code. - assert_data_contains_subset: Assert if the request models are a subset of the response models. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. @@ -299,13 +266,13 @@ def bulk_create( **kwargs, ) - if assert_data_contains_subset: + if make_assertions: - def make_assertions(actual_data: t.List[ModelViewSetClient.Data]): - for model, actual_model in zip(data, actual_data): - self._assert_data_contains_subset(model, actual_model) + def _make_assertions(json_models: t.List[JsonDict]): + for json_model in json_models: + self._assert_create(json_model) - self._assert_response_json_bulk(response, make_assertions, data) + self._assert_response_json_bulk(response, _make_assertions, data) return response @@ -313,9 +280,7 @@ def retrieve( self, model: AnyModel, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, - model_serializer_class: t.Optional[ - t.Type[ModelSerializer[AnyModel]] - ] = None, + make_assertions: bool = True, **kwargs, ): # pylint: disable=line-too-long @@ -324,7 +289,7 @@ def retrieve( Args: model: The model to retrieve. status_code_assertion: The expected status code. - model_serializer_class: The serializer used to serialize the model's data. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. @@ -338,14 +303,15 @@ def retrieve( **kwargs, ) - self._assert_response_json( - response, - make_assertions=lambda actual_data: self.assert_data_equals_model( - actual_data, - model, - model_serializer_class, - ), - ) + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: ( + self._assert_serialized_model_equals_json_model( + model, json_model + ) + ), + ) return response @@ -353,9 +319,7 @@ def list( self, models: t.Iterable[AnyModel], status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, - model_serializer_class: t.Optional[ - t.Type[ModelSerializer[AnyModel]] - ] = None, + make_assertions: bool = True, filters: ListFilters = None, **kwargs, ): @@ -365,7 +329,7 @@ def list( Args: models: The model list to retrieve. status_code_assertion: The expected status code. - model_serializer_class: The serializer used to serialize the model's data. + make_assertions: A flag designating whether to make the default assertions. filters: The filters to apply to the list. Returns: @@ -389,26 +353,31 @@ def list( **kwargs, ) - def _make_assertions(actual_data: ModelViewSetClient.Data): - for data, model in zip(actual_data["data"], models): - self.assert_data_equals_model( - data, - model, - model_serializer_class, - ) + if make_assertions: + + def _make_assertions(response_json: JsonDict): + json_models = t.cast(t.List[JsonDict], response_json["data"]) + for model, json_model in zip(models, json_models): + self._assert_serialized_model_equals_json_model( + model, json_model + ) - self._assert_response_json(response, _make_assertions) + self._assert_response_json(response, _make_assertions) return response + def _assert_partial_update(self, model: AnyModel, json_model: JsonDict): + model.refresh_from_db() + self._assert_serialized_model_equals_json_model( + model, json_model, contains_subset=True + ) + def partial_update( self, model: AnyModel, data: Data, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, - model_serializer_class: t.Optional[ - t.Type[ModelSerializer[AnyModel]] - ] = None, + make_assertions: bool = True, **kwargs, ): # pylint: disable=line-too-long @@ -418,7 +387,7 @@ def partial_update( model: The model to partially update. data: The values for each field. status_code_assertion: The expected status code. - model_serializer_class: The serializer used to serialize the model's data. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. @@ -434,17 +403,14 @@ def partial_update( **kwargs, ) - def _make_assertions(actual_data: ModelViewSetClient.Data): - model.refresh_from_db() - self.assert_data_equals_model( - actual_data, - model, - model_serializer_class, - contains_subset=True, + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: ( + self._assert_partial_update(model, json_model) + ), ) - self._assert_response_json(response, _make_assertions) - return response def bulk_partial_update( @@ -452,9 +418,7 @@ def bulk_partial_update( models: t.List[AnyModel], data: t.List[Data], status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, - model_serializer_class: t.Optional[ - t.Type[ModelSerializer[AnyModel]] - ] = None, + make_assertions: bool = True, **kwargs, ): # pylint: disable=line-too-long @@ -464,7 +428,7 @@ def bulk_partial_update( models: The models to partially update. data: The values for each field, for each model. status_code_assertion: The expected status code. - model_serializer_class: The serializer used to serialize the model's data. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. @@ -480,39 +444,43 @@ def bulk_partial_update( **kwargs, ) - def make_assertions(actual_data: t.List[ModelViewSetClient.Data]): - models.sort(key=lambda model: getattr(model, self._lookup_field)) + if make_assertions: - for data, model in zip(actual_data, models): - model.refresh_from_db() - self.assert_data_equals_model( - data, - model, - model_serializer_class, - contains_subset=True, + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr(model, self._lookup_field) ) + for model, json_model in zip(models, json_models): + self._assert_partial_update(model, json_model) - self._assert_response_json_bulk(response, make_assertions, data) + self._assert_response_json_bulk(response, _make_assertions, data) return response + def _assert_destroy(self, lookup_values: t.List): + assert not self._model_class.objects.filter( + **{f"{self._lookup_field}__in": lookup_values} + ).exists() + def destroy( self, model: AnyModel, status_code_assertion: StatusCodeAssertion = status.HTTP_204_NO_CONTENT, - anonymized: bool = False, + make_assertions: bool = True, **kwargs, ): + # pylint: disable=line-too-long """Destroy a model. Args: model: The model to destroy. status_code_assertion: The expected status code. - anonymized: Whether or not the data is anonymized. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.delete( # pylint: disable-next=no-member @@ -521,36 +489,33 @@ def destroy( **kwargs, ) - if not anonymized: - - def _make_assertions(): - # pylint: disable-next=no-member - with self._test_case.assertRaises( - model.DoesNotExist # type: ignore[attr-defined] - ): - model.refresh_from_db() - - self._assert_response(response, _make_assertions) + if make_assertions: + self._assert_response( + response, + make_assertions=lambda: self._assert_destroy([model.pk]), + ) return response def bulk_destroy( self, - lookup_values: t.List[t.Any], + lookup_values: t.List, status_code_assertion: StatusCodeAssertion = status.HTTP_204_NO_CONTENT, - anonymized: bool = False, + make_assertions: bool = True, **kwargs, ): + # pylint: disable=line-too-long """Bulk destroy many instances of a model. Args: lookup_values: The models to lookup and destroy. status_code_assertion: The expected status code. - anonymized: Whether or not the data is anonymized. + make_assertions: A flag designating whether to make the default assertions. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.delete( # pylint: disable-next=no-member @@ -561,14 +526,11 @@ def bulk_destroy( **kwargs, ) - if not anonymized: - - def _make_assertions(): - assert not self._model_class.objects.filter( - **{f"{self._lookup_field}__in": lookup_values} - ).exists() - - self._assert_response(response, _make_assertions) + if make_assertions: + self._assert_response( + response, + make_assertions=lambda: self._assert_destroy(lookup_values), + ) return response