From e6a3a55833282964ab81fecadd3e27edeba4b1e3 Mon Sep 17 00:00:00 2001 From: Stefan Kairinos Date: Tue, 6 Feb 2024 10:24:12 +0000 Subject: [PATCH] fix: create class (#73) * fix: add new props for new user * fix: permissions for base viewsets * fix: assert helpers for create and update * type hints --- codeforlife/serializers/base.py | 60 ++++++++++++- codeforlife/tests/model_serializer.py | 81 ++++++++++++++++-- codeforlife/tests/model_view_set.py | 6 +- codeforlife/user/models/user.py | 8 -- .../tests/auth/password_validators/base.py | 18 ++-- codeforlife/user/tests/views/test_klass.py | 21 +++++ codeforlife/user/tests/views/test_school.py | 28 ++++--- codeforlife/user/views/klass.py | 21 +++-- codeforlife/user/views/school.py | 18 ++-- codeforlife/user/views/user.py | 32 +++---- codeforlife/views/api.py | 84 +++++++++++++++++++ codeforlife/views/model.py | 3 +- 12 files changed, 307 insertions(+), 73 deletions(-) create mode 100644 codeforlife/views/api.py diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py index 16af791b..24083884 100644 --- a/codeforlife/serializers/base.py +++ b/codeforlife/serializers/base.py @@ -12,7 +12,13 @@ from rest_framework.serializers import BaseSerializer as _BaseSerializer from ..request import Request -from ..user.models import User +from ..user.models import ( # TODO: add IndependentUser + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + User, +) # pylint: disable-next=abstract-method @@ -28,15 +34,63 @@ def request(self): @property def request_user(self): """ - The user that made the request. Assumes the user has authenticated. + The user that made the request. + Assumes the user has authenticated. """ return t.cast(User, self.request.user) + @property + def request_teacher_user(self): + """ + The teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(TeacherUser, self.request.user) + + @property + def request_school_teacher_user(self): + """ + The school-teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(SchoolTeacherUser, self.request.user) + + @property + def request_non_school_teacher_user(self): + """ + The non-school-teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(NonSchoolTeacherUser, self.request.user) + + @property + def request_student_user(self): + """ + The student-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(StudentUser, self.request.user) + + # TODO: uncomment when moving to new data models. + # @property + # def request_indy_user(self): + # """ + # The independent-user that made the request. + # Assumes the user has authenticated. + # """ + + # return t.cast(IndependentUser, self.request.user) + @property def request_anon_user(self): """ - The user that made the request. Assumes the user has not authenticated. + The user that made the request. + Assumes the user has not authenticated. """ return t.cast(AnonymousUser, self.request.user) diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 6ef0e9d4..67d26f3d 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -6,8 +6,10 @@ """ import typing as t +from unittest.case import _AssertRaisesContext from django.db.models import Model +from django.forms.models import model_to_dict from django.test import TestCase from rest_framework.serializers import ValidationError from rest_framework.test import APIRequestFactory @@ -26,6 +28,8 @@ class ModelSerializerTestCase(TestCase, t.Generic[AnyModel]): request_factory = APIRequestFactory() + Data = t.Dict[str, t.Any] + @classmethod def setUpClass(cls): attr_name = "model_serializer_class" @@ -56,23 +60,26 @@ def assert_raises_validation_error(self, code: str, *args, **kwargs): The assert-raises context which will auto-assert the code. """ - context = self.assertRaises(ValidationError, *args, **kwargs) - - class ContextWrapper: + class Wrapper: """Wrap context to assert code on exit.""" - def __init__(self, context): - self.context = context + def __init__(self, ctx: "_AssertRaisesContext[ValidationError]"): + self.ctx = ctx def __enter__(self, *args, **kwargs): - return self.context.__enter__(*args, **kwargs) + return self.ctx.__enter__(*args, **kwargs) def __exit__(self, *args, **kwargs): - value = self.context.__exit__(*args, **kwargs) - assert self.context.exception.detail[0].code == code + value = self.ctx.__exit__(*args, **kwargs) + assert ( + code + == self.ctx.exception.detail[ # type: ignore[union-attr] + 0 # type: ignore[index] + ].code + ) return value - return ContextWrapper(context) + return Wrapper(self.assertRaises(ValidationError, *args, **kwargs)) # pylint: disable-next=too-many-arguments def _assert_validate( @@ -166,3 +173,59 @@ def get_validate(serializer: ModelSerializer[AnyModel]): get_validate, **kwargs, ) + + def _assert_data_is_subset_of_model(self, data: Data, model): + assert isinstance(model, Model) + + for field, value in data.copy().items(): + # NOTE: A data value of type dict == a foreign object on the model. + if isinstance(value, dict): + self._assert_data_is_subset_of_model( + value, + getattr(model, field), + ) + data.pop(field) + + self.assertDictContainsSubset(data, model_to_dict(model)) + + def assert_create( + self, + validated_data: Data, + *args, + new_data: t.Optional[Data] = None, + **kwargs, + ): + """Assert that the data used to create the model is a subset of the + model's data. + + Args: + validated_data: The data used to create the model. + new_data: Any new data that the model may have after creating. + """ + + serializer = self.model_serializer_class(*args, **kwargs) + model = serializer.create(validated_data) + data = {**validated_data, **(new_data or {})} + self._assert_data_is_subset_of_model(data, model) + + def assert_update( + self, + instance: AnyModel, + validated_data: Data, + *args, + new_data: t.Optional[Data] = None, + **kwargs, + ): + """Assert that the data used to update the model is a subset of the + model's data. + + Args: + instance: The model instance to update. + validated_data: The data used to update the model. + new_data: Any new data that the model may have after updating. + """ + + serializer = self.model_serializer_class(*args, **kwargs) + model = serializer.update(instance, validated_data) + data = {**validated_data, **(new_data or {})} + self._assert_data_is_subset_of_model(data, model) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index fd1dfd3c..2a245e4a 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -786,7 +786,7 @@ def get_other_user( other_user = other_users.first() assert other_user assert user != other_user - assert other_user.is_teacher if is_teacher else other_user.is_student + assert other_user.teacher if is_teacher else other_user.student return other_user def get_other_school_user( @@ -844,12 +844,12 @@ def get_another_school_user( # Cannot assert that 2 teachers are in the same class since a class # can only have 1 teacher. - if not (user.is_teacher and other_user.is_teacher): + if not (user.teacher and other_user.teacher): # At this point, same_class needs to be set. assert same_class is not None, "same_class must be set." # If one of the users is a teacher. - if user.is_teacher or is_teacher: + if user.teacher or is_teacher: # Get the teacher. teacher = other_user if is_teacher else user diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index 7df30615..407689a5 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -57,14 +57,6 @@ def teacher(self) -> t.Optional[Teacher]: except Teacher.DoesNotExist: return None - @property - def is_student(self): - return self.student is not None - - @property - def is_teacher(self): - return self.teacher is not None - @property def otp_secret(self): return self.userprofile.otp_secret diff --git a/codeforlife/user/tests/auth/password_validators/base.py b/codeforlife/user/tests/auth/password_validators/base.py index 27bbb466..94d141d1 100644 --- a/codeforlife/user/tests/auth/password_validators/base.py +++ b/codeforlife/user/tests/auth/password_validators/base.py @@ -5,6 +5,8 @@ Base test case for all password validators. """ +from unittest.case import _AssertRaisesContext + from django.core.exceptions import ValidationError from django.test import TestCase @@ -22,20 +24,18 @@ def assert_raises_validation_error(self, code: str, *args, **kwargs): The assert-raises context which will auto-assert the code. """ - context = self.assertRaises(ValidationError, *args, **kwargs) - - class ContextWrapper: + class Wrapper: """Wrap context to assert code on exit.""" - def __init__(self, context): - self.context = context + def __init__(self, ctx: "_AssertRaisesContext[ValidationError]"): + self.ctx = ctx def __enter__(self, *args, **kwargs): - return self.context.__enter__(*args, **kwargs) + return self.ctx.__enter__(*args, **kwargs) def __exit__(self, *args, **kwargs): - value = self.context.__exit__(*args, **kwargs) - assert self.context.exception.code == code + value = self.ctx.__exit__(*args, **kwargs) + assert self.ctx.exception.code == code return value - return ContextWrapper(context) + return Wrapper(self.assertRaises(ValidationError, *args, **kwargs)) diff --git a/codeforlife/user/tests/views/test_klass.py b/codeforlife/user/tests/views/test_klass.py index 86452206..e1f712fa 100644 --- a/codeforlife/user/tests/views/test_klass.py +++ b/codeforlife/user/tests/views/test_klass.py @@ -5,6 +5,7 @@ from ....tests import ModelViewSetTestCase from ...models import Class +from ...permissions import InSchool, IsTeacher from ...views import ClassViewSet @@ -59,3 +60,23 @@ def test_retrieve__student__same_school__in_class(self): # TODO: other retrieve and list tests # TODO: replace above tests with get_queryset() tests + + def test_get_permissions__list(self): + """ + Only school-teachers can list classes. + """ + + self.assert_get_permissions( + permissions=[IsTeacher(), InSchool()], + action="list", + ) + + def test_get_permissions__retrieve(self): + """ + Anyone in a school can retrieve a class. + """ + + self.assert_get_permissions( + permissions=[InSchool()], + action="retrieve", + ) diff --git a/codeforlife/user/tests/views/test_school.py b/codeforlife/user/tests/views/test_school.py index 222c3726..c13613a4 100644 --- a/codeforlife/user/tests/views/test_school.py +++ b/codeforlife/user/tests/views/test_school.py @@ -5,8 +5,10 @@ from rest_framework import status +from ....permissions import AllowNone from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile +from ...permissions import InSchool from ...views import SchoolViewSet @@ -176,22 +178,24 @@ def test_list__indy_student(self): self.client.list([], status.HTTP_403_FORBIDDEN) - def test_list__teacher(self): + # TODO: replace above tests with get_queryset() tests + + def test_get_permissions__list(self): """ - Teacher can list only the school they are in. + No one is allowed to list schools. """ - user = self._login_teacher() - - self.client.list([user.teacher.school]) + self.assert_get_permissions( + permissions=[AllowNone()], + action="list", + ) - def test_list__student(self): + def test_get_permissions__retrieve(self): """ - Student can list only the school they are in. + Only a user in a school can retrieve a school. """ - user = self._login_student() - - self.client.list([user.student.class_field.teacher.school]) - - # TODO: replace above tests with get_queryset() tests + self.assert_get_permissions( + permissions=[InSchool()], + action="retrieve", + ) diff --git a/codeforlife/user/views/klass.py b/codeforlife/user/views/klass.py index e4d5ff06..6fbd7fb9 100644 --- a/codeforlife/user/views/klass.py +++ b/codeforlife/user/views/klass.py @@ -3,11 +3,9 @@ Created on 24/01/2024 at 13:47:53(+00:00). """ -import typing as t - from ...views import ModelViewSet -from ..models import Class, User -from ..permissions import InSchool +from ..models import Class +from ..permissions import InSchool, IsTeacher from ..serializers import ClassSerializer @@ -16,15 +14,22 @@ class ClassViewSet(ModelViewSet[Class]): http_method_names = ["get"] lookup_field = "access_code" serializer_class = ClassSerializer - permission_classes = [InSchool] + + def get_permissions(self): + # Only school-teachers can list classes. + if self.action == "list": + return [IsTeacher(), InSchool()] + + return [InSchool()] # pylint: disable-next=missing-function-docstring def get_queryset(self): - user = t.cast(User, self.request.user) - if user.is_student: + user = self.request_user + if user.student: return Class.objects.filter(students=user.student) + + user = self.request_school_teacher_user if user.teacher.is_admin: - # TODO: add school field to class object return Class.objects.filter(teacher__school=user.teacher.school) return Class.objects.filter(teacher=user.teacher) diff --git a/codeforlife/user/views/school.py b/codeforlife/user/views/school.py index cc3b903c..efbf04d5 100644 --- a/codeforlife/user/views/school.py +++ b/codeforlife/user/views/school.py @@ -3,10 +3,9 @@ Created on 24/01/2024 at 13:38:15(+00:00). """ -import typing as t - +from ...permissions import AllowNone from ...views import ModelViewSet -from ..models import School, User +from ..models import School from ..permissions import InSchool from ..serializers import SchoolSerializer @@ -15,15 +14,22 @@ class SchoolViewSet(ModelViewSet[School]): http_method_names = ["get"] serializer_class = SchoolSerializer - permission_classes = [InSchool] + + def get_permissions(self): + # No one is allowed to list schools. + if self.action == "list": + return [AllowNone()] + + return [InSchool()] # pylint: disable-next=missing-function-docstring def get_queryset(self): - user = t.cast(User, self.request.user) - if user.is_student: + user = self.request_user + if user.student: return School.objects.filter( # TODO: should be user.student.school_id id=user.student.class_field.teacher.school_id ) + user = self.request_school_teacher_user return School.objects.filter(id=user.teacher.school_id) diff --git a/codeforlife/user/views/user.py b/codeforlife/user/views/user.py index 3942b709..4331ccaf 100644 --- a/codeforlife/user/views/user.py +++ b/codeforlife/user/views/user.py @@ -19,8 +19,8 @@ class UserViewSet(ModelViewSet[User]): # pylint: disable-next=missing-function-docstring def get_queryset(self): - user = t.cast(User, self.request.user) - if user.is_student: + user = self.request_user + if user.student: if user.student.class_field is None: return User.objects.filter(id=user.id) @@ -33,18 +33,22 @@ def get_queryset(self): return teachers | students - teachers = User.objects.filter( - new_teacher__school=user.teacher.school_id - ) - students = ( - User.objects.filter( - # TODO: add school foreign key to student model. - new_student__class_field__teacher__school=user.teacher.school_id, + user = self.request_teacher_user + if user.teacher.school: + teachers = User.objects.filter( + new_teacher__school=user.teacher.school_id ) - if user.teacher.is_admin - else User.objects.filter( - new_student__class_field__teacher=user.teacher + students = ( + User.objects.filter( + # TODO: add school foreign key to student model. + new_student__class_field__teacher__school=user.teacher.school_id, + ) + if user.teacher.is_admin + else User.objects.filter( + new_student__class_field__teacher=user.teacher + ) ) - ) - return teachers | students + return teachers | students + + return User.objects.filter(pk=user.pk) diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py new file mode 100644 index 00000000..c0c4c7de --- /dev/null +++ b/codeforlife/views/api.py @@ -0,0 +1,84 @@ +""" +© Ocado Group +Created on 05/02/2024 at 16:33:52(+00:00). +""" + +import typing as t + +from django.contrib.auth.models import AnonymousUser +from rest_framework.views import APIView as _APIView + +from ..user.models import ( + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + User, +) + + +# pylint: disable-next=missing-class-docstring +class APIView(_APIView): + @property + def request_user(self): + """ + The user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(User, self.request.user) + + @property + def request_teacher_user(self): + """ + The teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(TeacherUser, self.request.user) + + @property + def request_school_teacher_user(self): + """ + The school-teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(SchoolTeacherUser, self.request.user) + + @property + def request_non_school_teacher_user(self): + """ + The non-school-teacher-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(NonSchoolTeacherUser, self.request.user) + + @property + def request_student_user(self): + """ + The student-user that made the request. + Assumes the user has authenticated. + """ + + return t.cast(StudentUser, self.request.user) + + # TODO: uncomment when moving to new data models. + # @property + # def request_indy_user(self): + # """ + # The independent-user that made the request. + # Assumes the user has authenticated. + # """ + + # return t.cast(IndependentUser, self.request.user) + + @property + def request_anon_user(self): + """ + The user that made the request. + Assumes the user has not authenticated. + """ + + return t.cast(AnonymousUser, self.request.user) diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index c0173844..57bb3a98 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -16,6 +16,7 @@ from ..permissions import Permission from ..serializers import ModelListSerializer, ModelSerializer +from .api import APIView AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -32,7 +33,7 @@ class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): # pylint: disable-next=too-many-ancestors -class ModelViewSet(_ModelViewSet[AnyModel], t.Generic[AnyModel]): +class ModelViewSet(APIView, _ModelViewSet[AnyModel], t.Generic[AnyModel]): """Base model view set for all model view sets.""" serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]]