Skip to content

Commit

Permalink
fix: custom base serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 29, 2024
1 parent ebe5feb commit cc5d5d6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
1 change: 1 addition & 0 deletions codeforlife/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
Created on 20/01/2024 at 11:19:12(+00:00).
"""

from .base import *
from .model import *
41 changes: 41 additions & 0 deletions codeforlife/serializers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
© Ocado Group
Created on 29/01/2024 at 14:27:09(+00:00).
Base serializer.
"""

import typing as t

from django.contrib.auth.models import AnonymousUser
from rest_framework.serializers import BaseSerializer as _BaseSerializer

from ..request import Request
from ..user.models import User


# pylint: disable-next=abstract-method
class BaseSerializer(_BaseSerializer):
"""Base serializer to be inherited by all other serializers."""

@property
def request(self):
"""The HTTP request that triggered the view."""

return t.cast(Request, self.context["request"])

@property
def request_user(self):
"""
The user that made the request. Assumes the user has authenticated.
"""

return t.cast(User, self.request.user)

@property
def request_anon_user(self):
"""
The user that made the request. Assumes the user has not authenticated.
"""

return t.cast(AnonymousUser, self.request.user)
9 changes: 8 additions & 1 deletion codeforlife/serializers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
from rest_framework.serializers import ModelSerializer as _ModelSerializer
from rest_framework.serializers import ValidationError as _ValidationError

from .base import BaseSerializer

AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelSerializer(_ModelSerializer[AnyModel], t.Generic[AnyModel]):
class ModelSerializer(
BaseSerializer,
_ModelSerializer[AnyModel],
t.Generic[AnyModel],
):
"""Base model serializer for all model serializers."""

# pylint: disable-next=useless-parent-delegation
Expand All @@ -31,6 +37,7 @@ def validate(self, attrs: t.Dict[str, t.Any]):


class ModelListSerializer(
BaseSerializer,
t.Generic[AnyModel],
_ListSerializer[t.List[AnyModel]],
):
Expand Down

0 comments on commit cc5d5d6

Please sign in to comment.