diff --git a/docs/faq.md b/docs/faq.md index 0f67edfa..018947b0 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -11,11 +11,13 @@ def resolver(root, info: Info): ## How to access the current user object in resolvers? -The current user object is accessible via the `info.context.request.user` object. +The current user object is accessible via the `get_current_user` method. ```python +from strawberry_django.auth.queries import get_current_user + def resolver(root, info: Info): - current_user = info.context.request.user + current_user = get_current_user(info) ``` ## Autocompletion with editors diff --git a/docs/guide/subscriptions.md b/docs/guide/subscriptions.md index a4b52858..984dad24 100644 --- a/docs/guide/subscriptions.md +++ b/docs/guide/subscriptions.md @@ -1,6 +1,139 @@ -### Subscriptions +# Subscriptions Subscriptions are supported using the [Strawberry Django Channels](https://strawberry.rocks/docs/integrations/channels) integration. -Check its docs to know how to use it. +This guide will give you a minimal working example to get you going. +There are 3 parts to this guide: + +1. Making Django compatible +2. Setup local testing +3. Creating your first subscription + +## Making Django compatible + +It's important to realise that Django doesn't support websockets out of the box. +To resolve this, we can help the platform along a little. + +This implementation is based on Django Channels - this means that should you wish - there is a lot more websockets fun to be had. If you're interested, head over to [Django Channels](https://channels.readthedocs.io). + +To add the base compatibility, go to your `MyProject.asgi.py` file and replace it with the following content. +Ensure that you replace the relevant code with your setup. + +```python +# MyProject.asgi.py +import os + +from django.core.asgi import get_asgi_application +from strawberry_django.routers import AuthGraphQLProtocolTypeRouter + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "MyProject.settings") # CHANGE the project name +django_asgi_app = get_asgi_application() + +# Import your Strawberry schema after creating the django ASGI application +# This ensures django.setup() has been called before any ORM models are imported +# for the schema. + +from .schema import schema # CHANGE path to where you housed your schema file. +application = AuthGraphQLProtocolTypeRouter( + schema, + django_application=django_asgi_app, +) +``` + +Also, ensure that you enable subscriptions on your AsgiGraphQLView in `MyProject.urls.py`: + +```python +... + +urlpatterns = [ + ... + path( + 'graphql/', + AsyncGraphQLView.as_view( + schema=schema, + graphiql=settings.DEBUG, + subscriptions_enabled=True, + ), + ), + ... +] + +``` + +Note, django-channels allows for a lot more complexity. Here we merely cover the basic framework to get subscriptions to run on Django with minimal effort. Should you be interested in discovering the far more advanced capabilities of Dango channels, head over to [channels docs](https://channels.readthedocs.io) + +## Setup local testing + +The classic `./manage.py runserver` will not support subscriptions as it runs on WSGI mode. However, Django has ASGI server support out of the box through Daphne, which will override the runserver command to support our desired ASGI support. + +There are other asgi servers available, such as Uvicorn and Hypercorn. For the sake of simplicity we'll use Daphne as it comes with the runserver override. [Django Docs](https://docs.djangoproject.com/en/4.2/howto/deployment/asgi/daphne/) This shouldn't stop you from using any of the other ASGI flavours in production or local testing like Uvicorn or Hypercorn + +To get started: Firstly, we need install Daphne to handle the workload, so let's install it: + +```bash +pip install daphne +``` + +Secondly, we need to add `daphne` to your settings.py file before 'django.contrib.staticfiles' + +```python +INSTALLED_APPS = [ + ... + 'daphne', + 'django.contrib.staticfiles', + ... +] +``` + +and add your `ASGI_APPLICATION` setting in your settings.py + +```python +# settings.py +... +ASGI_APPLICATION = 'MyProject.asgi.application' +... +``` + +Now you can run your test-server like as usual, but with ASGI support: + +```bash +./manage.py runserver +``` + +## Creating your first subscription + +Once you've taken care of those 2 setup steps, your first subscription is a breeze. +Go and edit your schema-file and add: + +```python +import asyncio +import strawberry + +@strawberry.type +class Subscription: + @strawberry.subscription + async def count(self, target: int = 100) -> int: + for i in range(target): + yield i + await asyncio.sleep(0.5) +``` + +That's pretty much it for this basic start. +See for yourself by running your test server `./manange.py runserver` and opening `http://127.0.0.1:8000/graphql/` in your browser. Now run: + +```graphql +subscription { + count(target: 10) +} +``` + +You should see something like (where the count changes every .5s to a max of 9) + +```json +{ + "data": { + "count": 9 + } +} +``` diff --git a/docs/guide/types.md b/docs/guide/types.md index f4cdae7c..fe582767 100644 --- a/docs/guide/types.md +++ b/docs/guide/types.md @@ -103,12 +103,15 @@ You can use that `info` parameter to, for example, limit access to results based on the current user in the request: ```{.python title=types.py} +from stawberry_django.auth.utils import get_current_user + @strawberry.django.type(models.Fruit) class Berry: @classmethod def get_queryset(cls, queryset, info, **kwargs): - if not info.context.request.user.is_staff: + user = get_current_user(info) + if not user.is_staff: # Restrict access to top secret berries if the user is not a staff member queryset = queryset.filter(is_top_secret=False) return queryset.filter(name__contains="berry") diff --git a/docs/guide/unit-testing.md b/docs/guide/unit-testing.md index c59eddc9..9ff87bb6 100644 --- a/docs/guide/unit-testing.md +++ b/docs/guide/unit-testing.md @@ -15,7 +15,7 @@ from strawberry_django.test.client import TestClient def test_me_unauthenticated(db): client = TestClient("/graphql") - res = gql_client.query(""" + res = client.query(""" query TestQuery { me { pk @@ -31,18 +31,20 @@ def test_me_unauthenticated(db): def test_me_authenticated(db): user = User.objects.create(...) - client = TestClient("/graphql") - res = client.query(""" - query TestQuery { - me { - pk - email - firstName - lastName - } - } - """) + + with client.login(user): + res = client.query(""" + query TestQuery { + me { + pk + email + firstName + lastName + } + } + """) + assert res.errors is None assert res.data == { "me": { @@ -53,3 +55,5 @@ def test_me_authenticated(db): }, } ``` + +For more information how to apply these tests, take a look at the (source)[https://github.com/strawberry-graphql/strawberry-graphql-django/blob/main/strawberry_django/test/client.py] and (this example)[https://github.com/strawberry-graphql/strawberry-graphql-django/blob/main/tests/test_permissions.py#L49] diff --git a/poetry.lock b/poetry.lock index 6c837b41..e1073502 100644 --- a/poetry.lock +++ b/poetry.lock @@ -96,6 +96,25 @@ files = [ {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, ] +[[package]] +name = "channels" +version = "4.0.0" +description = "Brings async, event-driven capabilities to Django 3.2 and up." +optional = false +python-versions = ">=3.7" +files = [ + {file = "channels-4.0.0-py3-none-any.whl", hash = "sha256:2253334ac76f67cba68c2072273f7e0e67dbdac77eeb7e318f511d2f9a53c5e4"}, + {file = "channels-4.0.0.tar.gz", hash = "sha256:0ce53507a7da7b148eaa454526e0e05f7da5e5d1c23440e4886cf146981d8420"}, +] + +[package.dependencies] +asgiref = ">=3.5.0,<4" +Django = ">=3.2" + +[package.extras] +daphne = ["daphne (>=4.0.0)"] +tests = ["async-timeout", "coverage (>=4.5,<5.0)", "pytest", "pytest-asyncio", "pytest-django"] + [[package]] name = "charset-normalizer" version = "3.3.0" @@ -1554,4 +1573,4 @@ enum = ["django-choices-field"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "3bcfa81251075157ab063bca118124b644f61cfa942f470697e5aa67c99b3ea1" +content-hash = "d18b33ac9945953a4030cf71cf44306a22a2ca5ca91b5ee314bf84a66f7a7c8b" diff --git a/pyproject.toml b/pyproject.toml index 82ad27f5..c822a44e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ Markdown = "^3.3.7" Pygments = "^2.15.1" factory-boy = "^3.2.1" django-guardian = "^2.4.0" +channels = { version = ">=3.0.5" } [tool.poetry.extras] debug-toolbar = ["django-debug-toolbar"] diff --git a/strawberry_django/auth/queries.py b/strawberry_django/auth/queries.py index 8ccd3872..bd947079 100644 --- a/strawberry_django/auth/queries.py +++ b/strawberry_django/auth/queries.py @@ -1,10 +1,17 @@ +from strawberry.types import Info + import strawberry_django +from .utils import get_current_user + -def resolve_current_user(info): - if not info.context.request.user.is_authenticated: +def resolve_current_user(info: Info): + user = get_current_user(info) + + if not getattr(user, "is_authenticated", False): return None - return info.context.request.user + + return user def current_user(): diff --git a/strawberry_django/auth/utils.py b/strawberry_django/auth/utils.py new file mode 100644 index 00000000..3feec201 --- /dev/null +++ b/strawberry_django/auth/utils.py @@ -0,0 +1,56 @@ +from typing import Literal, Optional, overload + +from asgiref.sync import sync_to_async +from strawberry.types import Info + +from strawberry_django.utils.typing import UserType + + +@overload +def get_current_user(info: Info, *, strict: Literal[True]) -> UserType: ... + + +@overload +def get_current_user(info: Info, *, strict: bool = False) -> Optional[UserType]: ... + + +def get_current_user(info: Info, *, strict: bool = False) -> Optional[UserType]: + """Get and return the current user based on various scenarios.""" + try: + user = info.context.request.user + except AttributeError: + try: + # queries/mutations in ASGI move the user into consumer scope + user = info.context.get("request").consumer.scope["user"] + except AttributeError: + # websockets / subscriptions move scope inside of the request + user = info.context.get("request").scope.get("user") + + if user is None: + raise ValueError("No user found in the current request") + + # Access an attribute inside the user object to force loading it in async contexts. + if user is not None: + _ = user.is_authenticated + + return user + + +@overload +async def aget_current_user( + info: Info, + *, + strict: Literal[True], +) -> UserType: ... + + +@overload +async def aget_current_user( + info: Info, + *, + strict: bool = False, +) -> Optional[UserType]: ... + + +async def aget_current_user(info: Info, *, strict: bool = False) -> Optional[UserType]: + return await sync_to_async(get_current_user)(info, strict=strict) diff --git a/strawberry_django/permissions.py b/strawberry_django/permissions.py index fa7fddc3..1c62b9d2 100644 --- a/strawberry_django/permissions.py +++ b/strawberry_django/permissions.py @@ -41,6 +41,7 @@ from strawberry.union import StrawberryUnion from typing_extensions import Literal, Self, assert_never +from strawberry_django.auth.utils import get_current_user from strawberry_django.fields.types import OperationInfo, OperationMessage from strawberry_django.resolvers import django_resolver @@ -289,16 +290,18 @@ def resolve( info: Info, **kwargs: Dict[str, Any], ) -> Any: - user = info.context.request.user + user = get_current_user(info) + try: from .integrations.guardian import get_user_or_anonymous except (ImportError, RuntimeError): # pragma: no cover pass else: - user = get_user_or_anonymous(user) + user = user and get_user_or_anonymous(user) # make sure the user is loaded - user.is_anonymous # noqa: B018 + if user is not None: + user.is_authenticated # noqa: B018 try: retval = self.resolve_for_user( @@ -319,13 +322,14 @@ async def resolve_async( info: Info, **kwargs: Dict[str, Any], ) -> Any: - user = info.context.request.user + user = get_current_user(info) + try: from .integrations.guardian import get_user_or_anonymous except (ImportError, RuntimeError): # pragma: no cover pass else: - user = await sync_to_async(get_user_or_anonymous)(user) + user = user and await sync_to_async(get_user_or_anonymous)(user) # make sure the user is loaded await sync_to_async(getattr)(user, "is_anonymous") @@ -406,7 +410,7 @@ def handle_no_permission(self, exception: BaseException, *, info: Info): def resolve_for_user( # pragma: no cover self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, @@ -425,12 +429,12 @@ class IsAuthenticated(DjangoPermissionExtension): def resolve_for_user( self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, ): - if not user.is_authenticated or not user.is_active: + if user is None or not user.is_authenticated or not user.is_active: raise DjangoNoPermission return resolver() @@ -448,12 +452,16 @@ class IsStaff(DjangoPermissionExtension): def resolve_for_user( self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, ): - if not user.is_authenticated or not getattr(user, "is_staff", False): + if ( + user is None + or not user.is_authenticated + or not getattr(user, "is_staff", False) + ): raise DjangoNoPermission return resolver() @@ -471,12 +479,16 @@ class IsSuperuser(DjangoPermissionExtension): def resolve_for_user( self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, ): - if not user.is_authenticated or not getattr(user, "is_superuser", False): + if ( + user is None + or not user.is_authenticated + or not getattr(user, "is_superuser", False) + ): raise DjangoNoPermission return resolver() @@ -694,12 +706,12 @@ class AutoDirective: def resolve_for_user( self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, ): - if self.with_anonymous and user.is_anonymous: + if user is None or self.with_anonymous and user.is_anonymous: raise DjangoNoPermission if ( @@ -719,11 +731,14 @@ def resolve_for_user( def resolve_for_user_with_perms( self, resolver: Callable, - user: UserType, + user: Optional[UserType], *, info: Info, source: Any, ): + if user is None: + raise DjangoNoPermission + if self.target == PermTarget.GLOBAL: if not self._has_perm(source, user, info=info): raise DjangoNoPermission diff --git a/strawberry_django/routers.py b/strawberry_django/routers.py new file mode 100644 index 00000000..87978b25 --- /dev/null +++ b/strawberry_django/routers.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from channels.auth import AuthMiddlewareStack +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.security.websocket import AllowedHostsOriginValidator +from django.urls import URLPattern, URLResolver, re_path +from strawberry.channels.handlers.http_handler import GraphQLHTTPConsumer +from strawberry.channels.handlers.ws_handler import GraphQLWSConsumer + +if TYPE_CHECKING: + from django.core.handlers.asgi import ASGIHandler + from strawberry.schema import BaseSchema + + +class AuthGraphQLProtocolTypeRouter(ProtocolTypeRouter): + """Convenience class to set up GraphQL on both HTTP and Websocket. + + This convenience class will include AuthMiddlewareStack and the + AllowedHostsOriginValidator to ensure you have user object available. + + ``` + from strawberry_django.routers import AuthGraphQLProtocolTypeRouter + from django.core.asgi import get_asgi_application. + + django_asgi = get_asgi_application() + + from myapi import schema + + application = AuthGraphQLProtocolTypeRouter( + schema, + django_application=django_asgi, + ) + ``` + + This will route all requests to /graphql on either HTTP or websockets to us, + and everything else to the Django application. + """ + + def __init__( + self, + schema: BaseSchema, + django_application: ASGIHandler | None = None, + url_pattern: str = "^graphql", + ): + http_urls: list[URLPattern | URLResolver] = [ + re_path(url_pattern, GraphQLHTTPConsumer.as_asgi(schema=schema)), + ] + if django_application is not None: + http_urls.append(re_path(r"^", django_application)) + + super().__init__( + { + "http": AuthMiddlewareStack( + URLRouter( + http_urls, + ), + ), + "websocket": AllowedHostsOriginValidator( + AuthMiddlewareStack( + URLRouter( + [ + re_path( + url_pattern, + GraphQLWSConsumer.as_asgi(schema=schema), + ), + ], + ), + ), + ), + }, + ) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index df5f53db..5afdc005 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -24,6 +24,7 @@ import strawberry_django from strawberry_django import mutations +from strawberry_django.auth.queries import get_current_user from strawberry_django.fields.types import ListInput, NodeInput, NodeInputPartial from strawberry_django.mutations import resolvers from strawberry_django.optimizer import DjangoOptimizerExtension @@ -379,7 +380,7 @@ class Query: @strawberry_django.field def me(self, info: Info) -> Optional[UserType]: - user = info.context.request.user + user = get_current_user(info, strict=True) if not user.is_authenticated: return None