diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 2c63875b..30c7968d 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -3,4 +3,5 @@ Created on 20/01/2024 at 11:19:12(+00:00). """ +from .base import * from .model import * diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py new file mode 100644 index 00000000..aff2db43 --- /dev/null +++ b/codeforlife/serializers/base.py @@ -0,0 +1,41 @@ +""" +© Ocado Group +Created on 29/01/2024 at 14:27:09(+00:00). + +Base serializer. +""" + +import typing as t + +from django.contrib.auth.models import AnonymousUser +from rest_framework.serializers import BaseSerializer as _BaseSerializer + +from ..request import Request +from ..user.models import User + + +# pylint: disable-next=abstract-method +class BaseSerializer(_BaseSerializer): + """Base serializer to be inherited by all other serializers.""" + + @property + def request(self): + """The HTTP request that triggered the view.""" + + return t.cast(Request, self.context["request"]) + + @property + def request_user(self): + """ + The user that made the request. Assumes the user has authenticated. + """ + + return t.cast(User, self.request.user) + + @property + def request_anon_user(self): + """ + The user that made the request. Assumes the user has not authenticated. + """ + + return t.cast(AnonymousUser, self.request.user) diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index 4aec39d7..49283531 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -12,10 +12,16 @@ from rest_framework.serializers import ModelSerializer as _ModelSerializer from rest_framework.serializers import ValidationError as _ValidationError +from .base import BaseSerializer + AnyModel = t.TypeVar("AnyModel", bound=Model) -class ModelSerializer(_ModelSerializer[AnyModel], t.Generic[AnyModel]): +class ModelSerializer( + BaseSerializer, + _ModelSerializer[AnyModel], + t.Generic[AnyModel], +): """Base model serializer for all model serializers.""" # pylint: disable-next=useless-parent-delegation @@ -31,6 +37,7 @@ def validate(self, attrs: t.Dict[str, t.Any]): class ModelListSerializer( + BaseSerializer, t.Generic[AnyModel], _ListSerializer[t.List[AnyModel]], ):