Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 22, 2024
1 parent 7238bbc commit efe62bc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 42 deletions.
23 changes: 14 additions & 9 deletions codeforlife/models/signals/pre_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,38 @@ def check_previous_values(
instance: AnyModel,
predicates: t.Dict[str, t.Callable[[t.Any, t.Any], bool]],
):
"""Check if the previous values are as expected.
"""Check if the previous values are as expected. If the model has not been
created yet, the previous values are None.
Args:
instance: The current instance.
predicates: A predicate for each field. It accepts the arguments
(previous_value, value) and returns True if the values are as expected.
Raises:
ValueError: If arg 'instance' has not been created yet.
Returns:
If all the previous values are as expected.
"""

if not was_created(instance):
raise ValueError("Arg 'instance' has not been created yet.")
if was_created(instance):
previous_instance = instance.__class__.objects.get(pk=instance.pk)

def get_previous_value(field: str):
return getattr(previous_instance, field)

previous_instance = instance.__class__.objects.get(pk=instance.pk)
else:
# pylint: disable-next=unused-argument
def get_previous_value(field: str):
return None

return all(
predicate(getattr(previous_instance, field), getattr(instance, field))
predicate(get_previous_value(field), getattr(instance, field))
for field, predicate in predicates.items()
)


def previous_values_are_unequal(instance: AnyModel, fields: t.Set[str]):
"""Check if all the previous values are not equal to the current values.
"""Check if all the previous values are not equal to the current values. If
the model has not been created yet, the previous values are None.
Args:
instance: The current instance.
Expand Down
105 changes: 77 additions & 28 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from django.db.models import Model
from django.db.models.query import QuerySet
from django.urls import reverse
from django.urls import reverse as _reverse
from django.utils import timezone
from django.utils.http import urlencode
from pyotp import TOTP
Expand All @@ -36,30 +36,26 @@ class ModelViewSetClient(
responses.
"""

_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"

@property
def basename(self):
"""Shortcut to get basename."""
Data = t.Dict[str, t.Any]

return self._test_case.basename
_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"

@property
def model_class(self):
def _model_class(self):
"""Shortcut to get model class."""

# pylint: disable-next=no-member
return self._test_case.get_model_class()

@property
def model_serializer_class(self):
def _model_serializer_class(self):
"""Shortcut to get model serializer class."""

# pylint: disable-next=no-member
return self._test_case.get_model_serializer_class()

@property
def model_view_set_class(self):
def _model_view_set_class(self):
"""Shortcut to get model view set class."""

# pylint: disable-next=no-member
Expand All @@ -84,7 +80,7 @@ def status_code_is_ok(status_code: int):

def assert_data_equals_model(
self,
data: t.Dict[str, t.Any],
data: Data,
model: AnyModel,
contains_subset: bool = False,
):
Expand Down Expand Up @@ -112,7 +108,7 @@ def parse_data(data):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

actual_data = parse_data(self.model_serializer_class(model).data)
actual_data = parse_data(self._model_serializer_class(model).data)

if contains_subset:
# pylint: disable-next=no-member
Expand All @@ -129,16 +125,37 @@ def parse_data(data):
"Data does not equal serialized model.",
)

def _get_reverse_detail(self, model: AnyModel, **kwargs):
return reverse(
def reverse(
self,
action: str,
model: t.Optional[AnyModel] = None,
**kwargs,
):
"""Get the reverse URL for the model view set's action.
Args:
action: The name of the action.
model: The model to look up.
Returns:
The reversed URL.
"""

reverse_kwargs = kwargs.pop("kwargs", {})
if model is not None:
reverse_kwargs[self._model_view_set_class.lookup_field] = getattr(
model,
self._model_view_set_class.lookup_field,
)

return _reverse(
viewname=kwargs.pop(
"viewname",
# pylint: disable-next=no-member
f"{self._test_case.basename}-{action}",
),
kwargs=reverse_kwargs,
**kwargs,
viewname=kwargs.get("viewname", f"{self.basename}-detail"),
kwargs={
**kwargs.get("kwargs", {}),
self.model_view_set_class.lookup_field: getattr(
model, self.model_view_set_class.lookup_field
),
},
)

# pylint: disable-next=too-many-arguments
Expand Down Expand Up @@ -184,6 +201,37 @@ def generic(

return response

def create(
self,
data: Data,
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Create a model.
Args:
data: The values for each field.
status_code_assertion: The expected status code.
Returns:
The HTTP response.
"""

response: Response = self.post(
self.reverse("list"),
status_code_assertion=status_code_assertion,
**kwargs,
)

if self.status_code_is_ok(response.status_code):
# pylint: disable-next=no-member
self._test_case.assertDictContainsSubset(
data,
response.json(), # type: ignore[attr-defined]
)

return response

def retrieve(
self,
model: AnyModel,
Expand All @@ -201,7 +249,7 @@ def retrieve(
"""

response: Response = self.get(
self._get_reverse_detail(model),
self.reverse("detail", model),
status_code_assertion=status_code_assertion,
**kwargs,
)
Expand Down Expand Up @@ -232,14 +280,14 @@ def list(
The HTTP response.
"""

assert self.model_class.objects.difference(
self.model_class.objects.filter(
assert self._model_class.objects.difference(
self._model_class.objects.filter(
pk__in=[model.pk for model in models]
)
).exists(), "List must exclude some models for a valid test."

response: Response = self.get(
f"{reverse(f'{self.basename}-list')}?{urlencode(filters or {})}",
f"{self.reverse('list')}?{urlencode(filters or {})}",
status_code_assertion=status_code_assertion,
**kwargs,
)
Expand All @@ -253,22 +301,23 @@ def list(
def partial_update(
self,
model: AnyModel,
data: t.Dict[str, t.Any],
data: Data,
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Partially update a model.
Args:
model: The model to partially update.
data: The values for each field.
status_code_assertion: The expected status code.
Returns:
The HTTP response.
"""

response: Response = self.patch(
self._get_reverse_detail(model),
self.reverse("detail", model),
data=data,
status_code_assertion=status_code_assertion,
**kwargs,
Expand Down Expand Up @@ -301,7 +350,7 @@ def destroy(
"""

response: Response = self.delete(
self._get_reverse_detail(model),
self.reverse("detail", model),
status_code_assertion=status_code_assertion,
**kwargs,
)
Expand Down
7 changes: 2 additions & 5 deletions codeforlife/user/urls.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from django.urls import include, path
from rest_framework.routers import DefaultRouter

from .views import ClassViewSet, UserViewSet, SchoolViewSet
from .views import ClassViewSet, SchoolViewSet, UserViewSet

router = DefaultRouter()
router.register("classes", ClassViewSet, basename="class")
router.register("users", UserViewSet, basename="user")
router.register("schools", SchoolViewSet, basename="school")

urlpatterns = [
path("", include(router.urls)),
]
urlpatterns = router.urls

0 comments on commit efe62bc

Please sign in to comment.