From efe62bc039d28965727909213887316592c41bf1 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Mon, 22 Jan 2024 22:45:55 +0000 Subject: [PATCH] minor fixes --- codeforlife/models/signals/pre_save.py | 23 +++--- codeforlife/tests/model_view_set.py | 105 ++++++++++++++++++------- codeforlife/user/urls.py | 7 +- 3 files changed, 93 insertions(+), 42 deletions(-) diff --git a/codeforlife/models/signals/pre_save.py b/codeforlife/models/signals/pre_save.py index 17e64017..306fea05 100644 --- a/codeforlife/models/signals/pre_save.py +++ b/codeforlife/models/signals/pre_save.py @@ -44,33 +44,38 @@ def check_previous_values( instance: AnyModel, predicates: t.Dict[str, t.Callable[[t.Any, t.Any], bool]], ): - """Check if the previous values are as expected. + """Check if the previous values are as expected. If the model has not been + created yet, the previous values are None. Args: instance: The current instance. predicates: A predicate for each field. It accepts the arguments (previous_value, value) and returns True if the values are as expected. - Raises: - ValueError: If arg 'instance' has not been created yet. - Returns: If all the previous values are as expected. """ - if not was_created(instance): - raise ValueError("Arg 'instance' has not been created yet.") + if was_created(instance): + previous_instance = instance.__class__.objects.get(pk=instance.pk) + + def get_previous_value(field: str): + return getattr(previous_instance, field) - previous_instance = instance.__class__.objects.get(pk=instance.pk) + else: + # pylint: disable-next=unused-argument + def get_previous_value(field: str): + return None return all( - predicate(getattr(previous_instance, field), getattr(instance, field)) + predicate(get_previous_value(field), getattr(instance, field)) for field, predicate in predicates.items() ) def previous_values_are_unequal(instance: AnyModel, fields: t.Set[str]): - """Check if all the previous values are not equal to the current values. + """Check if all the previous values are not equal to the current values. If + the model has not been created yet, the previous values are None. Args: instance: The current instance. diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 1216c7b2..5bce60e9 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -11,7 +11,7 @@ from django.db.models import Model from django.db.models.query import QuerySet -from django.urls import reverse +from django.urls import reverse as _reverse from django.utils import timezone from django.utils.http import urlencode from pyotp import TOTP @@ -36,30 +36,26 @@ class ModelViewSetClient( responses. """ - _test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]" - - @property - def basename(self): - """Shortcut to get basename.""" + Data = t.Dict[str, t.Any] - return self._test_case.basename + _test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]" @property - def model_class(self): + def _model_class(self): """Shortcut to get model class.""" # pylint: disable-next=no-member return self._test_case.get_model_class() @property - def model_serializer_class(self): + def _model_serializer_class(self): """Shortcut to get model serializer class.""" # pylint: disable-next=no-member return self._test_case.get_model_serializer_class() @property - def model_view_set_class(self): + def _model_view_set_class(self): """Shortcut to get model view set class.""" # pylint: disable-next=no-member @@ -84,7 +80,7 @@ def status_code_is_ok(status_code: int): def assert_data_equals_model( self, - data: t.Dict[str, t.Any], + data: Data, model: AnyModel, contains_subset: bool = False, ): @@ -112,7 +108,7 @@ def parse_data(data): return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") return data - actual_data = parse_data(self.model_serializer_class(model).data) + actual_data = parse_data(self._model_serializer_class(model).data) if contains_subset: # pylint: disable-next=no-member @@ -129,16 +125,37 @@ def parse_data(data): "Data does not equal serialized model.", ) - def _get_reverse_detail(self, model: AnyModel, **kwargs): - return reverse( + def reverse( + self, + action: str, + model: t.Optional[AnyModel] = None, + **kwargs, + ): + """Get the reverse URL for the model view set's action. + + Args: + action: The name of the action. + model: The model to look up. + + Returns: + The reversed URL. + """ + + reverse_kwargs = kwargs.pop("kwargs", {}) + if model is not None: + reverse_kwargs[self._model_view_set_class.lookup_field] = getattr( + model, + self._model_view_set_class.lookup_field, + ) + + return _reverse( + viewname=kwargs.pop( + "viewname", + # pylint: disable-next=no-member + f"{self._test_case.basename}-{action}", + ), + kwargs=reverse_kwargs, **kwargs, - viewname=kwargs.get("viewname", f"{self.basename}-detail"), - kwargs={ - **kwargs.get("kwargs", {}), - self.model_view_set_class.lookup_field: getattr( - model, self.model_view_set_class.lookup_field - ), - }, ) # pylint: disable-next=too-many-arguments @@ -184,6 +201,37 @@ def generic( return response + def create( + self, + data: Data, + status_code_assertion: StatusCodeAssertion = None, + **kwargs, + ): + """Create a model. + + Args: + data: The values for each field. + status_code_assertion: The expected status code. + + Returns: + The HTTP response. + """ + + response: Response = self.post( + self.reverse("list"), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if self.status_code_is_ok(response.status_code): + # pylint: disable-next=no-member + self._test_case.assertDictContainsSubset( + data, + response.json(), # type: ignore[attr-defined] + ) + + return response + def retrieve( self, model: AnyModel, @@ -201,7 +249,7 @@ def retrieve( """ response: Response = self.get( - self._get_reverse_detail(model), + self.reverse("detail", model), status_code_assertion=status_code_assertion, **kwargs, ) @@ -232,14 +280,14 @@ def list( The HTTP response. """ - assert self.model_class.objects.difference( - self.model_class.objects.filter( + assert self._model_class.objects.difference( + self._model_class.objects.filter( pk__in=[model.pk for model in models] ) ).exists(), "List must exclude some models for a valid test." response: Response = self.get( - f"{reverse(f'{self.basename}-list')}?{urlencode(filters or {})}", + f"{self.reverse('list')}?{urlencode(filters or {})}", status_code_assertion=status_code_assertion, **kwargs, ) @@ -253,7 +301,7 @@ def list( def partial_update( self, model: AnyModel, - data: t.Dict[str, t.Any], + data: Data, status_code_assertion: StatusCodeAssertion = None, **kwargs, ): @@ -261,6 +309,7 @@ def partial_update( Args: model: The model to partially update. + data: The values for each field. status_code_assertion: The expected status code. Returns: @@ -268,7 +317,7 @@ def partial_update( """ response: Response = self.patch( - self._get_reverse_detail(model), + self.reverse("detail", model), data=data, status_code_assertion=status_code_assertion, **kwargs, @@ -301,7 +350,7 @@ def destroy( """ response: Response = self.delete( - self._get_reverse_detail(model), + self.reverse("detail", model), status_code_assertion=status_code_assertion, **kwargs, ) diff --git a/codeforlife/user/urls.py b/codeforlife/user/urls.py index e797e5a4..a203c079 100644 --- a/codeforlife/user/urls.py +++ b/codeforlife/user/urls.py @@ -1,13 +1,10 @@ -from django.urls import include, path from rest_framework.routers import DefaultRouter -from .views import ClassViewSet, UserViewSet, SchoolViewSet +from .views import ClassViewSet, SchoolViewSet, UserViewSet router = DefaultRouter() router.register("classes", ClassViewSet, basename="class") router.register("users", UserViewSet, basename="user") router.register("schools", SchoolViewSet, basename="school") -urlpatterns = [ - path("", include(router.urls)), -] +urlpatterns = router.urls