From 627c1e029a50e1a283763d4bc445a0576342acf1 Mon Sep 17 00:00:00 2001 From: Stefan Kairinos Date: Fri, 2 Feb 2024 18:11:51 +0000 Subject: [PATCH] fix: assertion helpers (#69) * fix: permission checking * fix: assert get query set * add TODOs * fix: permission operators --- codeforlife/permissions/__init__.py | 1 + codeforlife/permissions/allow_none.py | 3 ++ .../is_cron_request_from_google.py | 3 ++ codeforlife/permissions/operators.py | 50 +++++++++++++++++++ codeforlife/tests/model_view_set.py | 36 +++++++++++++ codeforlife/user/permissions/in_class.py | 6 +++ codeforlife/user/permissions/in_school.py | 6 +++ .../user/permissions/is_independent.py | 3 ++ codeforlife/user/permissions/is_student.py | 6 +++ codeforlife/user/permissions/is_teacher.py | 7 +++ codeforlife/user/tests/views/test_klass.py | 1 + codeforlife/user/tests/views/test_school.py | 2 + codeforlife/user/tests/views/test_user.py | 2 + codeforlife/views/model.py | 4 ++ 14 files changed, 130 insertions(+) create mode 100644 codeforlife/permissions/operators.py diff --git a/codeforlife/permissions/__init__.py b/codeforlife/permissions/__init__.py index ddb7bf5b..d3bc6f9a 100644 --- a/codeforlife/permissions/__init__.py +++ b/codeforlife/permissions/__init__.py @@ -7,3 +7,4 @@ from .allow_none import AllowNone from .is_cron_request_from_google import IsCronRequestFromGoogle +from .operators import AND, NOT, OR, Permission diff --git a/codeforlife/permissions/allow_none.py b/codeforlife/permissions/allow_none.py index 6bc21734..ef2cbb83 100644 --- a/codeforlife/permissions/allow_none.py +++ b/codeforlife/permissions/allow_none.py @@ -14,5 +14,8 @@ class AllowNone(BasePermission): https://www.django-rest-framework.org/api-guide/permissions/#allowany """ + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request, view): return False diff --git a/codeforlife/permissions/is_cron_request_from_google.py b/codeforlife/permissions/is_cron_request_from_google.py index b8f49f7c..c6cdf254 100644 --- a/codeforlife/permissions/is_cron_request_from_google.py +++ b/codeforlife/permissions/is_cron_request_from_google.py @@ -14,6 +14,9 @@ class IsCronRequestFromGoogle(BasePermission): https://cloud.google.com/appengine/docs/flexible/scheduling-jobs-with-cron-yaml#securing_urls_for_cron """ + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request, view): return ( settings.DEBUG diff --git a/codeforlife/permissions/operators.py b/codeforlife/permissions/operators.py new file mode 100644 index 00000000..81176d06 --- /dev/null +++ b/codeforlife/permissions/operators.py @@ -0,0 +1,50 @@ +""" +© Ocado Group +Created on 02/02/2024 at 17:52:37(+00:00). + +Extends the permission operands provided by Django REST framework. +""" + +import typing as t + +from rest_framework.permissions import AND as _AND +from rest_framework.permissions import NOT as _NOT +from rest_framework.permissions import OR as _OR +from rest_framework.permissions import BasePermission + + +# pylint: disable-next=missing-class-docstring +class AND(_AND): + op1: BasePermission + op2: BasePermission + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.op1 == other.op1 + and self.op2 == other.op2 + ) + + +# pylint: disable-next=missing-class-docstring +class NOT(_NOT): + op1: BasePermission + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.op1 == other.op1 + + +# pylint: disable-next=missing-class-docstring +class OR(_OR): + op1: BasePermission + op2: BasePermission + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.op1 == other.op1 + and self.op2 == other.op2 + ) + + +Permission = t.Union[BasePermission, AND, NOT, OR] diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index eee9dffa..cb9290c2 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -20,6 +20,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory, APITestCase +from ..permissions import Permission from ..serializers import ModelSerializer from ..user.models import AuthFactor, User from ..views import ModelViewSet @@ -689,6 +690,41 @@ def setUpClass(cls): return super().setUpClass() + def assert_get_permissions( + self, + permissions: t.List[Permission], + *args, + **kwargs, + ): + """Assert that the expected permissions are returned. + + Args: + permissions: The expected permissions. + """ + + model_view_set = self.model_view_set_class(*args, **kwargs) + actual_permissions = model_view_set.get_permissions() + self.assertListEqual(permissions, actual_permissions) + + def assert_get_queryset( + self, + values: t.Collection[AnyModel], + *args, + ordered: bool = True, + **kwargs, + ): + """Assert that the expected queryset is returned. + + Args: + values: The values we expect the queryset to contain. + ordered: Whether the queryset provides an implicit ordering. + """ + + model_view_set = self.model_view_set_class(*args, **kwargs) + queryset = model_view_set.get_queryset() + # pylint: disable-next=no-member + self.assertQuerySetEqual(queryset, values, ordered=ordered) + def get_other_user( self, user: User, diff --git a/codeforlife/user/permissions/in_class.py b/codeforlife/user/permissions/in_class.py index c030e6d0..2c5ceac6 100644 --- a/codeforlife/user/permissions/in_class.py +++ b/codeforlife/user/permissions/in_class.py @@ -26,6 +26,12 @@ def __init__(self, class_id: t.Optional[str] = None): super().__init__() self.class_id = class_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.class_id == other.class_id + ) + def has_permission(self, request: Request, view: APIView): user = request.user if super().has_permission(request, view) and isinstance(user, User): diff --git a/codeforlife/user/permissions/in_school.py b/codeforlife/user/permissions/in_school.py index 1b866d21..2fde40c2 100644 --- a/codeforlife/user/permissions/in_school.py +++ b/codeforlife/user/permissions/in_school.py @@ -26,6 +26,12 @@ def __init__(self, school_id: t.Optional[int] = None): super().__init__() self.school_id = school_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.school_id == other.school_id + ) + def has_permission(self, request: Request, view: APIView): def in_school(school_id: int): return self.school_id is None or self.school_id == school_id diff --git a/codeforlife/user/permissions/is_independent.py b/codeforlife/user/permissions/is_independent.py index 5d0f5a8e..0ad94565 100644 --- a/codeforlife/user/permissions/is_independent.py +++ b/codeforlife/user/permissions/is_independent.py @@ -13,6 +13,9 @@ class IsIndependent(IsAuthenticated): """Request's user must be independent.""" + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/user/permissions/is_student.py b/codeforlife/user/permissions/is_student.py index a4a43e9e..2f0a1f9b 100644 --- a/codeforlife/user/permissions/is_student.py +++ b/codeforlife/user/permissions/is_student.py @@ -26,6 +26,12 @@ def __init__(self, student_id: t.Optional[int] = None): super().__init__() self.student_id = student_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.student_id == other.student_id + ) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/user/permissions/is_teacher.py b/codeforlife/user/permissions/is_teacher.py index 255d9c6d..c7a8ec65 100644 --- a/codeforlife/user/permissions/is_teacher.py +++ b/codeforlife/user/permissions/is_teacher.py @@ -34,6 +34,13 @@ def __init__( self.teacher_id = teacher_id self.is_admin = is_admin + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.teacher_id == other.teacher_id + and self.is_admin == other.is_admin + ) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/user/tests/views/test_klass.py b/codeforlife/user/tests/views/test_klass.py index eda41b04..86452206 100644 --- a/codeforlife/user/tests/views/test_klass.py +++ b/codeforlife/user/tests/views/test_klass.py @@ -58,3 +58,4 @@ def test_retrieve__student__same_school__in_class(self): self.client.retrieve(user.student.class_field) # TODO: other retrieve and list tests + # TODO: replace above tests with get_queryset() tests diff --git a/codeforlife/user/tests/views/test_school.py b/codeforlife/user/tests/views/test_school.py index d5b80db4..222c3726 100644 --- a/codeforlife/user/tests/views/test_school.py +++ b/codeforlife/user/tests/views/test_school.py @@ -193,3 +193,5 @@ def test_list__student(self): user = self._login_student() self.client.list([user.student.class_field.teacher.school]) + + # TODO: replace above tests with get_queryset() tests diff --git a/codeforlife/user/tests/views/test_user.py b/codeforlife/user/tests/views/test_user.py index 793dfc63..cacb475a 100644 --- a/codeforlife/user/tests/views/test_user.py +++ b/codeforlife/user/tests/views/test_user.py @@ -519,3 +519,5 @@ def test_all__only_http_get(self): assert [name.lower() for name in UserViewSet.http_method_names] == [ "get" ] + + # TODO: replace above tests with get_queryset() tests diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 80560a0a..c0173844 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -14,6 +14,7 @@ from rest_framework.serializers import ListSerializer from rest_framework.viewsets import ModelViewSet as DrfModelViewSet +from ..permissions import Permission from ..serializers import ModelListSerializer, ModelSerializer AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -49,6 +50,9 @@ def get_model_class(cls) -> t.Type[AnyModel]: 0 ] + def get_permissions(self): + return t.cast(t.List[Permission], super().get_permissions()) + def get_serializer(self, *args, **kwargs): serializer = super().get_serializer(*args, **kwargs)