diff --git a/tests/litestar/app.py b/tests/litestar/app.py index 77f222287d..3ed6945564 100644 --- a/tests/litestar/app.py +++ b/tests/litestar/app.py @@ -3,7 +3,7 @@ from litestar import Litestar, Request from litestar.di import Provide from strawberry.litestar import make_graphql_controller -from tests.litestar.schema import schema +from tests.views.schema import schema def custom_context_dependency() -> str: diff --git a/tests/litestar/schema.py b/tests/litestar/schema.py deleted file mode 100644 index 71b69ad464..0000000000 --- a/tests/litestar/schema.py +++ /dev/null @@ -1,157 +0,0 @@ -import asyncio -import typing -from enum import Enum -from typing import Any, Optional - -from graphql import GraphQLError - -import strawberry -from strawberry.file_uploads import Upload -from strawberry.permission import BasePermission -from strawberry.subscriptions.protocols.graphql_transport_ws.types import PingMessage - - -class AlwaysFailPermission(BasePermission): - message = "You are not authorized" - - def has_permission( - self, source: Any, info: strawberry.Info, **kwargs: typing.Any - ) -> bool: - return False - - -@strawberry.enum -class Flavor(Enum): - VANILLA = "vanilla" - STRAWBERRY = "strawberry" - CHOCOLATE = "chocolate" - - -@strawberry.input -class FolderInput: - files: typing.List[Upload] - - -@strawberry.type -class DebugInfo: - num_active_result_handlers: int - is_connection_init_timeout_task_done: typing.Optional[bool] - - -@strawberry.type -class Query: - @strawberry.field - def hello(self, name: typing.Optional[str] = None) -> str: - return f"Hello {name or 'world'}" - - @strawberry.field - async def async_hello(self, name: str, delay: float = 0) -> str: - await asyncio.sleep(delay) - return f"Hello {name or 'world'}" - - @strawberry.field(permission_classes=[AlwaysFailPermission]) - def always_fail(self) -> Optional[str]: - return "Hey" - - @strawberry.field - def root_name(root) -> str: - return type(root).__name__ - - @strawberry.field - async def exception(self, message: str) -> str: - raise ValueError(message) - - -@strawberry.type -class Mutation: - @strawberry.mutation - async def hello(self) -> str: - return "strawberry" - - @strawberry.mutation - async def read_text(self, text_file: Upload) -> str: - return (await text_file.read()).decode() - - @strawberry.mutation - async def read_files(self, files: typing.List[Upload]) -> typing.List[str]: - contents = [] - for file in files: - content = (await file.read()).decode() - contents.append(content) - return contents - - @strawberry.mutation - async def read_folder(self, folder: FolderInput) -> typing.List[str]: - contents = [] - for file in folder.files: - content = (await file.read()).decode() - contents.append(content) - return contents - - -@strawberry.type -class Subscription: - @strawberry.subscription - async def echo( - self, message: str, delay: float = 0 - ) -> typing.AsyncGenerator[str, None]: - await asyncio.sleep(delay) - yield message - - @strawberry.subscription - async def request_ping( - self, info: strawberry.Info - ) -> typing.AsyncGenerator[bool, None]: - ws = info.context["ws"] - await ws.send_json(PingMessage().as_dict()) - yield True - - @strawberry.subscription - async def infinity(self, message: str) -> typing.AsyncGenerator[str, None]: - while True: - yield message - await asyncio.sleep(1) - - @strawberry.subscription - async def context(self, info: strawberry.Info) -> typing.AsyncGenerator[str, None]: - yield info.context["custom_value"] - - @strawberry.subscription - async def error(self, message: str) -> typing.AsyncGenerator[str, None]: - yield GraphQLError(message) # type: ignore - - @strawberry.subscription - async def exception(self, message: str) -> typing.AsyncGenerator[str, None]: - raise ValueError(message) - - # Without this yield, the method is not recognised as an async generator - yield "Hi" - - @strawberry.subscription - async def flavors(self) -> typing.AsyncGenerator[Flavor, None]: - yield Flavor.VANILLA - yield Flavor.STRAWBERRY - yield Flavor.CHOCOLATE - - @strawberry.subscription - async def debug( - self, info: strawberry.Info - ) -> typing.AsyncGenerator[DebugInfo, None]: - active_result_handlers = [ - task for task in info.context["get_tasks"]() if not task.done() - ] - - connection_init_timeout_task = info.context["connectionInitTimeoutTask"] - is_connection_init_timeout_task_done = ( - connection_init_timeout_task.done() - if connection_init_timeout_task - else None - ) - - yield DebugInfo( - num_active_result_handlers=len(active_result_handlers), - is_connection_init_timeout_task_done=is_connection_init_timeout_task_done, - ) - - -schema = strawberry.Schema(Query, mutation=Mutation, subscription=Subscription)