Skip to content

Commit

Permalink
fix: reusable type and simplify base serializer test case
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Feb 7, 2024
1 parent dbcea33 commit c8123f2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 87 deletions.
24 changes: 9 additions & 15 deletions codeforlife/serializers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rest_framework.serializers import ModelSerializer as _ModelSerializer
from rest_framework.serializers import ValidationError as _ValidationError

from ..types import DataDict
from .base import BaseSerializer

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand All @@ -24,7 +25,7 @@ class ModelSerializer(
):
"""Base model serializer for all model serializers."""

instance: AnyModel
instance: t.Optional[AnyModel]

@property
def view(self):
Expand All @@ -35,18 +36,14 @@ def view(self):
return t.cast(ModelViewSet[AnyModel], super().view)

# pylint: disable-next=useless-parent-delegation
def update(
self,
instance: AnyModel,
validated_data: t.Dict[str, t.Any],
) -> AnyModel:
def update(self, instance: AnyModel, validated_data: DataDict) -> AnyModel:
return super().update(instance, validated_data)

# pylint: disable-next=useless-parent-delegation
def create(self, validated_data: t.Dict[str, t.Any]) -> AnyModel:
def create(self, validated_data: DataDict) -> AnyModel:
return super().create(validated_data)

def validate(self, attrs: t.Dict[str, t.Any]):
def validate(self, attrs: DataDict):
return attrs


Expand All @@ -72,7 +69,7 @@ class Meta:
list_serializer_class = UserListSerializer
"""

instance: t.List[AnyModel]
instance: t.Optional[t.List[AnyModel]]
batch_size: t.Optional[int] = None

@property
Expand All @@ -96,10 +93,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
0
]

def create(
self,
validated_data: t.List[t.Dict[str, t.Any]],
) -> t.List[AnyModel]:
def create(self, validated_data: t.List[DataDict]) -> t.List[AnyModel]:
"""Bulk create many instances of a model.
https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create
Expand All @@ -120,7 +114,7 @@ def create(
def update(
self,
instance: t.List[AnyModel],
validated_data: t.List[t.Dict[str, t.Any]],
validated_data: t.List[DataDict],
) -> t.List[AnyModel]:
"""Bulk update many instances of a model.
Expand Down Expand Up @@ -148,7 +142,7 @@ def update(

return instance

def validate(self, attrs: t.List[t.Dict[str, t.Any]]):
def validate(self, attrs: t.List[DataDict]):
# If performing a bulk create.
if self.instance is None:
if len(attrs) == 0:
Expand Down
100 changes: 37 additions & 63 deletions codeforlife/tests/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rest_framework.test import APIRequestFactory

from ..serializers import ModelSerializer
from ..types import JsonDict, KwArgs
from ..types import DataDict
from ..user.models import User

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand All @@ -28,8 +28,6 @@ class ModelSerializerTestCase(TestCase, t.Generic[AnyModel]):

request_factory = APIRequestFactory()

Data = t.Dict[str, t.Any]

@classmethod
def setUpClass(cls):
attr_name = "model_serializer_class"
Expand Down Expand Up @@ -81,73 +79,59 @@ def __exit__(self, *args, **kwargs):

return Wrapper(self.assertRaises(ValidationError, *args, **kwargs))

# pylint: disable-next=too-many-arguments
def _assert_validate(
def init_request(
self,
value,
error_code: str,
user: t.Optional[User],
request_kwargs: t.Optional[KwArgs],
get_validate: t.Callable[
[ModelSerializer[AnyModel]], t.Callable[[t.Any], t.Any]
],
method: str,
user: t.Optional[User] = None,
**kwargs,
):
kwargs.setdefault("context", {})
context: t.Dict[str, t.Any] = kwargs["context"]
"""Initialize a generic HTTP request.
if "request" not in context:
request_kwargs = request_kwargs or {}
request_kwargs.setdefault("method", "POST")
request_kwargs.setdefault("path", "/")
request_kwargs.setdefault("data", "")
request_kwargs.setdefault("content_type", "application/json")
Create an instance of DRF's Request object. Note this does not send the
HTTP request.
request = self.request_factory.generic(**request_kwargs)
if user is not None:
request.user = user
Args:
method: The HTTP method.
user: The user making the request.
context["request"] = request
Returns:
An instance of DRF's Request object.
"""

serializer = self.model_serializer_class(**kwargs)
kwargs.setdefault("path", "/")
kwargs.setdefault("content_type", "application/json")

with self.assert_raises_validation_error(error_code):
get_validate(serializer)(value)
request = self.request_factory.generic(method, **kwargs)
if user:
request.user = user

return request

def assert_validate(
self,
attrs: t.Union[JsonDict, t.List[JsonDict]],
attrs: t.Union[DataDict, t.List[DataDict]],
error_code: str,
user: t.Optional[User] = None,
request_kwargs: t.Optional[KwArgs] = None,
*args,
**kwargs,
):
"""Asserts that calling validate() raises the expected error code.
Args:
attrs: The attributes to pass to validate().
error_code: The expected error code to be raised.
user: The requesting user.
request_kwargs: The kwargs used to initialize the request.
"""

self._assert_validate(
attrs,
error_code,
user,
request_kwargs,
get_validate=lambda serializer: serializer.validate,
**kwargs,
)
serializer = self.model_serializer_class(*args, **kwargs)
with self.assert_raises_validation_error(error_code):
serializer.validate(attrs) # type: ignore[arg-type]

# pylint: disable-next=too-many-arguments
def assert_validate_field(
self,
name: str,
value,
error_code: str,
user: t.Optional[User] = None,
request_kwargs: t.Optional[KwArgs] = None,
*args,
**kwargs,
):
"""Asserts that calling validate_field() raises the expected error code.
Expand All @@ -156,25 +140,15 @@ def assert_validate_field(
name: The name of the field.
value: The value to pass to validate_field().
error_code: The expected error code to be raised.
user: The requesting user.
request_kwargs: The kwargs used to initialize the request.
"""

def get_validate(serializer: ModelSerializer[AnyModel]):
validate_field = getattr(serializer, f"validate_{name}")
assert callable(validate_field)
return validate_field

self._assert_validate(
value,
error_code,
user,
request_kwargs,
get_validate,
**kwargs,
)

def _assert_data_is_subset_of_model(self, data: Data, model):
serializer = self.model_serializer_class(*args, **kwargs)
validate_field = getattr(serializer, f"validate_{name}")
assert callable(validate_field)
with self.assert_raises_validation_error(error_code):
validate_field(value)

def _assert_data_is_subset_of_model(self, data: DataDict, model):
assert isinstance(model, Model)

for field, value in data.copy().items():
Expand All @@ -190,9 +164,9 @@ def _assert_data_is_subset_of_model(self, data: Data, model):

def assert_create(
self,
validated_data: Data,
validated_data: DataDict,
*args,
new_data: t.Optional[Data] = None,
new_data: t.Optional[DataDict] = None,
**kwargs,
):
"""Assert that the data used to create the model is a subset of the
Expand All @@ -211,9 +185,9 @@ def assert_create(
def assert_update(
self,
instance: AnyModel,
validated_data: Data,
validated_data: DataDict,
*args,
new_data: t.Optional[Data] = None,
new_data: t.Optional[DataDict] = None,
**kwargs,
):
"""Assert that the data used to update the model is a subset of the
Expand Down
15 changes: 7 additions & 8 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..permissions import Permission
from ..serializers import ModelSerializer
from ..types import JsonDict
from ..types import DataDict, JsonDict
from ..user.models import (
AuthFactor,
NonSchoolTeacherUser,
Expand Down Expand Up @@ -77,7 +77,6 @@ def _lookup_field(self):
else lookup_field
)

Data = t.Dict[str, t.Any]
StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
ListFilters = t.Optional[t.Dict[str, str]]

Expand All @@ -101,7 +100,7 @@ def _assert_response_json_bulk(
self,
response: Response,
make_assertions: t.Callable[[t.List[JsonDict]], None],
data: t.List[Data],
data: t.List[DataDict],
):
def _make_assertions():
response_json = response.json() # type: ignore[attr-defined]
Expand Down Expand Up @@ -137,7 +136,7 @@ def _assert_serialized_model_equals_json_model(

datetime_to_representation = DateTimeField().to_representation

def datetime_values_to_representation(data: ModelViewSetClient.Data):
def datetime_values_to_representation(data: DataDict):
for key, value in data.copy().items():
if isinstance(value, dict):
datetime_values_to_representation(value)
Expand Down Expand Up @@ -215,7 +214,7 @@ def _assert_create(self, json_model: JsonDict):

def create(
self,
data: Data,
data: DataDict,
status_code_assertion: StatusCodeAssertion = status.HTTP_201_CREATED,
make_assertions: bool = True,
**kwargs,
Expand Down Expand Up @@ -249,7 +248,7 @@ def create(

def bulk_create(
self,
data: t.List[Data],
data: t.List[DataDict],
status_code_assertion: StatusCodeAssertion = status.HTTP_201_CREATED,
make_assertions: bool = True,
**kwargs,
Expand Down Expand Up @@ -385,7 +384,7 @@ def _assert_partial_update(self, model: AnyModel, json_model: JsonDict):
def partial_update(
self,
model: AnyModel,
data: Data,
data: DataDict,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
make_assertions: bool = True,
**kwargs,
Expand Down Expand Up @@ -426,7 +425,7 @@ def partial_update(
def bulk_partial_update(
self,
models: t.List[AnyModel],
data: t.List[Data],
data: t.List[DataDict],
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
make_assertions: bool = True,
**kwargs,
Expand Down
2 changes: 2 additions & 0 deletions codeforlife/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
JsonList = t.List["JsonValue"]
JsonDict = t.Dict[str, "JsonValue"]
JsonValue = t.Union[int, str, bool, JsonList, JsonDict]

DataDict = t.Dict[str, t.Any]
3 changes: 2 additions & 1 deletion codeforlife/views/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..permissions import Permission
from ..serializers import ModelListSerializer, ModelSerializer
from ..types import DataDict
from .api import APIView

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand Down Expand Up @@ -142,7 +143,7 @@ def bulk_partial_update(self, request: Request):
else self.lookup_field
)

data = t.cast(t.List[t.Dict[str, t.Any]], request.data)
data = t.cast(t.List[DataDict], request.data)
data.sort(key=lambda model: model[lookup_field])

queryset = model_class.objects.filter( # type: ignore[attr-defined]
Expand Down

0 comments on commit c8123f2

Please sign in to comment.