Skip to content

Commit

Permalink
fix: assertion helpers (#69)
Browse files Browse the repository at this point in the history
* fix: permission checking

* fix: assert get query set

* add TODOs

* fix: permission operators
  • Loading branch information
SKairinos authored Feb 2, 2024
1 parent ab29e5a commit 627c1e0
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 0 deletions.
1 change: 1 addition & 0 deletions codeforlife/permissions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@

from .allow_none import AllowNone
from .is_cron_request_from_google import IsCronRequestFromGoogle
from .operators import AND, NOT, OR, Permission
3 changes: 3 additions & 0 deletions codeforlife/permissions/allow_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ class AllowNone(BasePermission):
https://www.django-rest-framework.org/api-guide/permissions/#allowany
"""

def __eq__(self, other):
return isinstance(other, self.__class__)

def has_permission(self, request, view):
return False
3 changes: 3 additions & 0 deletions codeforlife/permissions/is_cron_request_from_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class IsCronRequestFromGoogle(BasePermission):
https://cloud.google.com/appengine/docs/flexible/scheduling-jobs-with-cron-yaml#securing_urls_for_cron
"""

def __eq__(self, other):
return isinstance(other, self.__class__)

def has_permission(self, request, view):
return (
settings.DEBUG
Expand Down
50 changes: 50 additions & 0 deletions codeforlife/permissions/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
© Ocado Group
Created on 02/02/2024 at 17:52:37(+00:00).
Extends the permission operands provided by Django REST framework.
"""

import typing as t

from rest_framework.permissions import AND as _AND
from rest_framework.permissions import NOT as _NOT
from rest_framework.permissions import OR as _OR
from rest_framework.permissions import BasePermission


# pylint: disable-next=missing-class-docstring
class AND(_AND):
op1: BasePermission
op2: BasePermission

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.op1 == other.op1
and self.op2 == other.op2
)


# pylint: disable-next=missing-class-docstring
class NOT(_NOT):
op1: BasePermission

def __eq__(self, other):
return isinstance(other, self.__class__) and self.op1 == other.op1


# pylint: disable-next=missing-class-docstring
class OR(_OR):
op1: BasePermission
op2: BasePermission

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.op1 == other.op1
and self.op2 == other.op2
)


Permission = t.Union[BasePermission, AND, NOT, OR]
36 changes: 36 additions & 0 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory, APITestCase

from ..permissions import Permission
from ..serializers import ModelSerializer
from ..user.models import AuthFactor, User
from ..views import ModelViewSet
Expand Down Expand Up @@ -689,6 +690,41 @@ def setUpClass(cls):

return super().setUpClass()

def assert_get_permissions(
self,
permissions: t.List[Permission],
*args,
**kwargs,
):
"""Assert that the expected permissions are returned.
Args:
permissions: The expected permissions.
"""

model_view_set = self.model_view_set_class(*args, **kwargs)
actual_permissions = model_view_set.get_permissions()
self.assertListEqual(permissions, actual_permissions)

def assert_get_queryset(
self,
values: t.Collection[AnyModel],
*args,
ordered: bool = True,
**kwargs,
):
"""Assert that the expected queryset is returned.
Args:
values: The values we expect the queryset to contain.
ordered: Whether the queryset provides an implicit ordering.
"""

model_view_set = self.model_view_set_class(*args, **kwargs)
queryset = model_view_set.get_queryset()
# pylint: disable-next=no-member
self.assertQuerySetEqual(queryset, values, ordered=ordered)

def get_other_user(
self,
user: User,
Expand Down
6 changes: 6 additions & 0 deletions codeforlife/user/permissions/in_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def __init__(self, class_id: t.Optional[str] = None):
super().__init__()
self.class_id = class_id

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.class_id == other.class_id
)

def has_permission(self, request: Request, view: APIView):
user = request.user
if super().has_permission(request, view) and isinstance(user, User):
Expand Down
6 changes: 6 additions & 0 deletions codeforlife/user/permissions/in_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def __init__(self, school_id: t.Optional[int] = None):
super().__init__()
self.school_id = school_id

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.school_id == other.school_id
)

def has_permission(self, request: Request, view: APIView):
def in_school(school_id: int):
return self.school_id is None or self.school_id == school_id
Expand Down
3 changes: 3 additions & 0 deletions codeforlife/user/permissions/is_independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
class IsIndependent(IsAuthenticated):
"""Request's user must be independent."""

def __eq__(self, other):
return isinstance(other, self.__class__)

def has_permission(self, request: Request, view: APIView):
user = request.user
return (
Expand Down
6 changes: 6 additions & 0 deletions codeforlife/user/permissions/is_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def __init__(self, student_id: t.Optional[int] = None):
super().__init__()
self.student_id = student_id

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.student_id == other.student_id
)

def has_permission(self, request: Request, view: APIView):
user = request.user
return (
Expand Down
7 changes: 7 additions & 0 deletions codeforlife/user/permissions/is_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __init__(
self.teacher_id = teacher_id
self.is_admin = is_admin

def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self.teacher_id == other.teacher_id
and self.is_admin == other.is_admin
)

def has_permission(self, request: Request, view: APIView):
user = request.user
return (
Expand Down
1 change: 1 addition & 0 deletions codeforlife/user/tests/views/test_klass.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ def test_retrieve__student__same_school__in_class(self):
self.client.retrieve(user.student.class_field)

# TODO: other retrieve and list tests
# TODO: replace above tests with get_queryset() tests
2 changes: 2 additions & 0 deletions codeforlife/user/tests/views/test_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,5 @@ def test_list__student(self):
user = self._login_student()

self.client.list([user.student.class_field.teacher.school])

# TODO: replace above tests with get_queryset() tests
2 changes: 2 additions & 0 deletions codeforlife/user/tests/views/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,5 @@ def test_all__only_http_get(self):
assert [name.lower() for name in UserViewSet.http_method_names] == [
"get"
]

# TODO: replace above tests with get_queryset() tests
4 changes: 4 additions & 0 deletions codeforlife/views/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rest_framework.serializers import ListSerializer
from rest_framework.viewsets import ModelViewSet as DrfModelViewSet

from ..permissions import Permission
from ..serializers import ModelListSerializer, ModelSerializer

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand Down Expand Up @@ -49,6 +50,9 @@ def get_model_class(cls) -> t.Type[AnyModel]:
0
]

def get_permissions(self):
return t.cast(t.List[Permission], super().get_permissions())

def get_serializer(self, *args, **kwargs):
serializer = super().get_serializer(*args, **kwargs)

Expand Down

0 comments on commit 627c1e0

Please sign in to comment.