Skip to content

Commit

Permalink
fix: custom api request factory
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Feb 8, 2024
1 parent ac9d027 commit 9930bf1
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 31 deletions.
1 change: 1 addition & 0 deletions codeforlife/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Custom test cases.
"""

from .api_request_factory import APIRequestFactory
from .cron import CronTestCase
from .model_serializer import ModelSerializerTestCase
from .model_view_set import ModelViewSetTestCase
189 changes: 189 additions & 0 deletions codeforlife/tests/api_request_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
© Ocado Group
Created on 08/02/2024 at 15:42:25(+00:00).
"""

import typing as t

from django.core.handlers.wsgi import WSGIRequest
from rest_framework.parsers import (
FileUploadParser,
FormParser,
JSONParser,
MultiPartParser,
)
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory as _APIRequestFactory

from ..user.models import User


class APIRequestFactory(_APIRequestFactory):
"""Custom API request factory that returns DRF's Request object."""

# pylint: disable-next=too-many-arguments
def generic(
self,
method: str,
path: t.Optional[str] = None,
data: t.Optional[str] = None,
content_type: t.Optional[str] = None,
secure: bool = True,
user: t.Optional[User] = None,
**extra
):
wsgi_request = t.cast(
WSGIRequest,
super().generic(
method,
path or "/",
data or "",
content_type or "application/json",
secure,
**extra,
),
)

request = Request(
wsgi_request,
parsers=[
JSONParser(),
FormParser(),
MultiPartParser(),
FileUploadParser(),
],
)

if user:
request.user = user

return request

def get( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
user: t.Optional[User] = None,
**extra
):
return super().get(
path or "/",
data,
user=user,
**extra,
)

# pylint: disable-next=too-many-arguments
def post( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[User] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return super().post(
path or "/",
data,
format,
content_type,
user=user,
**extra,
)

# pylint: disable-next=too-many-arguments
def put( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[User] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return super().put(
path or "/",
data,
format,
content_type,
user=user,
**extra,
)

# pylint: disable-next=too-many-arguments
def patch( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[User] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return super().patch(
path or "/",
data,
format,
content_type,
user=user,
**extra,
)

# pylint: disable-next=too-many-arguments
def delete( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[User] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return super().delete(
path or "/",
data,
format,
content_type,
user=user,
**extra,
)

# pylint: disable-next=too-many-arguments
def options( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Optional[t.Union[t.Dict[str, str], str]] = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[User] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return super().options(
path or "/",
data or {},
format,
content_type,
user=user,
**extra,
)
31 changes: 1 addition & 30 deletions codeforlife/tests/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from django.forms.models import model_to_dict
from django.test import TestCase
from rest_framework.serializers import ValidationError
from rest_framework.test import APIRequestFactory

from ..serializers import ModelSerializer
from ..types import DataDict
from ..user.models import User
from .api_request_factory import APIRequestFactory

AnyModel = t.TypeVar("AnyModel", bound=Model)

Expand Down Expand Up @@ -79,34 +78,6 @@ def __exit__(self, *args, **kwargs):

return Wrapper(self.assertRaises(ValidationError, *args, **kwargs))

def init_request(
self,
method: str,
user: t.Optional[User] = None,
**kwargs,
):
"""Initialize a generic HTTP request.
Create an instance of DRF's Request object. Note this does not send the
HTTP request.
Args:
method: The HTTP method.
user: The user making the request.
Returns:
An instance of DRF's Request object.
"""

kwargs.setdefault("path", "/")
kwargs.setdefault("content_type", "application/json")

request = self.request_factory.generic(method.upper(), **kwargs)
if user:
request.user = user

return request

def assert_validate(
self,
attrs: t.Union[DataDict, t.List[DataDict]],
Expand Down
3 changes: 2 additions & 1 deletion codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import DateTimeField
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from rest_framework.test import APIClient, APITestCase

from ..permissions import Permission
from ..serializers import ModelSerializer
Expand All @@ -33,6 +33,7 @@
User,
)
from ..views import ModelViewSet
from .api_request_factory import APIRequestFactory

AnyModel = t.TypeVar("AnyModel", bound=Model)

Expand Down

0 comments on commit 9930bf1

Please sign in to comment.