diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..e07695618b --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,23 @@ +Release type: patch + +Adds the ability to use the `&` and `|` operators on permissions to form boolean logic. For example, if you want +a field to be accessible with either the `IsAdmin` or `IsOwner` permission you +could define the field as follows: + +```python +import strawberry +from strawberry.permission import PermissionExtension, BasePermission + + +@strawberry.type +class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(IsAdmin() | IsOwner())], fail_silently=True + ) + ] + ) + def name(self) -> str: + return "ABC" +``` diff --git a/docs/guides/permissions.md b/docs/guides/permissions.md index 76529a6900..a0ee8c3946 100644 --- a/docs/guides/permissions.md +++ b/docs/guides/permissions.md @@ -209,6 +209,31 @@ consider if it is possible to use alternative solutions like the `@skip` or without permission. Check the GraphQL documentation for more information on [directives](https://graphql.org/learn/queries/#directives). +## Boolean Operations + +When using the `PermissionExtension`, it is possible to combine permissions +using the `&` and `|` operators to form boolean logic. For example, if you want +a field to be accessible with either the `IsAdmin` or `IsOwner` permission you +could define the field as follows: + +```python +import strawberry +from strawberry.permission import PermissionExtension, BasePermission + + +@strawberry.type +class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(IsAdmin() | IsOwner())], fail_silently=True + ) + ] + ) + def name(self) -> str: + return "ABC" +``` + ## Customizable Error Handling To customize the error handling, the `on_unauthorized` method on the diff --git a/strawberry/permission.py b/strawberry/permission.py index 8e7dabaab1..254e3a12bb 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -8,12 +8,15 @@ TYPE_CHECKING, Any, Awaitable, - Dict, List, + Literal, Optional, + Tuple, Type, + TypedDict, Union, ) +from typing_extensions import deprecated from strawberry.exceptions import StrawberryGraphQLError from strawberry.exceptions.permission_fail_silently_requires_optional import ( @@ -35,6 +38,15 @@ from strawberry.types import Info +def unpack_maybe( + value: Union[object, Tuple[bool, object]], default: object = None +) -> Tuple[object, object]: + if isinstance(value, tuple) and len(value) == 2: + return value + else: + return value, default + + class BasePermission(abc.ABC): """ Base class for creating permissions @@ -50,18 +62,41 @@ class BasePermission(abc.ABC): @abc.abstractmethod def has_permission( - self, source: Any, info: Info, **kwargs: Any - ) -> Union[bool, Awaitable[bool]]: + self, source: Any, info: Info, **kwargs: object + ) -> Union[ + bool, + Awaitable[bool], + Tuple[Literal[False], dict], + Awaitable[Tuple[Literal[False], dict]], + ]: + """ + This method is a required override in the permission class. It checks if the user has the necessary permissions to access a specific field. + + The method should return a boolean value: + - True: The user has the necessary permissions. + - False: The user does not have the necessary permissions. In this case, the `on_unauthorized` method will be invoked. + + Avoid raising exceptions in this method. Instead, use the `on_unauthorized` method to handle errors and customize the error response. + + If there's a need to pass additional information to the `on_unauthorized` method, return a tuple. The first element should be False, and the second element should be a dictionary containing the additional information. + + Args: + source (Any): The source field that the permission check is being performed on. + info (Info): The GraphQL resolve info associated with the field. + **kwargs (Any): Additional arguments that are typically passed to the field resolver. + + Returns: + bool or tuple: Returns True if the user has the necessary permissions. Returns False or a tuple (False, additional_info) if the user does not have the necessary permissions. In the latter case, the `on_unauthorized` method will be invoked. + """ raise NotImplementedError( "Permission classes should override has_permission method" ) - def on_unauthorized(self) -> None: + def on_unauthorized(self, **kwargs: object) -> None: """ Default error raising for permissions. This can be overridden to customize the behavior. """ - # Instantiate error class error = self.error_class(self.message or "") @@ -74,6 +109,9 @@ def on_unauthorized(self) -> None: raise error @property + @deprecated( + "@schema_directive is deprecated and will be disabled by default on 31.12.2024 with future removal planned. Use the new @permissions directive instead." + ) def schema_directive(self) -> object: if not self._schema_directive: @@ -89,6 +127,111 @@ class AutoDirective: return self._schema_directive + @cached_property + def is_async(self) -> bool: + return iscoroutinefunction(self.has_permission) + + def __and__(self, other: BasePermission): + return AndPermission([self, other]) + + def __or__(self, other: BasePermission): + return OrPermission([self, other]) + + +class CompositePermissionContext(TypedDict): + failed_permissions: List[Tuple[BasePermission, dict]] + + +class CompositePermission(BasePermission, abc.ABC): + def __init__(self, child_permissions: List[BasePermission]): + self.child_permissions = child_permissions + + def on_unauthorized(self, **kwargs: object) -> Any: + failed_permissions = kwargs.get("failed_permissions", []) + for permission, context in failed_permissions: + permission.on_unauthorized(**context) + + @cached_property + def is_async(self) -> bool: + return any(x.is_async for x in self.child_permissions) + + +class AndPermission(CompositePermission): + def has_permission( + self, source: Any, info: Info, **kwargs: object + ) -> Union[ + bool, + Awaitable[bool], + Tuple[Literal[False], CompositePermissionContext], + Awaitable[Tuple[Literal[False], CompositePermissionContext]], + ]: + if self.is_async: + return self._has_permission_async(source, info, **kwargs) + + for permission in self.child_permissions: + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) + if not has_permission: + return False, {"failed_permissions": [(permission, context)]} + return True + + async def _has_permission_async( + self, source: Any, info: Info, **kwargs: object + ) -> Union[bool, Tuple[Literal[False], CompositePermissionContext]]: + for permission in self.child_permissions: + permission_response = await await_maybe( + permission.has_permission(source, info, **kwargs) + ) + has_permission, context = unpack_maybe(permission_response, {}) + if not has_permission: + return False, {"failed_permissions": [(permission, context)]} + return True + + def __and__(self, other: BasePermission): + return AndPermission([*self.child_permissions, other]) + + +class OrPermission(CompositePermission): + def has_permission( + self, source: Any, info: Info, **kwargs: object + ) -> Union[ + bool, + Awaitable[bool], + Tuple[Literal[False], dict], + Awaitable[Tuple[Literal[False], dict]], + ]: + if self.is_async: + return self._has_permission_async(source, info, **kwargs) + failed_permissions = [] + for permission in self.child_permissions: + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) + if has_permission: + return True + failed_permissions.append((permission, context)) + + return False, {"failed_permissions": failed_permissions} + + async def _has_permission_async( + self, source: Any, info: Info, **kwargs: object + ) -> Union[bool, Tuple[Literal[False], dict]]: + failed_permissions = [] + for permission in self.child_permissions: + permission_response = await await_maybe( + permission.has_permission(source, info, **kwargs) + ) + has_permission, context = unpack_maybe(permission_response, {}) + if has_permission: + return True + failed_permissions.append((permission, context)) + + return False, {"failed_permissions": failed_permissions} + + def __or__(self, other: BasePermission): + return OrPermission([*self.child_permissions, other]) + class PermissionExtension(FieldExtension): """ @@ -100,8 +243,8 @@ class PermissionExtension(FieldExtension): NOTE: Currently, this is automatically added to the field, when using - field.permission_classes - This is deprecated behavior, please manually add the extension to field.extensions + field.permission_classes. You are free to use whichever method you prefer. + Use PermissionExtension if you want additional customization. """ def __init__( @@ -117,12 +260,16 @@ def __init__( def apply(self, field: StrawberryField) -> None: """ - Applies all of the permission directives to the schema + Applies all the permission directives to the schema and sets up silent permissions """ if self.use_directives: field.directives.extend( - p.schema_directive for p in self.permissions if p.schema_directive + [ + p.schema_directive + for p in self.permissions + if not isinstance(p, CompositePermission) + ] ) # We can only fail silently if the field is optional or a list if self.fail_silently: @@ -132,28 +279,36 @@ def apply(self, field: StrawberryField) -> None: elif isinstance(field.type, StrawberryList): self.return_empty_list = True else: - errror = PermissionFailSilentlyRequiresOptionalError(field) - raise errror + raise PermissionFailSilentlyRequiresOptionalError(field) - def _on_unauthorized(self, permission: BasePermission) -> Any: + def _on_unauthorized(self, permission: BasePermission, **kwargs: object) -> Any: if self.fail_silently: return [] if self.return_empty_list else None - return permission.on_unauthorized() + + if kwargs in (None, {}): + return permission.on_unauthorized() + return permission.on_unauthorized(**kwargs) def resolve( self, next_: SyncExtensionResolver, source: Any, info: Info, - **kwargs: Dict[str, Any], + **kwargs: object[str, Any], ) -> Any: """ Checks if the permission should be accepted and raises an exception if not """ + for permission in self.permissions: - if not permission.has_permission(source, info, **kwargs): - return self._on_unauthorized(permission) + has_permission, context = unpack_maybe( + permission.has_permission(source, info, **kwargs), {} + ) + + if not has_permission: + return self._on_unauthorized(permission, **context) + return next_(source, info, **kwargs) async def resolve_async( @@ -161,15 +316,21 @@ async def resolve_async( next_: AsyncExtensionResolver, source: Any, info: Info, - **kwargs: Dict[str, Any], + **kwargs: object[str, Any], ) -> Any: for permission in self.permissions: - has_permission = await await_maybe( + permission_response = await await_maybe( permission.has_permission(source, info, **kwargs) ) + context = {} + if isinstance(permission_response, tuple): + has_permission, context = permission_response + else: + has_permission = permission_response + if not has_permission: - return self._on_unauthorized(permission) + return self._on_unauthorized(permission, **context) next = next_(source, info, **kwargs) if inspect.isasyncgen(next): return next @@ -179,9 +340,4 @@ async def resolve_async( def supports_sync(self) -> bool: """The Permission extension always supports async checking using await_maybe, but only supports sync checking if there are no async permissions""" - async_permissions = [ - True - for permission in self.permissions - if iscoroutinefunction(permission.has_permission) - ] - return len(async_permissions) == 0 + return all(not permission.is_async for permission in self.permissions) diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 0cec4cde60..3fe5b11638 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -54,6 +54,522 @@ def user(self) -> str: # pragma: no cover assert result.errors[0].message == "User is not authenticated" +@pytest.mark.asyncio +async def test_no_graphql_error_when_and_permission_is_allowed(): + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class TrueAsyncPermission(BasePermission): + message = "True Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension(permissions=[(TruePermission() & TruePermission())]) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TrueAsyncPermission() & TrueAsyncPermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.data["user"] == "patrick" + + query = "{ userAsync }" + result = await schema.execute(query) + assert result.data["userAsync"] == "patrick" + + +@pytest.mark.asyncio +async def test_raises_graphql_error_when_right_and_permission_is_denied(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class FalseAsyncPermission(BasePermission): + message = "False Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TruePermission() & FalsePermission())] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TruePermission() & FalseAsyncPermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.errors[0].message == "False Permission Failed" + + query = "{ userAsync }" + result = await schema.execute(query) + assert result.errors[0].message == "False Permission Failed" + + +@pytest.mark.asyncio +async def test_raises_graphql_error_when_nested(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class FalseAsyncPermission(BasePermission): + message = "False Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[ + ( + (TruePermission() & FalsePermission()) + | FalseAsyncPermission() + ) + & TruePermission() + ] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = await schema.execute(query) + assert result.errors[0].message == "False Permission Failed" + + +@pytest.mark.asyncio +async def test_raises_graphql_error_when_left_and_permission_is_denied(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: # pragma: no cover + return True + + class FalseAsyncPermission(BasePermission): + message = "False Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalsePermission() & TruePermission())] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalseAsyncPermission() & TruePermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.errors[0].message == "False Permission Failed" + + query = "{ userAsync }" + result = await schema.execute(query) + assert result.errors[0].message == "False Permission Failed" + + +@pytest.mark.asyncio +async def test_raises_graphql_error_from_left_exception_when_both_and_permission_is_denied(): + class FalseLeftPermission(BasePermission): + message = "False Left Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class FalseRightPermission(BasePermission): + message = "False Right Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: # pragma: no cover + return False + + class FalseRightAsyncPermission(BasePermission): + message = "False Right Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: # pragma: no cover + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalseLeftPermission() & FalseRightPermission())] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalseLeftPermission() & FalseRightAsyncPermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.errors[0].message == "False Left Permission Failed" + + query = "{ userAsync }" + + result = await schema.execute(query) + assert result.errors[0].message == "False Left Permission Failed" + + +@pytest.mark.asyncio +async def test_no_graphql_error_when_both_or_permission_is_allowed(): + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class TrueAsyncPermission(BasePermission): + message = "True Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension(permissions=[(TruePermission() | TruePermission())]) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TrueAsyncPermission() | TrueAsyncPermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.data["user"] == "patrick" + + query = "{ userAsync }" + + result = await schema.execute(query) + assert result.data["userAsync"] == "patrick" + + +@pytest.mark.asyncio +async def test_no_graphql_error_when_left_or_permission_is_allowed(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: # pragma: no cover + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class TrueAsyncPermission(BasePermission): + message = "True Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TruePermission() | FalsePermission())] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(TrueAsyncPermission() | FalsePermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.data["user"] == "patrick" + + query = "{ userAsync }" + + result = await schema.execute(query) + assert result.data["userAsync"] == "patrick" + + +@pytest.mark.asyncio +async def test_no_graphql_error_when_right_or_permission_is_allowed(): + class FalsePermission(BasePermission): + message = "False Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class TruePermission(BasePermission): + message = "True Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + class TrueAsyncPermission(BasePermission): + message = "True Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return True + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalsePermission()) | TruePermission()] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalsePermission()) | TrueAsyncPermission()] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.data["user"] == "patrick" + query = "{ userAsync }" + + result = await schema.execute(query) + assert result.data["userAsync"] == "patrick" + + +@pytest.mark.asyncio +async def test_raises_graphql_error_from_left_exception_when_both_or_permission_is_denied(): + class FalseLeftPermission(BasePermission): + message = "False Left Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class FalseRightPermission(BasePermission): + message = "False Right Permission Failed" + + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: + return False + + class FalseRightAsyncPermission(BasePermission): + message = "False Right Permission Failed" + + async def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs: typing.Any + ) -> bool: # pragma: no cover + return False + + @strawberry.type + class Query: + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalseLeftPermission() & FalseRightPermission())] + ) + ] + ) + def user(self) -> str: # pragma: no cover + return "patrick" + + @strawberry.field( + extensions=[ + PermissionExtension( + permissions=[(FalseLeftPermission() & FalseRightAsyncPermission())] + ) + ] + ) + def user_async(self) -> str: # pragma: no cover + return "patrick" + + schema = strawberry.Schema(query=Query) + + query = "{ user }" + + result = schema.execute_sync(query) + assert result.errors[0].message == "False Left Permission Failed" + + query = "{ userAsync }" + + result = await schema.execute(query) + assert result.errors[0].message == "False Left Permission Failed" + + @pytest.mark.asyncio async def test_raises_permission_error_for_subscription(): class IsAdmin(BasePermission):