From 5d5887c977fcc9bf7adb0c52b9f2f1130affa143 Mon Sep 17 00:00:00 2001 From: Stephane Latil Date: Thu, 13 Jun 2024 13:26:08 +0200 Subject: [PATCH 1/7] Adds aget_object_or_404 --- adrf/shortcuts.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 adrf/shortcuts.py diff --git a/adrf/shortcuts.py b/adrf/shortcuts.py new file mode 100644 index 0000000..1abd611 --- /dev/null +++ b/adrf/shortcuts.py @@ -0,0 +1,29 @@ +from django.http import Http404 + +def _get_queryset(klass): + """ + Return a QuerySet or a Manager. + Duck typing in action: any class with a `get()` method (for + get_object_or_404) or a `filter()` method (for get_list_or_404) might do + the job. + """ + # If it is a model class or anything else with ._default_manager + if hasattr(klass, "_default_manager"): + return klass._default_manager.all() + return klass + +async def aget_object_or_404(klass, *args, **kwargs): + """See get_object_or_404().""" + queryset = _get_queryset(klass) + if not hasattr(queryset, "aget"): + klass__name = ( + klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + ) + raise ValueError( + "First argument to aget_object_or_404() must be a Model, Manager, or " + f"QuerySet, not '{klass__name}'." + ) + try: + return await queryset.aget(*args, **kwargs) + except queryset.model.DoesNotExist: + raise Http404(f"No {queryset.model._meta.object_name} matches the given query.") \ No newline at end of file From 4a04a2b200fa6a4c63e011e0858afa1db32ed862 Mon Sep 17 00:00:00 2001 From: Stephane Latil Date: Thu, 13 Jun 2024 13:26:33 +0200 Subject: [PATCH 2/7] Adds GenericViewSet --- adrf/viewsets.py | 210 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 209 insertions(+), 1 deletion(-) diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 566bc90..7d476f3 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -1,11 +1,17 @@ import asyncio from functools import update_wrapper +from asgiref.sync import async_to_sync, sync_to_async +from django.db.models import QuerySet +from django.shortcuts import get_object_or_404 from django.utils.decorators import classonlymethod from django.utils.functional import classproperty from adrf.views import APIView +from adrf.shortcuts import aget_object_or_404 from rest_framework.viewsets import ViewSetMixin as DRFViewSetMixin +from rest_framework.settings import api_settings +from rest_framework.response import Response class ViewSetMixin(DRFViewSetMixin): @@ -139,6 +145,8 @@ async def async_view(request, *args, **kwargs): class ViewSet(ViewSetMixin, APIView): + _ASYNC_NON_DISPATCH_METHODS = [] + @classproperty def view_is_async(cls): """ @@ -147,6 +155,206 @@ def view_is_async(cls): result = [ asyncio.iscoroutinefunction(function) for name, function in cls.__dict__.items() - if callable(function) and not name.startswith("__") + if callable(function) and not name.startswith("__") + and not name in cls._ASYNC_NON_DISPATCH_METHODS ] return any(result) + +class GenericViewSet(ViewSet): + """ + Base class for all other generic views. + """ + _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS \ + + ['aget_object'] + + queryset = None + serializer_class = None + + # If you want to use object lookups other than pk, set 'lookup_field'. + # For more complex lookup requirements override `get_object()`. + lookup_field = 'pk' + lookup_url_kwarg = None + + # The filter backend classes to use for queryset filtering + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS + + # The style to use for queryset pagination. + pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + + # Allow generic typing checking for generic views. + def __class_getitem__(cls, *args, **kwargs): + return cls + + def get_queryset(self): + """ + Get the list of items for this view. + This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + This method should always be used rather than accessing `self.queryset` + directly, as `self.queryset` gets evaluated only once, and those results + are cached for all subsequent requests. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) + """ + assert self.queryset is not None, ( + "'%s' should either include a `queryset` attribute, " + "or override the `get_queryset()` method." + % self.__class__.__name__ + ) + + queryset = self.queryset + if isinstance(queryset, QuerySet): + # Ensure queryset is re-evaluated on each request. + queryset = queryset.all() + return queryset + + async def aget_object(self): + """ + Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. + """ + queryset = self.filter_queryset(self.get_queryset()) + + # Perform the lookup filtering. + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + + assert lookup_url_kwarg in self.kwargs, ( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, lookup_url_kwarg) + ) + + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} + obj = await aget_object_or_404(queryset, **filter_kwargs) + + # May raise a permission denied + self.check_object_permissions(self.request, obj) + + return obj + + def get_object(self): + """ + Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. + """ + queryset = self.filter_queryset(self.get_queryset()) + + # Perform the lookup filtering. + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + + assert lookup_url_kwarg in self.kwargs, ( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, lookup_url_kwarg) + ) + + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} + obj = get_object_or_404(queryset, **filter_kwargs) + + # May raise a permission denied + self.check_object_permissions(self.request, obj) + + return obj + + def get_serializer(self, *args, **kwargs): + """ + Return the serializer instance that should be used for validating and + deserializing input, and for serializing output. + """ + serializer_class = self.get_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + return serializer_class(*args, **kwargs) + + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. + + (Eg. admins get full serialization, others get basic serialization) + """ + assert self.serializer_class is not None, ( + "'%s' should either include a `serializer_class` attribute, " + "or override the `get_serializer_class()` method." + % self.__class__.__name__ + ) + + return self.serializer_class + + def get_serializer_context(self): + """ + Extra context provided to the serializer class. + """ + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + + You are unlikely to want to override this method, although you may need + to call it either from a list view, or from a custom `get_object` + method if you want to apply the configured filtering backend to the + default queryset. + """ + for backend in list(self.filter_backends): + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset + + @property + def paginator(self): + """ + The paginator instance associated with the view, or `None`. + """ + if not hasattr(self, '_paginator'): + if self.pagination_class is None: + self._paginator = None + else: + self._paginator = self.pagination_class() + return self._paginator + + def paginate_queryset(self, queryset): + """ + Return a single page of results, or `None` if pagination is disabled. + """ + if self.paginator is None: + return None + if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): + return async_to_sync(self.paginator.paginate_queryset(queryset, self.request, view=self)) + return self.paginator.paginate_queryset(queryset, self.request, view=self) + + async def apaginate_queryset(self, queryset): + """ + Return a single page of results, or `None` if pagination is disabled. + """ + if self.paginator is None: + return None + if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): + return await self.paginator.paginate_queryset(queryset, self.request, view=self) + return self.paginator.paginate_queryset(queryset, self.request, view=self) + + def get_paginated_response(self, data): + """ + Return a paginated style `Response` object for the given output data. + """ + assert self.paginator is not None + if asyncio.iscoroutinefunction(self.paginator.get_paginated_response): + return async_to_sync(self.paginator.get_paginated_response(data)) + return self.paginator.get_paginated_response(data) From af0e3bd059a179689b72a54539fa33b8a25034f9 Mon Sep 17 00:00:00 2001 From: Stephane Latil Date: Thu, 13 Jun 2024 13:36:51 +0200 Subject: [PATCH 3/7] Adds some Mixins --- adrf/viewsets.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 7d476f3..3202102 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -12,6 +12,7 @@ from rest_framework.viewsets import ViewSetMixin as DRFViewSetMixin from rest_framework.settings import api_settings from rest_framework.response import Response +from rest_framework import status class ViewSetMixin(DRFViewSetMixin): @@ -165,7 +166,7 @@ class GenericViewSet(ViewSet): Base class for all other generic views. """ _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS \ - + ['aget_object'] + + ['aget_object', 'perform_create', 'apaginate_queryset'] queryset = None serializer_class = None @@ -358,3 +359,52 @@ def get_paginated_response(self, data): if asyncio.iscoroutinefunction(self.paginator.get_paginated_response): return async_to_sync(self.paginator.get_paginated_response(data)) return self.paginator.get_paginated_response(data) + +class CreateModelMixin: + """ + Create a model instance. + """ + async def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + await self.perform_create(serializer) + data = await serializer.adata + headers = self.get_success_headers(data) + return Response(data, status=status.HTTP_201_CREATED, headers=headers) + + async def perform_create(self, serializer): + await serializer.asave() + + def get_success_headers(self, data): + try: + return {'Location': str(data[api_settings.URL_FIELD_NAME])} + except (TypeError, KeyError): + return {} + +class RetrieveModelMixin: + """ + Retrieve a model instance. + """ + async def retrieve(self, request, *args, **kwargs): + instance = await self.aget_object() + serializer = self.get_serializer(instance, many=False) + #try to serialize async is the serializer supports it. Sync otherwise + data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data + return Response(data, status=status.HTTP_200_OK) + +class ListModelMixin: + """ + List a queryset. + """ + async def list(self, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + + page = await self.apaginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data + return await self.aget_paginated_response(data) + + serializer = self.get_serializer(queryset, many=True) + data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data + return Response(data, status=status.HTTP_200_OK) From cc2db45dd0c939b1116fa42626a50ed7a74af4c5 Mon Sep 17 00:00:00 2001 From: Stephane Latil Date: Thu, 13 Jun 2024 14:07:28 +0200 Subject: [PATCH 4/7] Adds Update and Delete model Mixins --- adrf/viewsets.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 3202102..55346f0 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -166,7 +166,8 @@ class GenericViewSet(ViewSet): Base class for all other generic views. """ _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS \ - + ['aget_object', 'perform_create', 'apaginate_queryset'] + + ['aget_object', 'perform_create', 'apaginate_queryset', + 'perform_update'] queryset = None serializer_class = None @@ -408,3 +409,43 @@ async def list(self, *args, **kwargs): serializer = self.get_serializer(queryset, many=True) data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data return Response(data, status=status.HTTP_200_OK) + + +class UpdateModelMixin: + """ + Update a model instance. + """ + async def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = await self.aget_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + data = serializer.adata + await self.perform_update(serializer) + + if getattr(instance, '_prefetched_objects_cache', None): + # If 'prefetch_related' has been applied to a queryset, we need to + # forcibly invalidate the prefetch cache on the instance. + instance._prefetched_objects_cache = {} + + return Response(await data, status=status.HTTP_200_OK) + + async def perform_update(self, serializer): + await serializer.asave() + + async def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return await self.update(request, *args, **kwargs) + + +class DestroyModelMixin: + """ + Destroy a model instance. + """ + async def destroy(self, request, *args, **kwargs): + instance = await self.aget_object() + await self.perform_destroy(instance) + return Response(status=status.HTTP_204_NO_CONTENT) + + async def perform_destroy(self, instance): + await instance.adelete() From 340dc39e5beb65fd221077b1533413dea542159e Mon Sep 17 00:00:00 2001 From: Stephane Latil Date: Thu, 13 Jun 2024 14:10:51 +0200 Subject: [PATCH 5/7] Adds ModelViewSet and ReadOnly version --- adrf/viewsets.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 55346f0..9739e12 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -449,3 +449,23 @@ async def destroy(self, request, *args, **kwargs): async def perform_destroy(self, instance): await instance.adelete() + +class ReadOnlyModelViewSet(RetrieveModelMixin, + ListModelMixin, + GenericViewSet): + """ + A viewset that provides default asynchronous `list()` and `retrieve()` actions. + """ + pass + +class ModelViewSet(CreateModelMixin, + ListModelMixin, + RetrieveModelMixin, + UpdateModelMixin, + DestroyModelMixin, + GenericViewSet): + """ + A viewset that provides default asynchronous `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ + pass \ No newline at end of file From 026d4e4bb8c265b8955b2058becd38a79437b00f Mon Sep 17 00:00:00 2001 From: Enrico Massa Date: Sat, 3 Aug 2024 20:20:20 +0800 Subject: [PATCH 6/7] Added async viewset support and generics --- .github/workflows/main.yml | 2 +- .pre-commit-config.yaml | 4 +- adrf/generics.py | 212 +++++++++++++++++ adrf/mixins.py | 98 ++++++++ adrf/requests.py | 1 - adrf/serializers.py | 16 +- adrf/shortcuts.py | 50 ++-- adrf/test.py | 1 - adrf/utils.py | 49 ++++ adrf/views.py | 4 +- adrf/viewsets.py | 346 +++------------------------- pyproject.toml | 18 ++ setup.cfg | 28 --- tests/conftest.py | 14 +- tests/{test_models.py => models.py} | 2 +- tests/test_authentication.py | 8 +- tests/test_generics.py | 122 ++++++++++ tests/test_object_permissions.py | 4 +- tests/test_permissions.py | 4 +- tests/test_serializers.py | 7 +- tests/test_shortcuts.py | 28 +++ tests/test_testmodule.py | 4 +- tests/test_throttling.py | 4 +- tests/test_views.py | 6 +- tests/test_viewsets.py | 58 ++++- tests/urls.py | 1 + tox.ini | 1 - 27 files changed, 688 insertions(+), 404 deletions(-) create mode 100644 adrf/generics.py create mode 100644 adrf/mixins.py create mode 100644 adrf/utils.py delete mode 100644 setup.cfg rename tests/{test_models.py => models.py} (100%) create mode 100644 tests/test_generics.py create mode 100644 tests/test_shortcuts.py create mode 100644 tests/urls.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8e98548..03fc443 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install poetry - run: python -m pip install poetry==1.8.2 + run: python -m pip install poetry==1.8.3 - name: Install dependencies run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7cb8a39..100e0ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-ast - id: check-added-large-files @@ -12,7 +12,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.5.5 hooks: # Run the linter. - id: ruff diff --git a/adrf/generics.py b/adrf/generics.py new file mode 100644 index 0000000..d57b703 --- /dev/null +++ b/adrf/generics.py @@ -0,0 +1,212 @@ +import asyncio + +from asgiref.sync import async_to_sync +from django.http import Http404 +from rest_framework.exceptions import ValidationError +from rest_framework.generics import GenericAPIView as DRFGenericAPIView + +from adrf import mixins, views +from adrf.shortcuts import aget_object_or_404 as _aget_object_or_404 + + +def aget_object_or_404(queryset, *filter_args, **filter_kwargs): + """ + Same as Django's standard shortcut, but make sure to also raise 404 + if the filter_kwargs don't match the required types. + """ + try: + return _aget_object_or_404(queryset, *filter_args, **filter_kwargs) + except (TypeError, ValueError, ValidationError): + raise Http404 + + +class GenericAPIView(views.APIView, DRFGenericAPIView): + """This generic API view supports async pagination.""" + + async def aget_object(self): + """ + Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. + """ + queryset = self.filter_queryset(self.get_queryset()) + + # Perform the lookup filtering. + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + + assert lookup_url_kwarg in self.kwargs, ( + "Expected view %s to be called with a URL keyword argument " + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + "attribute on the view correctly." + % (self.__class__.__name__, lookup_url_kwarg) + ) + + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} + obj = await aget_object_or_404(queryset, **filter_kwargs) + + # May raise a permission denied + self.check_object_permissions(self.request, obj) + + return obj + + def paginate_queryset(self, queryset): + """ + Return a single page of results, or `None` if pagination is disabled. + """ + if self.paginator is None: + return None + if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): + return async_to_sync(self.paginator.paginate_queryset)( + queryset, self.request, view=self + ) + return self.paginator.paginate_queryset(queryset, self.request, view=self) + + def get_paginated_response(self, data): + """ + Return a paginated style `Response` object for the given output data. + """ + assert self.paginator is not None + if asyncio.iscoroutinefunction(self.paginator.get_paginated_response): + return async_to_sync(self.paginator.get_paginated_response)(data) + return self.paginator.get_paginated_response(data) + + async def apaginate_queryset(self, queryset): + """ + Return a single page of results, or `None` if pagination is disabled. + """ + if self.paginator is None: + return None + if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): + return await self.paginator.paginate_queryset( + queryset, self.request, view=self + ) + return self.paginator.paginate_queryset(queryset, self.request, view=self) + + async def get_apaginated_response(self, data): + """ + Return a paginated style `Response` object for the given output data. + """ + assert self.paginator is not None + if asyncio.iscoroutinefunction(self.paginator.get_paginated_response): + return await self.paginator.get_paginated_response(data) + return self.paginator.get_paginated_response(data) + + +# Concrete view classes that provide method handlers +# by composing the mixin classes with the base view. + + +class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): + """ + Concrete view for creating a model instance. + """ + + async def post(self, request, *args, **kwargs): + return await self.acreate(request, *args, **kwargs) + + +class ListAPIView(mixins.ListModelMixin, GenericAPIView): + """ + Concrete view for listing a queryset. + """ + + async def get(self, request, *args, **kwargs): + return await self.alist(request, *args, **kwargs) + + +class RetrieveAPIView(mixins.RetrieveModelMixin, GenericAPIView): + """ + Concrete view for retrieving a model instance. + """ + + async def get(self, request, *args, **kwargs): + return await self.aretrieve(request, *args, **kwargs) + + +class DestroyAPIView(mixins.DestroyModelMixin, GenericAPIView): + """ + Concrete view for deleting a model instance. + """ + + async def delete(self, request, *args, **kwargs): + return await self.adestroy(request, *args, **kwargs) + + +class UpdateAPIView(mixins.UpdateModelMixin, GenericAPIView): + """ + Concrete view for updating a model instance. + """ + + async def put(self, request, *args, **kwargs): + return await self.aupdate(request, *args, **kwargs) + + async def patch(self, request, *args, **kwargs): + return await self.partial_aupdate(request, *args, **kwargs) + + +class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, GenericAPIView): + """ + Concrete view for listing a queryset or creating a model instance. + """ + + async def get(self, request, *args, **kwargs): + return await self.alist(request, *args, **kwargs) + + async def post(self, request, *args, **kwargs): + return await self.acreate(request, *args, **kwargs) + + +class RetrieveUpdateAPIView( + mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView +): + """ + Concrete view for retrieving, updating a model instance. + """ + + async def get(self, request, *args, **kwargs): + return await self.aretrieve(request, *args, **kwargs) + + async def put(self, request, *args, **kwargs): + return await self.aupdate(request, *args, **kwargs) + + async def patch(self, request, *args, **kwargs): + return await self.partial_aupdate(request, *args, **kwargs) + + +class RetrieveDestroyAPIView( + mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView +): + """ + Concrete view for retrieving or deleting a model instance. + """ + + async def get(self, request, *args, **kwargs): + return await self.aretrieve(request, *args, **kwargs) + + async def delete(self, request, *args, **kwargs): + return await self.adestroy(request, *args, **kwargs) + + +class RetrieveUpdateDestroyAPIView( + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + GenericAPIView, +): + """ + Concrete view for retrieving, updating or deleting a model instance. + """ + + async def get(self, request, *args, **kwargs): + return await self.aretrieve(request, *args, **kwargs) + + async def put(self, request, *args, **kwargs): + return await self.aupdate(request, *args, **kwargs) + + async def patch(self, request, *args, **kwargs): + return await self.partial_aupdate(request, *args, **kwargs) + + async def delete(self, request, *args, **kwargs): + return await self.adestroy(request, *args, **kwargs) diff --git a/adrf/mixins.py b/adrf/mixins.py new file mode 100644 index 0000000..5f4f3a1 --- /dev/null +++ b/adrf/mixins.py @@ -0,0 +1,98 @@ +from asgiref.sync import sync_to_async +from rest_framework import mixins, status +from rest_framework.response import Response + + +async def get_data(serializer): + """Use adata if the serializer supports it, data otherwise.""" + return await serializer.adata if hasattr(serializer, "adata") else serializer.data + + +class CreateModelMixin(mixins.CreateModelMixin): + """ + Create a model instance. + """ + + async def acreate(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + await sync_to_async(serializer.is_valid)(raise_exception=True) + await self.perform_acreate(serializer) + data = await get_data(serializer) + headers = self.get_success_headers(data) + return Response(data, status=status.HTTP_201_CREATED, headers=headers) + + async def perform_acreate(self, serializer): + await serializer.asave() + + +class ListModelMixin(mixins.ListModelMixin): + """ + List a queryset. + """ + + async def alist(self, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + + page = await self.apaginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = await get_data(serializer) + return await self.get_apaginated_response(data) + + serializer = self.get_serializer(queryset, many=True) + data = await get_data(serializer) + return Response(data, status=status.HTTP_200_OK) + + +class RetrieveModelMixin(mixins.RetrieveModelMixin): + """ + Retrieve a model instance. + """ + + async def aretrieve(self, request, *args, **kwargs): + instance = await self.aget_object() + serializer = self.get_serializer(instance, many=False) + data = await get_data(serializer) + return Response(data, status=status.HTTP_200_OK) + + +class UpdateModelMixin(mixins.UpdateModelMixin): + """ + Update a model instance. + """ + + async def aupdate(self, request, *args, **kwargs): + partial = kwargs.pop("partial", False) + instance = await self.aget_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + await sync_to_async(serializer.is_valid)(raise_exception=True) + await self.perform_aupdate(serializer) + + if getattr(instance, "_prefetched_objects_cache", None): + # If 'prefetch_related' has been applied to a queryset, we need to + # forcibly invalidate the prefetch cache on the instance. + instance._prefetched_objects_cache = {} + data = await get_data(serializer) + + return Response(data, status=status.HTTP_200_OK) + + async def perform_aupdate(self, serializer): + await serializer.asave() + + async def partial_aupdate(self, request, *args, **kwargs): + kwargs["partial"] = True + return await self.aupdate(request, *args, **kwargs) + + +class DestroyModelMixin(mixins.DestroyModelMixin): + """ + Destroy a model instance. + """ + + async def adestroy(self, request, *args, **kwargs): + instance = await self.aget_object() + await self.perform_adestroy(instance) + return Response(status=status.HTTP_204_NO_CONTENT) + + async def perform_adestroy(self, instance): + await instance.adelete() diff --git a/adrf/requests.py b/adrf/requests.py index bf9a5c4..13f641e 100644 --- a/adrf/requests.py +++ b/adrf/requests.py @@ -1,7 +1,6 @@ import asyncio from asgiref.sync import async_to_sync - from rest_framework import exceptions from rest_framework.request import Request, wrap_attributeerrors diff --git a/adrf/serializers.py b/adrf/serializers.py index 2884db2..433925d 100644 --- a/adrf/serializers.py +++ b/adrf/serializers.py @@ -3,18 +3,18 @@ from async_property import async_property from django.db import models - from rest_framework.fields import SkipField -from rest_framework.serializers import LIST_SERIALIZER_KWARGS +from rest_framework.serializers import ( + LIST_SERIALIZER_KWARGS, + model_meta, + raise_errors_on_nested_writes, +) from rest_framework.serializers import BaseSerializer as DRFBaseSerializer from rest_framework.serializers import ListSerializer as DRFListSerializer from rest_framework.serializers import ModelSerializer as DRFModelSerializer from rest_framework.serializers import Serializer as DRFSerializer -from rest_framework.serializers import SerializerMetaclass as DRFSerializerMetaclass -from rest_framework.serializers import model_meta, raise_errors_on_nested_writes from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList - # NOTE This is the list of fields defined by DRF for which we need to call to_rapresentation. DRF_FIELDS = list(DRFModelSerializer.serializer_field_mapping.values()) + [ DRFModelSerializer.serializer_related_field, @@ -128,11 +128,7 @@ async def asave(self, **kwargs): return self.instance -class _Serializer(metaclass=DRFSerializerMetaclass): - pass - - -class Serializer(BaseSerializer, _Serializer, DRFSerializer): +class Serializer(BaseSerializer, DRFSerializer): @async_property async def adata(self): """ diff --git a/adrf/shortcuts.py b/adrf/shortcuts.py index 1abd611..cc1a58e 100644 --- a/adrf/shortcuts.py +++ b/adrf/shortcuts.py @@ -1,29 +1,25 @@ from django.http import Http404 +from django.shortcuts import _get_queryset -def _get_queryset(klass): - """ - Return a QuerySet or a Manager. - Duck typing in action: any class with a `get()` method (for - get_object_or_404) or a `filter()` method (for get_list_or_404) might do - the job. - """ - # If it is a model class or anything else with ._default_manager - if hasattr(klass, "_default_manager"): - return klass._default_manager.all() - return klass - -async def aget_object_or_404(klass, *args, **kwargs): - """See get_object_or_404().""" - queryset = _get_queryset(klass) - if not hasattr(queryset, "aget"): - klass__name = ( - klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ - ) - raise ValueError( - "First argument to aget_object_or_404() must be a Model, Manager, or " - f"QuerySet, not '{klass__name}'." - ) - try: - return await queryset.aget(*args, **kwargs) - except queryset.model.DoesNotExist: - raise Http404(f"No {queryset.model._meta.object_name} matches the given query.") \ No newline at end of file +try: + from django.shortcuts import aget_object_or_404 +except ImportError: + # NOTE aget_object_or_404 defined since Django 5. + # This function will be removed when support for Django 4 is dropped. + async def aget_object_or_404(klass, *args, **kwargs): + """See get_object_or_404().""" + queryset = _get_queryset(klass) + if not hasattr(queryset, "aget"): + klass__name = ( + klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ + ) + raise ValueError( + "First argument to aget_object_or_404() must be a Model, Manager, or " + f"QuerySet, not '{klass__name}'." + ) + try: + return await queryset.aget(*args, **kwargs) + except queryset.model.DoesNotExist: + raise Http404( + f"No {queryset.model._meta.object_name} matches the given query." + ) diff --git a/adrf/test.py b/adrf/test.py index 40533d3..a8e93ad 100644 --- a/adrf/test.py +++ b/adrf/test.py @@ -4,7 +4,6 @@ from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory from django.utils.encoding import force_bytes from django.utils.http import urlencode - from rest_framework.settings import api_settings from rest_framework.test import force_authenticate diff --git a/adrf/utils.py b/adrf/utils.py new file mode 100644 index 0000000..a8a49c8 --- /dev/null +++ b/adrf/utils.py @@ -0,0 +1,49 @@ +import inspect + + +# NOTE This function has been taken from the python library and modified +# to allow an exclusion list and avoid recursion errors. +def getmembers(object, predicate, exclude_names=[]): + results = [] + processed = set() + names = [x for x in dir(object) if x not in exclude_names] + if inspect.isclass(object): + mro = inspect.getmro(object) + # add any DynamicClassAttributes to the list of names if object is a class; + # this may result in duplicate entries if, for example, a virtual + # attribute with the same name as a DynamicClassAttribute exists + try: + for base in object.__bases__: + for k, v in base.__dict__.items(): + if ( + isinstance(v, inspect.types.DynamicClassAttribute) + and k not in exclude_names + ): + names.append(k) + except AttributeError: + pass + else: + mro = () + for key in names: + # First try to get the value via getattr. Some descriptors don't + # like calling their __get__ (see bug #1785), so fall back to + # looking in the __dict__. + try: + value = getattr(object, key) + # handle the duplicate key + if key in processed: + raise AttributeError + except AttributeError: + for base in mro: + if key in base.__dict__: + value = base.__dict__[key] + break + else: + # could be a (currently) missing slot member, or a buggy + # __dir__; discard and move on + continue + if not predicate or predicate(value): + results.append((key, value)) + processed.add(key) + results.sort(key=lambda pair: pair[0]) + return results diff --git a/adrf/views.py b/adrf/views.py index 3c24898..27fe8af 100755 --- a/adrf/views.py +++ b/adrf/views.py @@ -2,13 +2,13 @@ from typing import List, Optional from asgiref.sync import async_to_sync, sync_to_async - -from adrf.requests import AsyncRequest from rest_framework.permissions import BasePermission from rest_framework.request import Request from rest_framework.throttling import BaseThrottle from rest_framework.views import APIView as DRFAPIView +from adrf.requests import AsyncRequest + class APIView(DRFAPIView): def sync_dispatch(self, request, *args, **kwargs): diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 9739e12..be1c4f7 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -1,18 +1,15 @@ import asyncio +import inspect from functools import update_wrapper -from asgiref.sync import async_to_sync, sync_to_async -from django.db.models import QuerySet -from django.shortcuts import get_object_or_404 from django.utils.decorators import classonlymethod from django.utils.functional import classproperty +from rest_framework.viewsets import ViewSetMixin as DRFViewSetMixin +from adrf import mixins +from adrf.generics import GenericAPIView +from adrf.utils import getmembers from adrf.views import APIView -from adrf.shortcuts import aget_object_or_404 -from rest_framework.viewsets import ViewSetMixin as DRFViewSetMixin -from rest_framework.settings import api_settings -from rest_framework.response import Response -from rest_framework import status class ViewSetMixin(DRFViewSetMixin): @@ -146,8 +143,13 @@ async def async_view(request, *args, **kwargs): class ViewSet(ViewSetMixin, APIView): - _ASYNC_NON_DISPATCH_METHODS = [] - + _ASYNC_NON_DISPATCH_METHODS = [ + "check_async_object_permissions", + "async_dispatch", + "check_async_permissions", + "check_async_throttles", + ] + @classproperty def view_is_async(cls): """ @@ -155,317 +157,43 @@ def view_is_async(cls): """ result = [ asyncio.iscoroutinefunction(function) - for name, function in cls.__dict__.items() - if callable(function) and not name.startswith("__") - and not name in cls._ASYNC_NON_DISPATCH_METHODS + for name, function in getmembers( + cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"] + ) + if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS ] return any(result) -class GenericViewSet(ViewSet): - """ - Base class for all other generic views. - """ - _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS \ - + ['aget_object', 'perform_create', 'apaginate_queryset', - 'perform_update'] - - queryset = None - serializer_class = None - - # If you want to use object lookups other than pk, set 'lookup_field'. - # For more complex lookup requirements override `get_object()`. - lookup_field = 'pk' - lookup_url_kwarg = None - - # The filter backend classes to use for queryset filtering - filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - - # The style to use for queryset pagination. - pagination_class = api_settings.DEFAULT_PAGINATION_CLASS - - # Allow generic typing checking for generic views. - def __class_getitem__(cls, *args, **kwargs): - return cls - - def get_queryset(self): - """ - Get the list of items for this view. - This must be an iterable, and may be a queryset. - Defaults to using `self.queryset`. - - This method should always be used rather than accessing `self.queryset` - directly, as `self.queryset` gets evaluated only once, and those results - are cached for all subsequent requests. - - You may want to override this if you need to provide different - querysets depending on the incoming request. - - (Eg. return a list of items that is specific to the user) - """ - assert self.queryset is not None, ( - "'%s' should either include a `queryset` attribute, " - "or override the `get_queryset()` method." - % self.__class__.__name__ - ) - - queryset = self.queryset - if isinstance(queryset, QuerySet): - # Ensure queryset is re-evaluated on each request. - queryset = queryset.all() - return queryset - - async def aget_object(self): - """ - Returns the object the view is displaying. - - You may want to override this if you need to provide non-standard - queryset lookups. Eg if objects are referenced using multiple - keyword arguments in the url conf. - """ - queryset = self.filter_queryset(self.get_queryset()) - - # Perform the lookup filtering. - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - - assert lookup_url_kwarg in self.kwargs, ( - 'Expected view %s to be called with a URL keyword argument ' - 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - 'attribute on the view correctly.' % - (self.__class__.__name__, lookup_url_kwarg) - ) - - filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} - obj = await aget_object_or_404(queryset, **filter_kwargs) - - # May raise a permission denied - self.check_object_permissions(self.request, obj) - - return obj - - def get_object(self): - """ - Returns the object the view is displaying. - - You may want to override this if you need to provide non-standard - queryset lookups. Eg if objects are referenced using multiple - keyword arguments in the url conf. - """ - queryset = self.filter_queryset(self.get_queryset()) - - # Perform the lookup filtering. - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - - assert lookup_url_kwarg in self.kwargs, ( - 'Expected view %s to be called with a URL keyword argument ' - 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - 'attribute on the view correctly.' % - (self.__class__.__name__, lookup_url_kwarg) - ) - - filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} - obj = get_object_or_404(queryset, **filter_kwargs) - - # May raise a permission denied - self.check_object_permissions(self.request, obj) - - return obj - - def get_serializer(self, *args, **kwargs): - """ - Return the serializer instance that should be used for validating and - deserializing input, and for serializing output. - """ - serializer_class = self.get_serializer_class() - kwargs.setdefault('context', self.get_serializer_context()) - return serializer_class(*args, **kwargs) - - def get_serializer_class(self): - """ - Return the class to use for the serializer. - Defaults to using `self.serializer_class`. - - You may want to override this if you need to provide different - serializations depending on the incoming request. - - (Eg. admins get full serialization, others get basic serialization) - """ - assert self.serializer_class is not None, ( - "'%s' should either include a `serializer_class` attribute, " - "or override the `get_serializer_class()` method." - % self.__class__.__name__ - ) - - return self.serializer_class - def get_serializer_context(self): - """ - Extra context provided to the serializer class. - """ - return { - 'request': self.request, - 'format': self.format_kwarg, - 'view': self - } - - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - - You are unlikely to want to override this method, although you may need - to call it either from a list view, or from a custom `get_object` - method if you want to apply the configured filtering backend to the - default queryset. - """ - for backend in list(self.filter_backends): - queryset = backend().filter_queryset(self.request, queryset, self) - return queryset - - @property - def paginator(self): - """ - The paginator instance associated with the view, or `None`. - """ - if not hasattr(self, '_paginator'): - if self.pagination_class is None: - self._paginator = None - else: - self._paginator = self.pagination_class() - return self._paginator - - def paginate_queryset(self, queryset): - """ - Return a single page of results, or `None` if pagination is disabled. - """ - if self.paginator is None: - return None - if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): - return async_to_sync(self.paginator.paginate_queryset(queryset, self.request, view=self)) - return self.paginator.paginate_queryset(queryset, self.request, view=self) - - async def apaginate_queryset(self, queryset): - """ - Return a single page of results, or `None` if pagination is disabled. - """ - if self.paginator is None: - return None - if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): - return await self.paginator.paginate_queryset(queryset, self.request, view=self) - return self.paginator.paginate_queryset(queryset, self.request, view=self) - - def get_paginated_response(self, data): - """ - Return a paginated style `Response` object for the given output data. - """ - assert self.paginator is not None - if asyncio.iscoroutinefunction(self.paginator.get_paginated_response): - return async_to_sync(self.paginator.get_paginated_response(data)) - return self.paginator.get_paginated_response(data) - -class CreateModelMixin: - """ - Create a model instance. - """ - async def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - await self.perform_create(serializer) - data = await serializer.adata - headers = self.get_success_headers(data) - return Response(data, status=status.HTTP_201_CREATED, headers=headers) - - async def perform_create(self, serializer): - await serializer.asave() - - def get_success_headers(self, data): - try: - return {'Location': str(data[api_settings.URL_FIELD_NAME])} - except (TypeError, KeyError): - return {} - -class RetrieveModelMixin: - """ - Retrieve a model instance. - """ - async def retrieve(self, request, *args, **kwargs): - instance = await self.aget_object() - serializer = self.get_serializer(instance, many=False) - #try to serialize async is the serializer supports it. Sync otherwise - data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data - return Response(data, status=status.HTTP_200_OK) - -class ListModelMixin: - """ - List a queryset. - """ - async def list(self, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) +class GenericViewSet(ViewSet, GenericAPIView): + _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS + [ + "aget_object", + "apaginate_queryset", + "get_apaginated_response", + ] - page = await self.apaginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data - return await self.aget_paginated_response(data) - serializer = self.get_serializer(queryset, many=True) - data = await serializer.adata if hasattr(serializer, 'adata') else serializer.data - return Response(data, status=status.HTTP_200_OK) - - -class UpdateModelMixin: - """ - Update a model instance. - """ - async def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) - instance = await self.aget_object() - serializer = self.get_serializer(instance, data=request.data, partial=partial) - serializer.is_valid(raise_exception=True) - data = serializer.adata - await self.perform_update(serializer) - - if getattr(instance, '_prefetched_objects_cache', None): - # If 'prefetch_related' has been applied to a queryset, we need to - # forcibly invalidate the prefetch cache on the instance. - instance._prefetched_objects_cache = {} - - return Response(await data, status=status.HTTP_200_OK) - - async def perform_update(self, serializer): - await serializer.asave() - - async def partial_update(self, request, *args, **kwargs): - kwargs['partial'] = True - return await self.update(request, *args, **kwargs) - - -class DestroyModelMixin: - """ - Destroy a model instance. - """ - async def destroy(self, request, *args, **kwargs): - instance = await self.aget_object() - await self.perform_destroy(instance) - return Response(status=status.HTTP_204_NO_CONTENT) - - async def perform_destroy(self, instance): - await instance.adelete() - -class ReadOnlyModelViewSet(RetrieveModelMixin, - ListModelMixin, - GenericViewSet): +class ReadOnlyModelViewSet( + mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet +): """ A viewset that provides default asynchronous `list()` and `retrieve()` actions. """ + pass -class ModelViewSet(CreateModelMixin, - ListModelMixin, - RetrieveModelMixin, - UpdateModelMixin, - DestroyModelMixin, - GenericViewSet): + +class ModelViewSet( + mixins.CreateModelMixin, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + GenericViewSet, +): """ A viewset that provides default asynchronous `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. """ - pass \ No newline at end of file + + pass diff --git a/pyproject.toml b/pyproject.toml index 54110ca..7ac220b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,3 +27,21 @@ ruff = "^0.5.5" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff.lint] +extend-select = ["I"] + +[tool.coverage.run] +source = ["adrf"] +branch = true +parallel = true + +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError" +] + +[tool.pytest] +addopts="--tb=short --strict-markers -ra" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 59a320a..0000000 --- a/setup.cfg +++ /dev/null @@ -1,28 +0,0 @@ -[metadata] -license_files = LICENSE - -[tool:pytest] -addopts=--tb=short --strict-markers -ra - -[flake8] -ignore = E501,W503,W504 - -[isort] -skip=.tox -atomic=true -multi_line_output=5 -extra_standard_library=types -known_third_party=pytest,_pytest,django -known_first_party=rest_framework,tests - -[coverage:run] -# NOTE: source is ignored with pytest-cov (but uses the same). -source = . -include = adrf/*,tests/* -branch = 1 - -[coverage:report] -include = adrf/*,tests/* -exclude_lines = - pragma: no cover - raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py index bfa3699..4c442be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import django +from django.core.management.commands.makemigrations import Command as MakeMigrations +from django.core.management.commands.migrate import Command as Migrate def pytest_configure(config): @@ -22,7 +24,13 @@ def pytest_configure(config): "BACKEND": "django.template.backends.django.DjangoTemplates", "APP_DIRS": True, "OPTIONS": { - "debug": True, # We want template errors to raise + "debug": True, # We want template errors to raise, + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], }, }, ], @@ -36,6 +44,7 @@ def pytest_configure(config): "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", + "django.contrib.messages", "django.contrib.sessions", "django.contrib.sites", "django.contrib.staticfiles", @@ -44,6 +53,9 @@ def pytest_configure(config): "tests", ), PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",), + DEFAULT_AUTO_FIELD="django.db.models.AutoField", ) django.setup() + MakeMigrations().run_from_argv(["python", "manage.py"]) + Migrate().run_from_argv(["python", "manage.py"]) diff --git a/tests/test_models.py b/tests/models.py similarity index 100% rename from tests/test_models.py rename to tests/models.py index c29ba63..df7e687 100644 --- a/tests/test_models.py +++ b/tests/models.py @@ -1,5 +1,5 @@ -from django.db import models from django.contrib.auth.models import User +from django.db import models class Order(models.Model): diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 144fb37..73ad047 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -2,15 +2,15 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import TestCase, override_settings - -from adrf.views import APIView -from adrf.decorators import api_view from rest_framework import permissions, status -from rest_framework.decorators import permission_classes, authentication_classes from rest_framework.authentication import BaseAuthentication +from rest_framework.decorators import authentication_classes, permission_classes from rest_framework.exceptions import AuthenticationFailed from rest_framework.test import APIRequestFactory +from adrf.decorators import api_view +from adrf.views import APIView + fake = faker.Faker() faked_user = User( diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..8e1ed83 --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,122 @@ +from asgiref.sync import async_to_sync +from django.test import TestCase +from rest_framework import status +from rest_framework.test import APIRequestFactory + +from adrf import generics, serializers + +from .models import Order, User + +factory = APIRequestFactory() + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ("username",) + + +class CreateUserView(generics.CreateAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class ListUserView(generics.ListAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class RetrieveUserView(generics.RetrieveAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class DestroyUserView(generics.DestroyAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class UpdateUserView(generics.UpdateAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class TestCreateUserView(TestCase): + def setUp(self): + self.view = CreateUserView.as_view() + + def test_post_succeeds(self): + request = factory.post("/", {"username": "test"}) + response = async_to_sync(self.view)(request) + expected = {"username": "test"} + assert response.status_code == status.HTTP_201_CREATED + assert response.data == expected + + +class TestListUserView(TestCase): + def setUp(self): + self.view = ListUserView.as_view() + + def test_get_no_users(self): + request = factory.get("/") + response = async_to_sync(self.view)(request) + expected = [] + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_get_one_user(self): + User.objects.create(username="test") + request = factory.get("/") + response = async_to_sync(self.view)(request) + expected = [{"username": "test"}] + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + +class TestRetrieveUserView(TestCase): + def setUp(self): + self.view = RetrieveUserView.as_view() + + def test_get_no_users(self): + request = factory.get("/") + response = async_to_sync(self.view)(request, pk=1) + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_one_user(self): + user = User.objects.create(username="test") + request = factory.get("/") + response = async_to_sync(self.view)(request, pk=user.id) + expected = {"username": "test"} + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + +class TestDestroyUserView(TestCase): + def setUp(self): + self.view = DestroyUserView.as_view() + + def test_delete_no_users(self): + request = factory.delete("/") + response = async_to_sync(self.view)(request, pk=1) + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_one_user(self): + user = User.objects.create(username="test") + Order.objects.create(name="Test order", user=user) + request = factory.delete("/") + response = async_to_sync(self.view)(request, pk=user.id) + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not Order.objects.exists() + + +class TestUpdateUserView(TestCase): + def setUp(self): + self.view = UpdateUserView.as_view() + + def test_update_user(self): + user = User.objects.create(username="test") + request = factory.put("/", data={"username": "not-test"}) + response = async_to_sync(self.view)(request, pk=user.id) + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.username == "not-test" diff --git a/tests/test_object_permissions.py b/tests/test_object_permissions.py index 02d01bd..8d049d0 100755 --- a/tests/test_object_permissions.py +++ b/tests/test_object_permissions.py @@ -1,11 +1,11 @@ from asgiref.sync import sync_to_async from django.http import HttpResponse from django.test import TestCase, override_settings - -from adrf.views import APIView from rest_framework.permissions import BasePermission from rest_framework.test import APIRequestFactory +from adrf.views import APIView + factory = APIRequestFactory() diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 000c23a..51ebfe7 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1,10 +1,10 @@ from django.http import HttpResponse from django.test import TestCase, override_settings - -from adrf.views import APIView from rest_framework.permissions import BasePermission from rest_framework.test import APIRequestFactory +from adrf.views import APIView + factory = APIRequestFactory() diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 6563dc1..266573f 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -2,11 +2,12 @@ from asgiref.sync import sync_to_async from django.test import TestCase - -from adrf.serializers import ModelSerializer, Serializer from rest_framework import serializers from rest_framework.test import APIRequestFactory -from .test_models import User, Order + +from adrf.serializers import ModelSerializer, Serializer + +from .models import Order, User factory = APIRequestFactory() diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py new file mode 100644 index 0000000..02701ff --- /dev/null +++ b/tests/test_shortcuts.py @@ -0,0 +1,28 @@ +from django.http import Http404 +from django.test import TestCase + +from adrf.shortcuts import aget_object_or_404 + +from .models import User + + +class TestAGetObject(TestCase): + async def test_aget_object_or_404_not_a_model_raises(self): + with self.assertRaises(ValueError): + await aget_object_or_404(None, id=1) + + async def test_aget_object_or_404_raises(self): + with self.assertRaises(Http404): + await aget_object_or_404(User, id=1) + + async def test_aget_object_or_404_with_model_succeeds(self): + username = "test" + user = await User.objects.acreate(username=username) + obj = await aget_object_or_404(User, username=username) + assert user == obj + + async def test_aget_object_or_404_with_queryset_succeeds(self): + username = "test" + user = await User.objects.acreate(username=username) + obj = await aget_object_or_404(User.objects.all(), username=username) + assert user == obj diff --git a/tests/test_testmodule.py b/tests/test_testmodule.py index ee3cbb5..ee9539d 100644 --- a/tests/test_testmodule.py +++ b/tests/test_testmodule.py @@ -1,11 +1,11 @@ from django.core.handlers.asgi import ASGIRequest from django.test import TestCase, override_settings from django.urls import path, reverse +from rest_framework import status +from rest_framework.response import Response from adrf.decorators import api_view from adrf.test import AsyncAPIClient, AsyncAPIRequestFactory -from rest_framework import status -from rest_framework.response import Response @api_view(["GET", "POST", "PUT", "PATCH"]) diff --git a/tests/test_throttling.py b/tests/test_throttling.py index 8819f1f..63aab9b 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -1,10 +1,10 @@ from django.http import HttpResponse from django.test import TestCase, override_settings - -from adrf.views import APIView from rest_framework.test import APIRequestFactory from rest_framework.throttling import BaseThrottle +from adrf.views import APIView + factory = APIRequestFactory() diff --git a/tests/test_views.py b/tests/test_views.py index 0b25a4d..0e2b1cc 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,13 +3,13 @@ from asgiref.sync import async_to_sync from django.contrib.auth.models import User from django.test import TestCase - -from adrf.decorators import api_view -from adrf.views import APIView from rest_framework import status from rest_framework.response import Response from rest_framework.test import APIRequestFactory +from adrf.decorators import api_view +from adrf.views import APIView + factory = APIRequestFactory() diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py index 2f1de2b..90c260b 100644 --- a/tests/test_viewsets.py +++ b/tests/test_viewsets.py @@ -1,11 +1,12 @@ from asgiref.sync import async_to_sync from django.contrib.auth.models import User from django.test import TestCase - -from adrf.viewsets import ViewSet from rest_framework import status from rest_framework.response import Response from rest_framework.test import APIRequestFactory + +from adrf.serializers import ModelSerializer +from adrf.viewsets import ModelViewSet, ViewSet from tests.test_views import JSON_ERROR, sanitise_json_error factory = APIRequestFactory() @@ -107,3 +108,56 @@ def test_400_parse_error(self): expected = {"detail": JSON_ERROR} assert response.status_code == status.HTTP_400_BAD_REQUEST assert sanitise_json_error(response.data) == expected + + +class UserSerializer(ModelSerializer): + class Meta: + model = User + fields = ("username",) + + +class UserViewSet(ModelViewSet): + queryset = User.objects.all() + serializer_class = UserSerializer + + +class ModelViewSetIntegrationTests(TestCase): + def setUp(self): + self.list_create = UserViewSet.as_view({"get": "alist", "post": "acreate"}) + self.retrieve_update = UserViewSet.as_view( + {"get": "aretrieve", "put": "aupdate"} + ) + self.destroy = UserViewSet.as_view({"delete": "adestroy"}) + + def test_list_succeeds(self): + User.objects.create(username="test") + request = factory.get("/") + response = async_to_sync(self.list_create)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == [{"username": "test"}] + + def test_create_succeeds(self): + request = factory.post("/", data={"username": "test"}) + response = async_to_sync(self.list_create)(request) + assert response.status_code == status.HTTP_201_CREATED + assert response.data == {"username": "test"} + + def test_retrieve_succeeds(self): + user = User.objects.create(username="test") + request = factory.get("/") + response = async_to_sync(self.retrieve_update)(request, pk=user.id) + assert response.status_code == status.HTTP_200_OK + assert response.data == {"username": "test"} + + def test_update_succeeds(self): + user = User.objects.create(username="test") + request = factory.put("/", data={"username": "not-test"}) + response = async_to_sync(self.retrieve_update)(request, pk=user.id) + assert response.status_code == status.HTTP_200_OK + assert response.data == {"username": "not-test"} + + def test_destroy_succeeds(self): + user = User.objects.create(username="test") + request = factory.delete("/") + response = async_to_sync(self.destroy)(request, pk=user.id) + assert response.status_code == status.HTTP_204_NO_CONTENT diff --git a/tests/urls.py b/tests/urls.py new file mode 100644 index 0000000..637600f --- /dev/null +++ b/tests/urls.py @@ -0,0 +1 @@ +urlpatterns = [] diff --git a/tox.ini b/tox.ini index 7c2835f..7dc16e8 100644 --- a/tox.ini +++ b/tox.ini @@ -24,7 +24,6 @@ ignore_outcome = true [testenv:py312-djangomain] ignore_outcome = true - [testenv:lint] deps = ruff commands = ruff check adrf From 462ceed8448e8c417c05136051c90661cf4d22f0 Mon Sep 17 00:00:00 2001 From: Enrico Massa Date: Sat, 3 Aug 2024 20:38:23 +0800 Subject: [PATCH 7/7] Updated readme with docs on generics --- README.md | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7eb6ac5..76eee05 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ class AsyncSerializer(Serializer): views.py ```python -from . import serializers +from .serializers import AsyncSerializer from adrf.views import APIView class AsyncView(APIView): @@ -149,7 +149,42 @@ class AsyncView(APIView): "password": "test", "age": 10, } - serializer = serializers.AsyncSerializer(data=data) + serializer = AsyncSerializer(data=data) serializer.is_valid() return await serializer.adata ``` + +# Async generics + +models.py + +```python +from django.db import models + +class Order(models.Model): + name = models.TextField() +``` + +serializers.py + +```python +from adrf.serializers import ModelSerializer +from .models import Order + +class OrderSerializer(ModelSerializer): + class Meta: + model = Order + fields = ('name', ) +``` + +views.py + +```python +from adrf.generics import ListCreateAPIView +from .models import Order +from .serializers import OrderSerializer + +class ListCreateOrderView(ListCreateAPIView): + queryset = Order.objects.all() + serializer_class = OrderSerializer +```