diff --git a/.github/release-check-action/release.py b/.github/release-check-action/release.py index 93b7c119fd..d8d7487353 100644 --- a/.github/release-check-action/release.py +++ b/.github/release-check-action/release.py @@ -35,10 +35,10 @@ def get_release_info(file_path: Path) -> ReleaseInfo: match = RELEASE_TYPE_REGEX.match(line) if not match: - raise InvalidReleaseFileError() + raise InvalidReleaseFileError change_type_key = match.group(1) change_type = ChangeType[change_type_key.upper()] - changelog = "".join([line for line in f.readlines()]).strip() + changelog = "".join(f.readlines()).strip() return ReleaseInfo(change_type, changelog) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51f803fcf5..77cae569e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.8.5 hooks: - id: ruff-format exclude: ^tests/\w+/snapshots/ diff --git a/federation-compatibility/schema.py b/federation-compatibility/schema.py index 4952977ee3..7502f59d93 100644 --- a/federation-compatibility/schema.py +++ b/federation-compatibility/schema.py @@ -279,7 +279,7 @@ def resolve_reference(cls, **data: Any) -> Optional["Product"]: return get_product_by_sku_and_variation( sku=data["sku"], variation=data["variation"] ) - elif "package" in data: + if "package" in data: return get_product_by_sku_and_package( sku=data["sku"], package=data["package"] ) diff --git a/pyproject.toml b/pyproject.toml index d03033ade7..00099cdb89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,125 +228,69 @@ src = ["strawberry", "tests"] select = ["ALL"] ignore = [ # https://github.com/astral-sh/ruff/pull/4427 - # equivalent to keep-runtime-typing + # equivalent to keep-runtime-typing. We might want to enable those + # after we drop support for Python 3.9 "UP006", "UP007", - "TID252", # we use asserts in tests and to hint mypy "S101", - "S102", - "S104", - "S324", - # definitely enable these, maybe not in tests + + # Allow "Any" for annotations. We have too many Any annotations and some + # are legit. Maybe reconsider in the future, except for tests? "ANN401", - "PGH003", - "PGH004", - "RET504", - "RET505", - "RET506", - "RET507", - "RET503", - "BLE001", - "B008", - "N811", - "N804", + + # Allow our exceptions to have names that don't end in "Error". Maybe refactor + # in the future? But that would be a breaking change. "N818", + + # Allow "type: ignore" without rule code. Because we support both mypy and + # pyright, and they have different codes for the same error, we can't properly + # fix those issues. + "PGH003", + # Variable `T` in function should be lowercase # this seems a potential bug or opportunity for improvement in ruff "N806", - # first argument should named self (found in tests) - "N805", - - "N815", - # shadowing builtins "A001", "A002", "A003", + # Unused arguments "ARG001", "ARG002", "ARG003", "ARG004", "ARG005", + + # Boolean positional arguments "FBT001", "FBT002", "FBT003", - "PT001", - "PT023", - - # this is pretty much handled by black - "E501", - - # enable these, we have some in tests - "B006", - "PT007", - "PT011", - "PT012", - "PT015", - "PT017", - "C414", - "N802", - - "SIM117", - "SIM102", - - "F841", - "B027", - "B905", - "ISC001", - - # same? - "S105", - "S106", + # Too many arguments/branches/return statements + "PLR0913", + "PLR0912", + "PLR0911", - "DTZ003", - "DTZ005", + # Do not force adding _co to covariant typevars + "PLC0105", - "RSE102", + # Allow private access to attributes "SLF001", - # in tests - "DTZ001", - - "EM101", - "EM102", - "EM103", - - "B904", - "B019", - - "N801", - "N807", - - # pandas - "PD", - - "RUF012", - "PLC0105", - "FA102", - # code complexity - "C", "C901", - # trailing commas - "COM812", - - "PLR", - "INP", - "TRY", - "SIM300", - "SIM114", - - "DJ008", + # Allow todo/fixme/etc comments "TD002", "TD003", "FIX001", "FIX002", + + # We don't want to add "from __future__ mport annotations" everywhere "FA100", # Docstrings, maybe to enable later @@ -359,29 +303,72 @@ ignore = [ "D106", "D107", "D412", + + # Allow to define exceptions text in the exception body + "TRY003", + "EM101", + "EM102", + "EM103", + + # Allow comparisons with magic numbers + "PLR2004", + + # Allow methods to use lru_cache + "B019", + + # Don't force if branches to be converted to "or" + "SIM114", + + # ruff formatter recommends to disable those, as they conflict with it + # we don't need to ever enable those. + "COM812", + "COM819", + "D206", + "E111", + "E114", + "E117", + "E501", + "ISC001", + "Q000", + "Q001", + "Q002", + "Q003", + "W191", ] [tool.ruff.lint.per-file-ignores] -"strawberry/schema/types/concrete_type.py" = ["TCH002"] +".github/*" = ["INP001"] +"federation-compatibility/*" = ["INP001"] +"strawberry/cli/*" = ["B008"] +"strawberry/extensions/tracing/__init__.py" = ["TCH004"] +"strawberry/fastapi/*" = ["B008"] +"strawberry/annotation.py" = ["RET505"] "tests/*" = [ - "RSE102", - "SLF001", - "TCH001", - "TCH002", - "TCH003", "ANN001", "ANN201", "ANN202", "ANN204", - "PLW0603", + "B008", + "B018", + "D", + "DTZ001", + "DTZ005", + "FA102", + "N805", "PLC1901", + "PLR2004", + "PLW0603", + "PT011", + "RUF012", + "S105", + "S106", "S603", "S607", - "B018", - "D", + "TCH001", + "TCH002", + "TCH003", + "TRY002", ] -"strawberry/extensions/tracing/__init__.py" = ["TCH004"] -"tests/http/clients/__init__.py" = ["F401"] [tool.ruff.lint.isort] known-first-party = ["strawberry"] diff --git a/strawberry/aiohttp/test/client.py b/strawberry/aiohttp/test/client.py index 86b08bf341..75b1fda4d6 100644 --- a/strawberry/aiohttp/test/client.py +++ b/strawberry/aiohttp/test/client.py @@ -57,13 +57,11 @@ async def request( headers: Optional[dict[str, object]] = None, files: Optional[dict[str, object]] = None, ) -> Any: - response = await self._client.post( + return await self._client.post( self.url, json=body if not files else None, data=body if files else None, ) - return response - __all__ = ["GraphQLTestClient"] diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 443264ecd4..eee14dff5d 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -99,12 +99,12 @@ async def iter_json( if ws_message.type == http.WSMsgType.TEXT: try: yield self.view.decode_json(ws_message.data) - except JSONDecodeError: + except JSONDecodeError as e: if not ignore_parsing_errors: - raise NonJsonMessageReceived() + raise NonJsonMessageReceived from e elif ws_message.type == http.WSMsgType.BINARY: - raise NonTextMessageReceived() + raise NonTextMessageReceived async def send_json(self, message: Mapping[str, object]) -> None: try: diff --git a/strawberry/annotation.py b/strawberry/annotation.py index d3fb65deba..29ffa6ee85 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -107,9 +107,7 @@ def evaluate(self) -> type: if isinstance(annotation, str): annotation = ForwardRef(annotation) - evaled_type = eval_type(annotation, self.namespace, None) - - return evaled_type + return eval_type(annotation, self.namespace, None) def _get_type_with_args( self, evaled_type: type[Any] @@ -155,13 +153,13 @@ def _resolve(self) -> Union[StrawberryType, type]: # a StrawberryType if self._is_enum(evaled_type): return self.create_enum(evaled_type) - elif self._is_optional(evaled_type, args): + if self._is_optional(evaled_type, args): return self.create_optional(evaled_type) - elif self._is_union(evaled_type, args): + if self._is_union(evaled_type, args): return self.create_union(evaled_type, args) - elif is_type_var(evaled_type) or evaled_type is Self: + if is_type_var(evaled_type) or evaled_type is Self: return self.create_type_var(cast(TypeVar, evaled_type)) - elif self._is_strawberry_type(evaled_type): + if self._is_strawberry_type(evaled_type): # Simply return objects that are already StrawberryTypes return evaled_type diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 2fd33a4210..1a1845d39f 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -97,11 +97,11 @@ async def iter_json( try: text = await self.ws.receive_text() yield self.view.decode_json(text) - except JSONDecodeError: # noqa: PERF203 + except JSONDecodeError as e: # noqa: PERF203 if not ignore_parsing_errors: - raise NonJsonMessageReceived() - except KeyError: - raise NonTextMessageReceived() + raise NonJsonMessageReceived from e + except KeyError as e: + raise NonTextMessageReceived from e except WebSocketDisconnect: # pragma: no cover pass diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 54992b1d44..b76d858f79 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -45,13 +45,13 @@ async def iter_json( break if message["message"] is None: - raise NonTextMessageReceived() + raise NonTextMessageReceived try: yield self.view.decode_json(message["message"]) - except json.JSONDecodeError: + except json.JSONDecodeError as e: if not ignore_parsing_errors: - raise NonJsonMessageReceived() + raise NonJsonMessageReceived from e async def send_json(self, message: Mapping[str, object]) -> None: serialized_message = self.view.encode_json(message) diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index be3c276fdf..f1807ca52b 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -55,7 +55,7 @@ def __init__( path: str, headers: Optional[list[tuple[bytes, bytes]]] = None, protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL, - connection_params: dict = {}, + connection_params: dict | None = None, **kwargs: Any, ) -> None: """Create a new communicator. @@ -69,6 +69,8 @@ def __init__( subprotocols: an ordered list of preferred subprotocols to be sent to the server. **kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor. """ + if connection_params is None: + connection_params = {} self.protocol = protocol subprotocols = kwargs.get("subprotocols", []) subprotocols.append(protocol) diff --git a/strawberry/cli/__init__.py b/strawberry/cli/__init__.py index 6dbaaca5f5..ec5448e0e8 100644 --- a/strawberry/cli/__init__.py +++ b/strawberry/cli/__init__.py @@ -1,10 +1,12 @@ try: from .app import app - from .commands.codegen import codegen as codegen # noqa - from .commands.export_schema import export_schema as export_schema # noqa - from .commands.schema_codegen import schema_codegen as schema_codegen # noqa - from .commands.server import server as server # noqa - from .commands.upgrade import upgrade as upgrade # noqa + from .commands.codegen import codegen as codegen # noqa: PLC0414 + from .commands.export_schema import export_schema as export_schema # noqa: PLC0414 + from .commands.schema_codegen import ( + schema_codegen as schema_codegen, # noqa: PLC0414 + ) + from .commands.server import server as server # noqa: PLC0414 + from .commands.upgrade import upgrade as upgrade # noqa: PLC0414 def run() -> None: app() diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index 50c3c3327b..cabe02e0f3 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -39,23 +39,21 @@ def _import_plugin(plugin: str) -> Optional[type[QueryCodegenPlugin]]: assert _is_codegen_plugin(obj) return obj - else: + + symbols = { + key: value for key, value in module.__dict__.items() if not key.startswith("__") + } + + if "__all__" in module.__dict__: symbols = { - key: value - for key, value in module.__dict__.items() - if not key.startswith("__") + name: symbol + for name, symbol in symbols.items() + if name in module.__dict__["__all__"] } - if "__all__" in module.__dict__: - symbols = { - name: symbol - for name, symbol in symbols.items() - if name in module.__dict__["__all__"] - } - - for obj in symbols.values(): - if _is_codegen_plugin(obj): - return obj + for obj in symbols.values(): + if _is_codegen_plugin(obj): + return obj return None diff --git a/strawberry/cli/commands/server.py b/strawberry/cli/commands/server.py index ecdd4e0f10..ba4e7e8a0a 100644 --- a/strawberry/cli/commands/server.py +++ b/strawberry/cli/commands/server.py @@ -25,7 +25,7 @@ class LogLevel(str, Enum): @app.command(help="Starts debug server") def server( schema: str, - host: str = typer.Option("0.0.0.0", "-h", "--host", show_default=True), + host: str = typer.Option("0.0.0.0", "-h", "--host", show_default=True), # noqa: S104 port: int = typer.Option(8000, "-p", "--port", show_default=True), log_level: LogLevel = typer.Option( "error", @@ -60,7 +60,7 @@ def server( "install them by running:\n" r"pip install 'strawberry-graphql\[debug-server]'" ) - raise typer.Exit(1) + raise typer.Exit(1) # noqa: B904 load_schema(schema, app_dir=app_dir) diff --git a/strawberry/cli/commands/upgrade/_run_codemod.py b/strawberry/cli/commands/upgrade/_run_codemod.py index abd6e6e8b1..6beaecf94c 100644 --- a/strawberry/cli/commands/upgrade/_run_codemod.py +++ b/strawberry/cli/commands/upgrade/_run_codemod.py @@ -41,9 +41,8 @@ def _execute_transform_wrap( additional_kwargs["scratch"] = {} # TODO: maybe capture warnings? - with open(os.devnull, "w") as null: # noqa: PTH123 - with contextlib.redirect_stderr(null): - return _execute_transform(**job, **additional_kwargs) + with open(os.devnull, "w") as null, contextlib.redirect_stderr(null): # noqa: PTH123 + return _execute_transform(**job, **additional_kwargs) def _get_progress_and_pool( diff --git a/strawberry/cli/utils/__init__.py b/strawberry/cli/utils/__init__.py index d54dfefebd..ed79738e6b 100644 --- a/strawberry/cli/utils/__init__.py +++ b/strawberry/cli/utils/__init__.py @@ -16,7 +16,7 @@ def load_schema(schema: str, app_dir: str) -> Schema: message = str(exc) rich.print(f"[red]Error: {message}") - raise typer.Exit(2) + raise typer.Exit(2) # noqa: B904 if not isinstance(schema_symbol, Schema): message = "The `schema` must be an instance of strawberry.Schema" diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index ab7ba5e00c..7eeb41111d 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -3,7 +3,7 @@ import textwrap from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -35,7 +35,7 @@ class PythonType: class PythonPlugin(QueryCodegenPlugin): - SCALARS_TO_PYTHON_TYPES: dict[str, PythonType] = { + SCALARS_TO_PYTHON_TYPES: ClassVar[dict[str, PythonType]] = { "ID": PythonType("str"), "Int": PythonType("int"), "String": PythonType("str"), @@ -128,7 +128,7 @@ def _print_argument_value(self, argval: GraphQLArgumentValue) -> str: + ", ".join(self._print_argument_value(v) for v in argval.values) + "]" ) - elif isinstance(argval.values, dict): + if isinstance(argval.values, dict): return ( "{" + ", ".join( @@ -137,8 +137,7 @@ def _print_argument_value(self, argval: GraphQLArgumentValue) -> str: ) + "}" ) - else: - raise TypeError(f"Unrecognized values type: {argval}") + raise TypeError(f"Unrecognized values type: {argval}") if isinstance(argval, GraphQLEnumValue): # This is an enum. It needs the namespace alongside the name. if argval.enum_type is None: diff --git a/strawberry/codegen/plugins/typescript.py b/strawberry/codegen/plugins/typescript.py index 057afd4870..ae6359cb90 100644 --- a/strawberry/codegen/plugins/typescript.py +++ b/strawberry/codegen/plugins/typescript.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -20,7 +20,7 @@ class TypeScriptPlugin(QueryCodegenPlugin): - SCALARS_TO_TS_TYPE = { + SCALARS_TO_TS_TYPE: ClassVar[dict[str | type, str]] = { "ID": "string", "Int": "number", "String": "string", @@ -102,6 +102,7 @@ def _print_scalar_type(self, type_: GraphQLScalar) -> str: if type_.name in self.SCALARS_TO_TS_TYPE: return "" + assert type_.python_type is not None return f"type {type_.name} = {self.SCALARS_TO_TS_TYPE[type_.python_type]}" def _print_union_type(self, type_: GraphQLUnion) -> str: diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 675820e72d..8f0ecffef7 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -202,7 +202,7 @@ def _get_deps(t: GraphQLType) -> Iterable[GraphQLType]: yield from _get_deps(gql_type) else: # Want to make sure that all types are covered. - raise ValueError(f"Unknown GraphQLType: {t}") + raise ValueError(f"Unknown GraphQLType: {t}") # noqa: TRY004 _TYPE_TO_GRAPHQL_TYPE = { @@ -249,14 +249,15 @@ def _sort_types(self, types: list[GraphQLType]) -> list[GraphQLType]: def type_cmp(t1: GraphQLType, t2: GraphQLType) -> int: """Compare the types.""" if t1 is t2: - return 0 - - if t1 in _get_deps(t2): - return -1 + retval = 0 + elif t1 in _get_deps(t2): + retval = -1 elif t2 in _get_deps(t1): - return 1 + retval = 1 else: - return 0 + retval = 0 + + return retval return sorted(types, key=cmp_to_key(type_cmp)) @@ -311,15 +312,15 @@ def run(self, query: str) -> CodegenResult: operations = self._get_operations(ast) if not operations: - raise NoOperationProvidedError() + raise NoOperationProvidedError if len(operations) > 1: - raise MultipleOperationsProvidedError() + raise MultipleOperationsProvidedError operation = operations[0] if operation.name is None: - raise NoOperationNameProvidedError() + raise NoOperationNameProvidedError # Look for any free-floating fragments and create types out of them # These types can then be referenced and included later via the @@ -550,7 +551,7 @@ def _get_field_type( if isinstance(field_type, ScalarDefinition): return self._collect_scalar(field_type, None) - elif isinstance(field_type, EnumDefinition): + if isinstance(field_type, EnumDefinition): return self._collect_enum(field_type) raise ValueError(f"Unsupported type: {field_type}") # pragma: no cover diff --git a/strawberry/codemods/annotated_unions.py b/strawberry/codemods/annotated_unions.py index 096c78ac0a..e60ad601dd 100644 --- a/strawberry/codemods/annotated_unions.py +++ b/strawberry/codemods/annotated_unions.py @@ -50,7 +50,7 @@ def __init__( super().__init__(context) - def visit_Module(self, node: cst.Module) -> Optional[bool]: + def visit_Module(self, node: cst.Module) -> Optional[bool]: # noqa: N802 self._is_using_named_import = False return super().visit_Module(node) diff --git a/strawberry/codemods/update_imports.py b/strawberry/codemods/update_imports.py index 24e34164f2..bb58323218 100644 --- a/strawberry/codemods/update_imports.py +++ b/strawberry/codemods/update_imports.py @@ -126,11 +126,9 @@ def _update_strawberry_type_imports( return updated_node - def leave_ImportFrom( + def leave_ImportFrom( # noqa: N802 self, node: cst.ImportFrom, updated_node: cst.ImportFrom ) -> cst.ImportFrom: updated_node = self._update_imports(updated_node, updated_node) updated_node = self._update_types_types_imports(updated_node, updated_node) - updated_node = self._update_strawberry_type_imports(updated_node, updated_node) - - return updated_node + return self._update_strawberry_type_imports(updated_node, updated_node) diff --git a/strawberry/dataloader.py b/strawberry/dataloader.py index c40d274b37..ff4f313ffc 100644 --- a/strawberry/dataloader.py +++ b/strawberry/dataloader.py @@ -240,7 +240,7 @@ async def dispatch_batch(loader: DataLoader, batch: Batch) -> None: values = list(values) if len(values) != len(batch): - raise WrongNumberOfResultsReturned( + raise WrongNumberOfResultsReturned( # noqa: TRY301 expected=len(batch), received=len(values) ) @@ -254,7 +254,7 @@ async def dispatch_batch(loader: DataLoader, batch: Batch) -> None: task.future.set_exception(value) else: task.future.set_result(value) - except Exception as e: + except Exception as e: # noqa: BLE001 for task in batch.tasks: task.future.set_exception(e) diff --git a/strawberry/django/__init__.py b/strawberry/django/__init__.py index c28422e346..d6537eab39 100644 --- a/strawberry/django/__init__.py +++ b/strawberry/django/__init__.py @@ -12,9 +12,9 @@ def __getattr__(name: str) -> Any: import_symbol = f"{__name__}.{name}" try: return importlib.import_module(import_symbol) - except ModuleNotFoundError: + except ModuleNotFoundError as e: raise AttributeError( f"Attempted import of {import_symbol} failed. Make sure to install the" "'strawberry-graphql-django' package to use the Strawberry Django " "extension API." - ) + ) from e diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 97a0955b75..c7b01f6d59 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -45,8 +45,7 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE - - from ..schema import BaseSchema + from strawberry.schema import BaseSchema # TODO: remove this and unify temporal responses @@ -266,7 +265,7 @@ class AsyncGraphQLView( request_adapter_class = AsyncDjangoHTTPRequestAdapter @classonlymethod # pyright: ignore[reportIncompatibleMethodOverride] - def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: + def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N805 # This code tells django that this view is async, see docs here: # https://docs.djangoproject.com/en/3.1/topics/async/#async-views diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py index acd3cc8f89..6cc21160a4 100644 --- a/strawberry/exceptions/__init__.py +++ b/strawberry/exceptions/__init__.py @@ -50,7 +50,7 @@ class UnallowedReturnTypeForUnion(Exception): def __init__( self, field_name: str, result_type: str, allowed_types: set[GraphQLObjectType] ) -> None: - formatted_allowed_types = list(sorted(type_.name for type_ in allowed_types)) + formatted_allowed_types = sorted(type_.name for type_ in allowed_types) message = ( f'The type "{result_type}" of the field "{field_name}" ' @@ -160,7 +160,9 @@ class StrawberryGraphQLError(GraphQLError): class ConnectionRejectionError(Exception): """Use it when you want to reject a WebSocket connection.""" - def __init__(self, payload: dict[str, object] = {}) -> None: + def __init__(self, payload: dict[str, object] | None = None) -> None: + if payload is None: + payload = {} self.payload = payload diff --git a/strawberry/exceptions/exception.py b/strawberry/exceptions/exception.py index f75ce488ef..7bfc8e2443 100644 --- a/strawberry/exceptions/exception.py +++ b/strawberry/exceptions/exception.py @@ -67,7 +67,7 @@ def __rich__(self) -> Optional[RenderableType]: from rich.panel import Panel if self.exception_source is None: - raise UnableToFindExceptionSource() from self + raise UnableToFindExceptionSource from self content = ( self.__rich_header__, diff --git a/strawberry/exceptions/permission_fail_silently_requires_optional.py b/strawberry/exceptions/permission_fail_silently_requires_optional.py index 9a922e2f8e..ca3a7de713 100644 --- a/strawberry/exceptions/permission_fail_silently_requires_optional.py +++ b/strawberry/exceptions/permission_fail_silently_requires_optional.py @@ -7,7 +7,8 @@ from .utils.source_finder import SourceFinder if TYPE_CHECKING: - from ..field import StrawberryField + from strawberry.field import StrawberryField + from .exception_source import ExceptionSource diff --git a/strawberry/exceptions/utils/source_finder.py b/strawberry/exceptions/utils/source_finder.py index a24d4cc776..3f98afd42d 100644 --- a/strawberry/exceptions/utils/source_finder.py +++ b/strawberry/exceptions/utils/source_finder.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, cast -from ..exception_source import ExceptionSource +from strawberry.exceptions.exception_source import ExceptionSource if TYPE_CHECKING: from collections.abc import Sequence diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index b2166c2bf9..cf64d06322 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -119,7 +119,7 @@ def get_fields_map_for_v2() -> dict[Any, Any]: class PydanticV2Compat: @property - def PYDANTIC_MISSING_TYPE(self) -> Any: + def PYDANTIC_MISSING_TYPE(self) -> Any: # noqa: N802 from pydantic_core import PydanticUndefined return PydanticUndefined @@ -155,7 +155,7 @@ def get_basic_type(self, type_: Any) -> type[Any]: type_ = self.fields_map[type_] if type_ is None: - raise UnsupportedTypeError() + raise UnsupportedTypeError if is_new_type(type_): return new_type_supertype(type_) @@ -168,7 +168,7 @@ def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]: class PydanticV1Compat: @property - def PYDANTIC_MISSING_TYPE(self) -> Any: + def PYDANTIC_MISSING_TYPE(self) -> Any: # noqa: N802 return dataclasses.MISSING def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]: @@ -231,7 +231,7 @@ def get_basic_type(self, type_: Any) -> type[Any]: type_ = self.fields_map[type_] if type_ is None: - raise UnsupportedTypeError() + raise UnsupportedTypeError if is_new_type(type_): return new_type_supertype(type_) diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index 5296f60acc..2f97fcf236 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -101,17 +101,17 @@ def convert_pydantic_model_to_strawberry_class( def convert_strawberry_class_to_pydantic_model(obj: type) -> Any: if hasattr(obj, "to_pydantic"): return obj.to_pydantic() - elif dataclasses.is_dataclass(obj): + if dataclasses.is_dataclass(obj): result = [] for f in dataclasses.fields(obj): value = convert_strawberry_class_to_pydantic_model(getattr(obj, f.name)) result.append((f.name, value)) return dict(result) - elif isinstance(obj, (list, tuple)): + if isinstance(obj, (list, tuple)): # Assume we can create an object of this type by passing in a # generator (which is not true for namedtuples, not supported). return type(obj)(convert_strawberry_class_to_pydantic_model(v) for v in obj) - elif isinstance(obj, dict): + if isinstance(obj, dict): return type(obj)( ( convert_strawberry_class_to_pydantic_model(k), @@ -119,5 +119,4 @@ def convert_strawberry_class_to_pydantic_model(obj: type) -> Any: ) for k, v in obj.items() ) - else: - return copy.deepcopy(obj) + return copy.deepcopy(obj) diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index fe6b863431..447dcd9e6a 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -32,8 +32,7 @@ def replace_pydantic_types(type_: Any, is_input: bool) -> Any: attr = "_strawberry_input_type" if is_input else "_strawberry_type" if hasattr(type_, attr): return getattr(type_, attr) - else: - raise UnregisteredTypeException(type_) + raise UnregisteredTypeException(type_) return type_ diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index caf8571b87..ac8958a4aa 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -115,7 +115,7 @@ def _build_dataclass_creation_fields( ) -def type( +def type( # noqa: PLR0915 model: builtins.type[PydanticModel], *, fields: Optional[list[str]] = None, @@ -127,7 +127,7 @@ def type( all_fields: bool = False, use_pydantic_alias: bool = True, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: - def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: + def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: # noqa: PLR0915 compat = PydanticCompat.from_model(model) model_fields = compat.get_model_fields(model) original_fields_set = set(fields) if fields else set() diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index acc9eba635..23f80fe6ea 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -47,8 +47,7 @@ def normalize_type(type_: type) -> Any: def get_strawberry_type_from_model(type_: Any) -> Any: if hasattr(type_, "_strawberry_type"): return type_._strawberry_type - else: - raise UnregisteredTypeException(type_) + raise UnregisteredTypeException(type_) def get_private_fields(cls: type) -> list[dataclasses.Field]: @@ -97,9 +96,7 @@ def get_default_factory_for_field( # if we have a default_factory, we should return it if has_factory: - default_factory = cast("NoArgAnyCallable", default_factory) - - return default_factory + return cast("NoArgAnyCallable", default_factory) # if we have a default, we should return it if has_default: @@ -108,8 +105,7 @@ def get_default_factory_for_field( # printing the value. if isinstance(default, BaseModel): return lambda: compat.model_dump(default) - else: - return lambda: smart_deepcopy(default) + return lambda: smart_deepcopy(default) # if we don't have default or default_factory, but the field is not required, # we should return a factory that returns None @@ -131,5 +127,3 @@ def ensure_all_auto_fields_in_pydantic( raise AutoFieldsNotInBaseModelError( fields=non_existing_fields, cls_name=cls_name, model=model ) - else: - return diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index 797428019f..a06657b776 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -111,9 +111,7 @@ def lazy_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type: return AnyType(TypeOfAny.special_form) type_name = ctx.type.args[0] - type_ = ctx.api.analyze_type(type_name) - - return type_ + return ctx.api.analyze_type(type_name) def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface) -> Any: @@ -147,14 +145,12 @@ def _get_type_for_expr(expr: Expression, api: SemanticAnalyzerPluginInterface) - if isinstance(expr, MemberExpr): if expr.fullname: return _get_named_type(expr.fullname, api) - else: - raise InvalidNodeTypeException(expr) + raise InvalidNodeTypeException(expr) if isinstance(expr, CallExpr): if expr.analyzed: return _get_type_for_expr(expr.analyzed, api) - else: - raise InvalidNodeTypeException(expr) + raise InvalidNodeTypeException(expr) if isinstance(expr, CastExpr): return expr.type @@ -178,8 +174,6 @@ def create_type_hook(ctx: DynamicClassDefContext) -> None: SymbolTableNode(GDEF, type_alias, plugin_generated=True), ) - return - def union_hook(ctx: DynamicClassDefContext) -> None: try: @@ -343,13 +337,12 @@ def add_static_method_to_class( cls.defs.body.remove(sym.node) # For compat with mypy < 0.93 - if MypyVersion.VERSION < Decimal("0.93"): + if Decimal("0.93") > MypyVersion.VERSION: function_type = api.named_type("__builtins__.function") + elif isinstance(api, SemanticAnalyzerPluginInterface): + function_type = api.named_type("builtins.function") else: - if isinstance(api, SemanticAnalyzerPluginInterface): - function_type = api.named_type("builtins.function") - else: - function_type = api.named_generic_type("builtins.function", []) + function_type = api.named_generic_type("builtins.function", []) arg_types, arg_names, arg_kinds = [], [], [] for arg in args: diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index 040ce83143..baef50bae1 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -77,7 +77,7 @@ def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]: f"{extension} defines both legacy and new style extension hooks for " "{self.HOOK_NAME}" ) - elif is_legacy: + if is_legacy: warnings.warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3) return self.from_legacy(extension, on_start, on_end) @@ -128,19 +128,17 @@ async def iterator() -> AsyncIterator: return WrappedHook(extension=extension, hook=iterator, is_async=True) - else: + @contextlib.contextmanager + def iterator_async() -> Iterator[None]: + if on_start: + on_start() - @contextlib.contextmanager - def iterator_async() -> Iterator[None]: - if on_start: - on_start() - - yield + yield - if on_end: - on_end() + if on_end: + on_end() - return WrappedHook(extension=extension, hook=iterator_async, is_async=False) + return WrappedHook(extension=extension, hook=iterator_async, is_async=False) @staticmethod def from_callable( @@ -155,14 +153,13 @@ async def iterator() -> AsyncIterator[None]: yield return WrappedHook(extension=extension, hook=iterator, is_async=True) - else: - @contextlib.contextmanager - def iterator() -> Iterator[None]: - func(extension) - yield + @contextlib.contextmanager # type: ignore[no-redef] + def iterator() -> Iterator[None]: + func(extension) + yield - return WrappedHook(extension=extension, hook=iterator, is_async=False) + return WrappedHook(extension=extension, hook=iterator, is_async=False) def __enter__(self) -> None: self.exit_stack = contextlib.ExitStack() @@ -175,8 +172,7 @@ def __enter__(self) -> None: f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} " "failed to complete synchronously." ) - else: - self.exit_stack.enter_context(hook.hook()) # type: ignore + self.exit_stack.enter_context(hook.hook()) # type: ignore def __exit__( self, diff --git a/strawberry/extensions/field_extension.py b/strawberry/extensions/field_extension.py index 6683247a98..7f97c07222 100644 --- a/strawberry/extensions/field_extension.py +++ b/strawberry/extensions/field_extension.py @@ -99,62 +99,61 @@ def build_field_extension_resolvers( f"Please add a resolve_async method to the extension(s)." ) return _get_async_resolvers(field.extensions) - else: - # Try to wrap all sync resolvers in async so that we can use async extensions - # on sync fields. This is not possible the other way around since - # the result of an async resolver would have to be awaited before calling - # the sync extension, making it impossible for the extension to modify - # any arguments. - non_sync_extensions = [ - extension for extension in field.extensions if not extension.supports_sync - ] - - if len(non_sync_extensions) == 0: - # Resolve everything sync - return _get_sync_resolvers(field.extensions) - - # We have async-only extensions and need to wrap the resolver - # That means we can't have sync-only extensions after the first async one - - # Check if we have a chain of sync-compatible - # extensions before the async extensions - # -> S-S-S-S-A-A-A-A - found_sync_extensions = 0 - - # All sync only extensions must be found before the first async-only one - found_sync_only_extensions = 0 - for extension in field.extensions: - # ...A, abort - if extension in non_sync_extensions: - break - # ...S - if extension in non_async_extensions: - found_sync_only_extensions += 1 - found_sync_extensions += 1 - - # Length of the chain equals length of non async extensions - # All sync extensions run first - if len(non_async_extensions) == found_sync_only_extensions: - # Prepend sync to async extension to field extensions - return list( - itertools.chain( - _get_sync_resolvers(field.extensions[:found_sync_extensions]), - [SyncToAsyncExtension().resolve_async], - _get_async_resolvers(field.extensions[found_sync_extensions:]), - ) - ) + # Try to wrap all sync resolvers in async so that we can use async extensions + # on sync fields. This is not possible the other way around since + # the result of an async resolver would have to be awaited before calling + # the sync extension, making it impossible for the extension to modify + # any arguments. + non_sync_extensions = [ + extension for extension in field.extensions if not extension.supports_sync + ] - # Some sync extensions follow the first async-only extension. Error case - async_extension_names = ",".join( - [extension.__class__.__name__ for extension in non_sync_extensions] - ) - raise TypeError( - f"Cannot mix async-only extension(s) {async_extension_names} " - f"with sync-only extension(s) {non_async_extension_names} " - f"on Field {field.name}. " - f"If possible try to change the execution order so that all sync-only " - f"extensions are executed first." + if len(non_sync_extensions) == 0: + # Resolve everything sync + return _get_sync_resolvers(field.extensions) + + # We have async-only extensions and need to wrap the resolver + # That means we can't have sync-only extensions after the first async one + + # Check if we have a chain of sync-compatible + # extensions before the async extensions + # -> S-S-S-S-A-A-A-A + found_sync_extensions = 0 + + # All sync only extensions must be found before the first async-only one + found_sync_only_extensions = 0 + for extension in field.extensions: + # ...A, abort + if extension in non_sync_extensions: + break + # ...S + if extension in non_async_extensions: + found_sync_only_extensions += 1 + found_sync_extensions += 1 + + # Length of the chain equals length of non async extensions + # All sync extensions run first + if len(non_async_extensions) == found_sync_only_extensions: + # Prepend sync to async extension to field extensions + return list( + itertools.chain( + _get_sync_resolvers(field.extensions[:found_sync_extensions]), + [SyncToAsyncExtension().resolve_async], + _get_async_resolvers(field.extensions[found_sync_extensions:]), + ) ) + # Some sync extensions follow the first async-only extension. Error case + async_extension_names = ",".join( + [extension.__class__.__name__ for extension in non_sync_extensions] + ) + raise TypeError( + f"Cannot mix async-only extension(s) {async_extension_names} " + f"with sync-only extension(s) {non_async_extension_names} " + f"on Field {field.name}. " + f"If possible try to change the execution order so that all sync-only " + f"extensions are executed first." + ) + __all__ = ["FieldExtension"] diff --git a/strawberry/extensions/query_depth_limiter.py b/strawberry/extensions/query_depth_limiter.py index ef801120ae..e80b5b1ea8 100644 --- a/strawberry/extensions/query_depth_limiter.py +++ b/strawberry/extensions/query_depth_limiter.py @@ -189,18 +189,17 @@ def resolve_field_value( ) -> FieldArgumentType: if isinstance(value, StringValueNode): return value.value - elif isinstance(value, IntValueNode): + if isinstance(value, IntValueNode): return int(value.value) - elif isinstance(value, FloatValueNode): + if isinstance(value, FloatValueNode): return float(value.value) - elif isinstance(value, BooleanValueNode): + if isinstance(value, BooleanValueNode): return value.value - elif isinstance(value, ListValueNode): + if isinstance(value, ListValueNode): return [resolve_field_value(v) for v in value.values] - elif isinstance(value, ObjectValueNode): + if isinstance(value, ObjectValueNode): return {v.name.value: resolve_field_value(v.value) for v in value.fields} - else: - return {} + return {} def get_field_arguments( @@ -250,20 +249,18 @@ def determine_depth( return 0 return 1 + max( - map( - lambda selection: determine_depth( - node=selection, - fragments=fragments, - depth_so_far=depth_so_far + 1, - max_depth=max_depth, - context=context, - operation_name=operation_name, - should_ignore=should_ignore, - ), - node.selection_set.selections, + determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far + 1, + max_depth=max_depth, + context=context, + operation_name=operation_name, + should_ignore=should_ignore, ) + for selection in node.selection_set.selections ) - elif isinstance(node, FragmentSpreadNode): + if isinstance(node, FragmentSpreadNode): return determine_depth( node=fragments[node.name.value], fragments=fragments, @@ -273,25 +270,22 @@ def determine_depth( operation_name=operation_name, should_ignore=should_ignore, ) - elif isinstance( + if isinstance( node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode) ): return max( - map( - lambda selection: determine_depth( - node=selection, - fragments=fragments, - depth_so_far=depth_so_far, - max_depth=max_depth, - context=context, - operation_name=operation_name, - should_ignore=should_ignore, - ), - node.selection_set.selections, + determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far, + max_depth=max_depth, + context=context, + operation_name=operation_name, + should_ignore=should_ignore, ) + for selection in node.selection_set.selections ) - else: - raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover + raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover def is_ignored(node: FieldNode, ignore: Optional[list[IgnoreType]] = None) -> bool: diff --git a/strawberry/extensions/tracing/datadog.py b/strawberry/extensions/tracing/datadog.py index 02a722f29f..1269dfdfd8 100644 --- a/strawberry/extensions/tracing/datadog.py +++ b/strawberry/extensions/tracing/datadog.py @@ -67,7 +67,7 @@ def create_span(self, lifecycle_step, name, **kwargs): ) def hash_query(self, query: str) -> str: - return hashlib.md5(query.encode("utf-8")).hexdigest() + return hashlib.md5(query.encode("utf-8")).hexdigest() # noqa: S324 def on_operation(self) -> Iterator[None]: self._operation_name = self.execution_context.operation_name diff --git a/strawberry/extensions/tracing/opentelemetry.py b/strawberry/extensions/tracing/opentelemetry.py index 686d311c96..56c301b7ca 100644 --- a/strawberry/extensions/tracing/opentelemetry.py +++ b/strawberry/extensions/tracing/opentelemetry.py @@ -45,7 +45,7 @@ def __init__( ) -> None: self._arg_filter = arg_filter self._tracer = trace.get_tracer("strawberry") - self._span_holder = dict() + self._span_holder = {} if execution_context: self.execution_context = execution_context @@ -116,18 +116,17 @@ def convert_to_allowed_types(self, value: Any) -> Any: # Put these in decreasing order of use-cases to exit as soon as possible if isinstance(value, (bool, str, bytes, int, float)): return value - elif isinstance(value, (list, tuple, range)): + if isinstance(value, (list, tuple, range)): return self.convert_list_or_tuple_to_allowed_types(value) - elif isinstance(value, dict): + if isinstance(value, dict): return self.convert_dict_to_allowed_types(value) - elif isinstance(value, (set, frozenset)): + if isinstance(value, (set, frozenset)): return self.convert_set_to_allowed_types(value) - elif isinstance(value, complex): + if isinstance(value, complex): return str(value) # Convert complex numbers to strings - elif isinstance(value, (bytearray, memoryview)): + if isinstance(value, (bytearray, memoryview)): return bytes(value) # Convert bytearray and memoryview to bytes - else: - return str(value) + return str(value) def convert_set_to_allowed_types(self, value: Union[set, frozenset]) -> str: return ( @@ -192,9 +191,7 @@ def resolve( **kwargs: Any, ) -> Any: if should_skip_tracing(_next, info): - result = _next(root, info, *args, **kwargs) - - return result + return _next(root, info, *args, **kwargs) with self._tracer.start_as_current_span( f"GraphQL Resolving: {info.field_name}", @@ -203,9 +200,7 @@ def resolve( ), ) as span: self.add_tags(span, info, kwargs) - result = _next(root, info, *args, **kwargs) - - return result + return _next(root, info, *args, **kwargs) __all__ = ["OpenTelemetryExtension", "OpenTelemetryExtensionSync"] diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 28d158bc4d..f43c9ecb0c 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -92,10 +92,9 @@ async def dependency( **default_context, **custom_context, } - elif custom_context is None: + if custom_context is None: return default_context - else: - raise InvalidCustomContext() + raise InvalidCustomContext # replace the signature parameters of dependency... # ...with the old parameters minus the first argument as it will be replaced... diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index a8acc0e289..bfd8d7663b 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -172,7 +172,7 @@ def entities_resolver( try: result = resolve_reference(**kwargs) - except Exception as e: + except Exception as e: # noqa: BLE001 result = e else: from strawberry.types.arguments import convert_argument @@ -187,7 +187,7 @@ def entities_resolver( scalar_registry=scalar_registry, config=config, ) - except Exception: + except Exception: # noqa: BLE001 result = TypeError(f"Unable to resolve reference for {type_name}") results.append(result) @@ -271,7 +271,7 @@ def _add_link_directives( link_directives: list[object] = [ Link( url=url, - import_=list(sorted(directives)), + import_=sorted(directives), # type: ignore[arg-type] ) for url, directives in directive_by_url.items() ] diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index f730e5408e..f33a3e44e7 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Optional, Union, cast, @@ -102,7 +103,7 @@ class GraphQLView( SyncBaseHTTPView[Request, Response, Response, Context, RootValue], View, ): - methods = ["GET", "POST"] + methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = FlaskHTTPRequestAdapter @@ -165,7 +166,7 @@ class AsyncGraphQLView( ], View, ): - methods = ["GET", "POST"] + methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = AsyncFlaskHTTPRequestAdapter diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 799fec8ba3..b73eb55181 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -306,8 +306,7 @@ async def run( await websocket.close(4406, "Subprotocol not acceptable") return websocket_response - else: - request = cast(Request, request) + request = cast(Request, request) request_adapter = self.request_adapter_class(request) sub_response = await self.get_sub_response(request) @@ -325,8 +324,7 @@ async def run( if self.should_render_graphql_ide(request_adapter): if self.graphql_ide: return await self.render_graphql_ide(request) - else: - raise HTTPException(404, "Not Found") + raise HTTPException(404, "Not Found") try: result = await self.execute_operation( diff --git a/strawberry/http/ides.py b/strawberry/http/ides.py index 63d7d4af10..be72fc1b8b 100644 --- a/strawberry/http/ides.py +++ b/strawberry/http/ides.py @@ -17,9 +17,7 @@ def get_graphql_ide_html( else: path = here / "static/graphiql.html" - template = path.read_text(encoding="utf-8") - - return template + return path.read_text(encoding="utf-8") __all__ = ["GraphQL_IDE", "get_graphql_ide_html"] diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 555d7708d0..149d4b50e6 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -175,8 +175,7 @@ def run( if self.should_render_graphql_ide(request_adapter): if self.graphql_ide: return self.render_graphql_ide(request) - else: - raise HTTPException(404, "Not Found") + raise HTTPException(404, "Not Found") sub_response = self.get_sub_response(request) context = ( diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 0e91aacc97..1edf2d5a89 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Any, Callable, + ClassVar, Optional, TypedDict, Union, @@ -203,13 +204,13 @@ async def iter_json( # Litestar internally defaults to an empty string for non-text messages if text == "": - raise NonTextMessageReceived() + raise NonTextMessageReceived try: yield self.view.decode_json(text) - except json.JSONDecodeError: + except json.JSONDecodeError as e: if not ignore_parsing_errors: - raise NonJsonMessageReceived() + raise NonJsonMessageReceived from e except WebSocketDisconnect: pass @@ -236,7 +237,7 @@ class GraphQLController( ], ): path: str = "" - dependencies: Dependencies = { + dependencies: ClassVar[Dependencies] = { # type: ignore[misc] "custom_context": Provide(_none_custom_context_getter), "context": Provide(_context_getter_http), "context_ws": Provide(_context_getter_ws), @@ -445,7 +446,7 @@ def make_graphql_controller( class _GraphQLController(GraphQLController): path: str = routes_path - dependencies: Dependencies = { + dependencies: ClassVar[Dependencies] = { # type: ignore[misc] "custom_context": Provide(custom_context_getter_), "context": Provide(_context_getter_http), "context_ws": Provide(_context_getter_ws), diff --git a/strawberry/permission.py b/strawberry/permission.py index 9624cce535..70622d94e4 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -101,7 +101,7 @@ def on_unauthorized(self) -> None: if self.error_extensions: # Add our extensions to the error if not error.extensions: - error.extensions = dict() + error.extensions = {} error.extensions.update(self.error_extensions) raise error diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 93afe401aa..3a1ff28058 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings from collections.abc import AsyncGenerator, Mapping -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast from typing_extensions import TypeGuard from quart import Request, Response, request @@ -52,7 +52,7 @@ class GraphQLView( ], View, ): - methods = ["GET", "POST"] + methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 5342d6e29a..d529cf63f0 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -758,7 +758,7 @@ class ListConnection(Connection[NodeType]): ) @classmethod - def resolve_connection( + def resolve_connection( # noqa: PLR0915 cls, nodes: NodeIterableType[NodeType], *, diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index cc580f12c5..9be05da5ab 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import abstractmethod -from functools import lru_cache from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Protocol @@ -88,7 +87,6 @@ def get_type_by_name( raise NotImplementedError @abstractmethod - @lru_cache def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirective]: raise NotImplementedError diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index d47baa8a6a..b0e1ebf45d 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -85,7 +85,7 @@ async def _parse_and_validate_async( context: ExecutionContext, extensions_runner: SchemaExtensionsRunner ) -> Optional[PreExecutionError]: if not context.query: - raise MissingQueryError() + raise MissingQueryError async with extensions_runner.parsing(): try: @@ -96,7 +96,7 @@ async def _parse_and_validate_async( context.errors = [error] return PreExecutionError(data=None, errors=[error]) - except Exception as error: + except Exception as error: # noqa: BLE001 error = GraphQLError(str(error), original_error=error) context.errors = [error] return PreExecutionError(data=None, errors=[error]) @@ -189,9 +189,9 @@ async def execute( # only return a sanitised version to the client. process_errors(result.errors, execution_context) - except (MissingQueryError, InvalidOperationTypeError) as e: - raise e - except Exception as exc: + except (MissingQueryError, InvalidOperationTypeError): + raise + except Exception as exc: # noqa: BLE001 return await _handle_execution_result( execution_context, PreExecutionError(data=None, errors=[_coerce_error(exc)]), @@ -219,7 +219,7 @@ def execute_sync( # Note: In graphql-core the schema would be validated here but in # Strawberry we are validating it at initialisation time instead if not execution_context.query: - raise MissingQueryError() + raise MissingQueryError # noqa: TRY301 with extensions_runner.parsing(): try: @@ -238,7 +238,7 @@ def execute_sync( ) if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) + raise InvalidOperationTypeError(execution_context.operation_type) # noqa: TRY301 with extensions_runner.validation(): _run_validation(execution_context) @@ -266,7 +266,7 @@ def execute_sync( if isawaitable(result): result = cast(Awaitable[GraphQLExecutionResult], result) # type: ignore[redundant-cast] ensure_future(result).cancel() - raise RuntimeError( + raise RuntimeError( # noqa: TRY301 "GraphQL execution failed to complete synchronously." ) @@ -282,9 +282,9 @@ def execute_sync( # extension). That way we can log the original errors but # only return a sanitised version to the client. process_errors(result.errors, execution_context) - except (MissingQueryError, InvalidOperationTypeError) as e: - raise e - except Exception as exc: + except (MissingQueryError, InvalidOperationTypeError): + raise + except Exception as exc: # noqa: BLE001 errors = [_coerce_error(exc)] execution_context.errors = errors process_errors(errors, execution_context) diff --git a/strawberry/schema/name_converter.py b/strawberry/schema/name_converter.py index 1dc30edd55..94f8331886 100644 --- a/strawberry/schema/name_converter.py +++ b/strawberry/schema/name_converter.py @@ -47,18 +47,17 @@ def from_type( return self.from_directive(type_) if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum return self.from_enum(type_) - elif isinstance(type_, StrawberryObjectDefinition): + if isinstance(type_, StrawberryObjectDefinition): if type_.is_input: return self.from_input_object(type_) if type_.is_interface: return self.from_interface(type_) return self.from_object(type_) - elif isinstance(type_, StrawberryUnion): + if isinstance(type_, StrawberryUnion): return self.from_union(type_) - elif isinstance(type_, ScalarDefinition): # TODO: Replace with StrawberryScalar + if isinstance(type_, ScalarDefinition): # TODO: Replace with StrawberryScalar return self.from_scalar(type_) - else: - return str(type_) + return str(type_) def from_argument(self, argument: StrawberryArgument) -> str: return self.get_graphql_name(argument) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 2351c8d51b..a7de78c95a 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -30,6 +30,7 @@ DirectivesExtensionSync, ) from strawberry.extensions.runner import SchemaExtensionsRunner +from strawberry.printer import print_schema from strawberry.schema.schema_converter import GraphQLCoreConverter from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY from strawberry.types import ExecutionContext @@ -40,7 +41,6 @@ ) from strawberry.types.graphql import OperationType -from ..printer import print_schema from . import compat from .base import BaseSchema from .config import StrawberryConfig @@ -177,9 +177,11 @@ class Query: self.schema_converter.from_schema_directive(type_) ) else: - if has_object_definition(type_): - if type_.__strawberry_definition__.is_graphql_generic: - type_ = StrawberryAnnotation(type_).resolve() # noqa: PLW2901 + if ( + has_object_definition(type_) + and type_.__strawberry_definition__.is_graphql_generic + ): + type_ = StrawberryAnnotation(type_).resolve() # noqa: PLW2901 graphql_type = self.schema_converter.from_maybe_optional(type_) if isinstance(graphql_type, GraphQLNonNull): graphql_type = graphql_type.of_type diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index bc364ef219..2648aa6ad9 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -47,6 +47,7 @@ ScalarAlreadyRegisteredError, UnresolvedFieldTypeError, ) +from strawberry.extensions.field_extension import build_field_extension_resolvers from strawberry.schema.types.scalar import _make_scalar_type from strawberry.types.arguments import StrawberryArgument, convert_arguments from strawberry.types.base import ( @@ -66,7 +67,6 @@ from strawberry.types.unset import UNSET from strawberry.utils.await_maybe import await_maybe -from ..extensions.field_extension import build_field_extension_resolvers from . import compat from .types.concrete_type import ConcreteType @@ -91,7 +91,9 @@ FieldType = TypeVar( - "FieldType", bound=Union[GraphQLField, GraphQLInputField], covariant=True + "FieldType", + bound=Union[GraphQLField, GraphQLInputField], + covariant=True, ) @@ -756,9 +758,8 @@ async def _async_resolver( if field.is_async: _async_resolver._is_default = not field.base_resolver # type: ignore return _async_resolver - else: - _resolver._is_default = not field.base_resolver # type: ignore - return _resolver + _resolver._is_default = not field.base_resolver # type: ignore + return _resolver def from_scalar(self, scalar: type) -> GraphQLScalarType: scalar_definition: ScalarDefinition @@ -808,10 +809,9 @@ def from_maybe_optional( NoneType = type(None) if type_ is None or type_ is NoneType: return self.from_type(type_) - elif isinstance(type_, StrawberryOptional): + if isinstance(type_, StrawberryOptional): return self.from_type(type_.of_type) - else: - return GraphQLNonNull(self.from_type(type_)) + return GraphQLNonNull(self.from_type(type_)) def from_type(self, type_: Union[StrawberryType, type]) -> GraphQLNullableType: if compat.is_graphql_generic(type_): @@ -819,27 +819,27 @@ def from_type(self, type_: Union[StrawberryType, type]) -> GraphQLNullableType: if isinstance(type_, EnumDefinition): # TODO: Replace with StrawberryEnum return self.from_enum(type_) - elif compat.is_input_type(type_): # TODO: Replace with StrawberryInputObject + if compat.is_input_type(type_): # TODO: Replace with StrawberryInputObject return self.from_input_object(type_) - elif isinstance(type_, StrawberryList): + if isinstance(type_, StrawberryList): return self.from_list(type_) - elif compat.is_interface_type(type_): # TODO: Replace with StrawberryInterface + if compat.is_interface_type(type_): # TODO: Replace with StrawberryInterface type_definition: StrawberryObjectDefinition = ( type_.__strawberry_definition__ # type: ignore ) return self.from_interface(type_definition) - elif has_object_definition(type_): + if has_object_definition(type_): return self.from_object(type_.__strawberry_definition__) - elif compat.is_enum(type_): # TODO: Replace with StrawberryEnum + if compat.is_enum(type_): # TODO: Replace with StrawberryEnum enum_definition: EnumDefinition = type_._enum_definition # type: ignore return self.from_enum(enum_definition) - elif isinstance(type_, StrawberryObjectDefinition): + if isinstance(type_, StrawberryObjectDefinition): return self.from_object(type_) - elif isinstance(type_, StrawberryUnion): + if isinstance(type_, StrawberryUnion): return self.from_union(type_) - elif isinstance(type_, LazyType): + if isinstance(type_, LazyType): return self.from_type(type_.resolve_type()) - elif compat.is_scalar( + if compat.is_scalar( type_, self.scalar_registry ): # TODO: Replace with StrawberryScalar return self.from_scalar(type_) diff --git a/strawberry/schema/subscribe.py b/strawberry/schema/subscribe.py index 5417958b62..8bda51a4d1 100644 --- a/strawberry/schema/subscribe.py +++ b/strawberry/schema/subscribe.py @@ -27,7 +27,7 @@ from graphql.execution.middleware import MiddlewareManager from graphql.type.schema import GraphQLSchema - from ..extensions.runner import SchemaExtensionsRunner + from strawberry.extensions.runner import SchemaExtensionsRunner SubscriptionResult: TypeAlias = Union[ PreExecutionError, AsyncGenerator[ExecutionResult, None] @@ -80,7 +80,7 @@ async def _subscribe( ) # graphql-core 3.2 doesn't handle some of the pre-execution errors. # see `test_subscription_immediate_error` - except Exception as exc: + except Exception as exc: # noqa: BLE001 aiter_or_result = OriginalExecutionResult( data=None, errors=[_coerce_error(exc)] ) @@ -103,7 +103,7 @@ async def _subscribe( process_errors, ) # graphql-core doesn't handle exceptions raised while executing. - except Exception as exc: + except Exception as exc: # noqa: BLE001 yield await _handle_execution_result( execution_context, OriginalExecutionResult(data=None, errors=[_coerce_error(exc)]), @@ -111,7 +111,7 @@ async def _subscribe( process_errors, ) # catch exceptions raised in `on_execute` hook. - except Exception as exc: + except Exception as exc: # noqa: BLE001 origin_result = OriginalExecutionResult( data=None, errors=[_coerce_error(exc)] ) diff --git a/strawberry/schema/types/base_scalars.py b/strawberry/schema/types/base_scalars.py index 4d8a66df23..58807813a7 100644 --- a/strawberry/schema/types/base_scalars.py +++ b/strawberry/schema/types/base_scalars.py @@ -15,7 +15,9 @@ def inner(value: str) -> object: try: return parser(value) except ValueError as e: - raise GraphQLError(f'Value cannot represent a {type_}: "{value}". {e}') + raise GraphQLError( # noqa: B904 + f'Value cannot represent a {type_}: "{value}". {e}' + ) return inner @@ -24,7 +26,7 @@ def parse_decimal(value: object) -> decimal.Decimal: try: return decimal.Decimal(str(value)) except decimal.DecimalException: - raise GraphQLError(f'Value cannot represent a Decimal: "{value}".') + raise GraphQLError(f'Value cannot represent a Decimal: "{value}".') # noqa: B904 isoformat = methodcaller("isoformat") diff --git a/strawberry/schema_codegen/__init__.py b/strawberry/schema_codegen/__init__.py index 92c018d948..78a443b780 100644 --- a/strawberry/schema_codegen/__init__.py +++ b/strawberry/schema_codegen/__init__.py @@ -115,7 +115,7 @@ def _get_field_type( if isinstance(field_type, NonNullTypeNode): return _get_field_type(field_type.type, was_non_nullable=True) - elif isinstance(field_type, ListTypeNode): + if isinstance(field_type, ListTypeNode): expr = cst.Subscript( value=cst.Name("list"), slice=[ @@ -262,14 +262,13 @@ def _get_field( def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue: if isinstance(argument_value, StringValueNode): return argument_value.value - elif isinstance(argument_value, EnumValueDefinitionNode): + if isinstance(argument_value, EnumValueDefinitionNode): return argument_value.name.value - elif isinstance(argument_value, ListValueNode): + if isinstance(argument_value, ListValueNode): return [_get_argument_value(arg) for arg in argument_value.values] - elif isinstance(argument_value, BooleanValueNode): + if isinstance(argument_value, BooleanValueNode): return argument_value.value - else: - raise NotImplementedError(f"Unknown argument value {argument_value}") + raise NotImplementedError(f"Unknown argument value {argument_value}") def _get_directives( diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index b4cdc9d4e8..5a0993fdcf 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -122,7 +122,7 @@ async def handle_connection_init_timeout(self) -> None: self.connection_timed_out = True reason = "Connection initialisation timeout" await self.websocket.close(code=4408, reason=reason) - except Exception as error: + except Exception as error: # noqa: BLE001 await self.handle_task_exception(error) # pragma: no cover finally: # do not clear self.connection_init_timeout_task @@ -298,7 +298,7 @@ async def run_operation(self, operation: Operation[Context, RootValue]) -> None: {"id": operation.id, "type": "complete"} ) - except BaseException as e: # pragma: no cover + except BaseException: # pragma: no cover self.operations.pop(operation.id, None) raise finally: diff --git a/strawberry/test/client.py b/strawberry/test/client.py index 7ec027f0fc..10400d06e8 100644 --- a/strawberry/test/client.py +++ b/strawberry/test/client.py @@ -188,8 +188,7 @@ def _build_multipart_file_map( # Variables can be mixed files and other data, we don't want to map non-files # vars so we need to remove them, we can't remove them before # because they can be part of a list of files or folder - map_without_vars = {k: v for k, v in map.items() if k in files} - return map_without_vars + return {k: v for k, v in map.items() if k in files} def _decode(self, response: Any, type: Literal["multipart", "json"]) -> Any: if type == "multipart": diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index ff5930f431..e0bc34d4f6 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -23,8 +23,8 @@ ) from strawberry.types.enum import EnumDefinition from strawberry.types.lazy_type import LazyType, StrawberryLazyReference -from strawberry.types.unset import UNSET as _deprecated_UNSET -from strawberry.types.unset import _deprecated_is_unset # noqa # type: ignore +from strawberry.types.unset import UNSET as _deprecated_UNSET # noqa: N811 +from strawberry.types.unset import _deprecated_is_unset # noqa: F401 if TYPE_CHECKING: from strawberry.schema.config import StrawberryConfig diff --git a/strawberry/types/auto.py b/strawberry/types/auto.py index 7ee49b3d4d..68aad315df 100644 --- a/strawberry/types/auto.py +++ b/strawberry/types/auto.py @@ -23,8 +23,8 @@ class StrawberryAutoMeta(type): """ - def __init__(self, *args: str, **kwargs: Any) -> None: - self._instance: Optional[StrawberryAuto] = None + def __init__(cls, *args: str, **kwargs: Any) -> None: + cls._instance: Optional[StrawberryAuto] = None super().__init__(*args, **kwargs) def __call__(cls, *args: str, **kwargs: Any) -> Any: @@ -34,7 +34,7 @@ def __call__(cls, *args: str, **kwargs: Any) -> Any: return cls._instance def __instancecheck__( - self, + cls, instance: Union[StrawberryAuto, StrawberryAnnotation, StrawberryType, type], ) -> bool: if isinstance(instance, StrawberryAnnotation): diff --git a/strawberry/types/base.py b/strawberry/types/base.py index 636ae3cab9..fb8831b3f6 100644 --- a/strawberry/types/base.py +++ b/strawberry/types/base.py @@ -54,12 +54,12 @@ def copy_with( str, Union[StrawberryType, type[WithStrawberryObjectDefinition]] ], ) -> Union[StrawberryType, type[WithStrawberryObjectDefinition]]: - raise NotImplementedError() + raise NotImplementedError @property @abstractmethod def is_graphql_generic(self) -> bool: - raise NotImplementedError() + raise NotImplementedError def has_generic(self, type_var: TypeVar) -> bool: return False @@ -70,17 +70,15 @@ def __eq__(self, other: object) -> bool: if isinstance(other, StrawberryType): return self is other - elif isinstance(other, StrawberryAnnotation): + if isinstance(other, StrawberryAnnotation): return self == other.resolve() - else: - # This could be simplified if StrawberryAnnotation.resolve() always returned - # a StrawberryType - resolved = StrawberryAnnotation(other).resolve() - if isinstance(resolved, StrawberryType): - return self == resolved - else: - return NotImplemented + # This could be simplified if StrawberryAnnotation.resolve() always returned + # a StrawberryType + resolved = StrawberryAnnotation(other).resolve() + if isinstance(resolved, StrawberryType): + return self == resolved + return NotImplemented def __hash__(self) -> int: # TODO: Is this a bad idea? __eq__ objects are supposed to have the same hash @@ -100,8 +98,7 @@ def __eq__(self, other: object) -> bool: if isinstance(other, StrawberryType): if isinstance(other, StrawberryContainer): return self.of_type == other.of_type - else: - return False + return False return super().__eq__(other) @@ -112,11 +109,10 @@ def type_params(self) -> list[TypeVar]: return list(parameters) if parameters else [] - elif isinstance(self.of_type, StrawberryType): + if isinstance(self.of_type, StrawberryType): return self.of_type.type_params - else: - return [] + return [] def copy_with( self, diff --git a/strawberry/types/field.py b/strawberry/types/field.py index 2c5c95440f..279611a97f 100644 --- a/strawberry/types/field.py +++ b/strawberry/types/field.py @@ -191,17 +191,21 @@ def __call__(self, resolver: _RESOLVER_TYPE) -> Self: for argument in resolver.arguments: if isinstance(argument.type_annotation.annotation, str): continue - elif isinstance(argument.type, StrawberryUnion): + + if isinstance(argument.type, StrawberryUnion): + raise InvalidArgumentTypeError( + resolver, + argument, + ) + + if ( + has_object_definition(argument.type) + and argument.type.__strawberry_definition__.is_interface + ): raise InvalidArgumentTypeError( resolver, argument, ) - elif has_object_definition(argument.type): - if argument.type.__strawberry_definition__.is_interface: - raise InvalidArgumentTypeError( - resolver, - argument, - ) self.base_resolver = resolver diff --git a/strawberry/types/fields/resolver.py b/strawberry/types/fields/resolver.py index dfa31e8a52..1a6df222ba 100644 --- a/strawberry/types/fields/resolver.py +++ b/strawberry/types/fields/resolver.py @@ -91,8 +91,7 @@ def find( if parameters: # Add compatibility for resolvers with no arguments first_parameter = parameters[0] return first_parameter if first_parameter.name == self.name else None - else: - return None + return None class ReservedType(NamedTuple): @@ -145,21 +144,19 @@ def find( ) warnings.warn(warning, stacklevel=3) return reserved_name - else: - return None + return None def is_reserved_type(self, other: builtins.type) -> bool: origin = cast(type, get_origin(other)) or other if origin is Annotated: # Handle annotated arguments such as Private[str] and DirectiveValue[str] return type_has_annotation(other, self.type) - else: - # Handle both concrete and generic types (i.e Info, and Info) - return ( - issubclass(origin, self.type) - if isinstance(origin, type) - else origin is self.type - ) + # Handle both concrete and generic types (i.e Info, and Info) + return ( + issubclass(origin, self.type) + if isinstance(origin, type) + else origin is self.type + ) SELF_PARAMSPEC = ReservedNameBoundParameter("self") @@ -309,24 +306,20 @@ def annotations(self) -> dict[str, object]: reserved_names = {p.name for p in reserved_parameters.values() if p is not None} annotations = self._unbound_wrapped_func.__annotations__ - annotations = { + return { name: annotation for name, annotation in annotations.items() if name not in reserved_names } - return annotations - @cached_property def type_annotation(self) -> Optional[StrawberryAnnotation]: return_annotation = self.signature.return_annotation if return_annotation is inspect.Signature.empty: return None - else: - type_annotation = StrawberryAnnotation( - annotation=return_annotation, namespace=self._namespace - ) - return type_annotation + return StrawberryAnnotation( + annotation=return_annotation, namespace=self._namespace + ) @property def type(self) -> Optional[Union[StrawberryType, type]]: diff --git a/strawberry/types/union.py b/strawberry/types/union.py index 003d532c90..f32356b583 100644 --- a/strawberry/types/union.py +++ b/strawberry/types/union.py @@ -54,7 +54,7 @@ class StrawberryUnion(StrawberryType): def __init__( self, name: Optional[str] = None, - type_annotations: tuple[StrawberryAnnotation, ...] = tuple(), + type_annotations: tuple[StrawberryAnnotation, ...] = (), description: Optional[str] = None, directives: Iterable[object] = (), ) -> None: diff --git a/strawberry/types/unset.py b/strawberry/types/unset.py index e1d2acea0f..2f28d65ec0 100644 --- a/strawberry/types/unset.py +++ b/strawberry/types/unset.py @@ -14,8 +14,7 @@ def __new__(cls: type["UnsetType"]) -> "UnsetType": ret = super().__new__(cls) cls.__instance = ret return ret - else: - return cls.__instance + return cls.__instance def __str__(self) -> str: return "" diff --git a/strawberry/utils/debug.py b/strawberry/utils/debug.py index 25fa0e5f7f..f56156215c 100644 --- a/strawberry/utils/debug.py +++ b/strawberry/utils/debug.py @@ -30,7 +30,7 @@ def pretty_print_graphql_operation( if operation_name == "IntrospectionQuery": return - now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # noqa: DTZ005 print(f"[{now}]: {operation_name or 'No operation name'}") # noqa: T201 print(highlight(query, GraphQLLexer(), Terminal256Formatter())) # noqa: T201 diff --git a/strawberry/utils/deprecations.py b/strawberry/utils/deprecations.py index 646c31225a..51f85c317d 100644 --- a/strawberry/utils/deprecations.py +++ b/strawberry/utils/deprecations.py @@ -4,7 +4,7 @@ from typing import Any, Optional -class DEPRECATION_MESSAGES: +class DEPRECATION_MESSAGES: # noqa: N801 _TYPE_DEFINITION = ( "_type_definition is deprecated, use __strawberry_definition__ instead" ) diff --git a/strawberry/utils/graphql_lexer.py b/strawberry/utils/graphql_lexer.py index 4e23668dd7..9c361e778c 100644 --- a/strawberry/utils/graphql_lexer.py +++ b/strawberry/utils/graphql_lexer.py @@ -1,3 +1,5 @@ +from typing import Any, ClassVar + from pygments import token from pygments.lexer import RegexLexer @@ -6,11 +8,11 @@ class GraphQLLexer(RegexLexer): """GraphQL Lexer for Pygments, used by the debug server.""" name = "GraphQL" - aliases = ["graphql", "gql"] - filenames = ["*.graphql", "*.gql"] - mimetypes = ["application/graphql"] + aliases: ClassVar[list[str]] = ["graphql", "gql"] + filenames: ClassVar[list[str]] = ["*.graphql", "*.gql"] + mimetypes: ClassVar[list[str]] = ["application/graphql"] - tokens = { + tokens: ClassVar[dict[str, list[tuple[str, Any]]]] = { "root": [ (r"#.*", token.Comment.Singline), (r"\.\.\.", token.Operator), diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index e06a7f72a2..2ce7065573 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -186,8 +186,7 @@ def get_parameters(annotation: type) -> Union[tuple[object], tuple[()]]: and annotation is not Generic ): return annotation.__parameters__ # type: ignore[union-attr] - else: - return () # pragma: no cover + return () # pragma: no cover @overload diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index 1610cd178d..253be77a39 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -51,9 +51,11 @@ async def test_no_layers(): "Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html " "for more information" ) - with pytest.deprecated_call(match="Use listen_to_channel instead"): - with pytest.raises(RuntimeError, match=msg): - await consumer.channel_listen("foobar").__anext__() + with ( + pytest.deprecated_call(match="Use listen_to_channel instead"), + pytest.raises(RuntimeError, match=msg), + ): + await consumer.channel_listen("foobar").__anext__() with pytest.raises(RuntimeError, match=msg): async with consumer.listen_to_channel("foobar"): diff --git a/tests/cli/fixtures/__init__.py b/tests/cli/fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/codegen/__init__.py b/tests/codegen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/conftest.py b/tests/conftest.py index 99049677c0..b4477a9ddb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,7 @@ def pytest_ignore_collect( ): if sys.version_info < (3, 12) and "python_312" in collection_path.parts: return True + return None def skip_if_gql_32(reason: str) -> pytest.MarkDecorator: diff --git a/tests/django/app/models.py b/tests/django/app/models.py index 77be618f73..b4bcfab4e3 100644 --- a/tests/django/app/models.py +++ b/tests/django/app/models.py @@ -1,5 +1,5 @@ from django.db import models -class Example(models.Model): +class Example(models.Model): # noqa: DJ008 name = models.CharField(max_length=100) diff --git a/tests/django/conftest.py b/tests/django/conftest.py index 288185723b..60a47556ac 100644 --- a/tests/django/conftest.py +++ b/tests/django/conftest.py @@ -8,7 +8,7 @@ from strawberry.django.test import GraphQLTestClient -@pytest.fixture() +@pytest.fixture def graphql_client() -> GraphQLTestClient: from django.test.client import Client diff --git a/tests/enums/__init__.py b/tests/enums/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/exceptions/classes/__init__.py b/tests/exceptions/classes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/exceptions/classes/test_exception_class_missing_optional_dependencies_error.py b/tests/exceptions/classes/test_exception_class_missing_optional_dependencies_error.py index 74e740a07c..7f8fa007c5 100644 --- a/tests/exceptions/classes/test_exception_class_missing_optional_dependencies_error.py +++ b/tests/exceptions/classes/test_exception_class_missing_optional_dependencies_error.py @@ -5,7 +5,7 @@ def test_missing_optional_dependencies_error(): with pytest.raises(MissingOptionalDependenciesError) as exc_info: - raise MissingOptionalDependenciesError() + raise MissingOptionalDependenciesError assert exc_info.value.message == "Some optional dependencies are missing" diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index 6d87c690a7..be761ae87c 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -247,7 +247,7 @@ def test(self, x: OwningInput) -> ExampleOutput: } """ result = schema.execute_sync( - query, variable_values=dict(input_data=dict(nonScalarType={})) + query, variable_values={"input_data": {"nonScalarType": {}}} ) assert not result.errors diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 8624a13b7f..99fd440042 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -110,7 +110,7 @@ class UserType: def test_auto_fields_other_sentinel(): - class other_sentinel: + class OtherSentinel: pass class User(pydantic.BaseModel): @@ -122,7 +122,7 @@ class User(pydantic.BaseModel): class UserType: age: strawberry.auto password: strawberry.auto - other: other_sentinel # this should be a private field, not an auto field + other: OtherSentinel # this should be a private field, not an auto field definition: StrawberryObjectDefinition = UserType.__strawberry_definition__ assert definition.name == "UserType" @@ -140,7 +140,7 @@ class UserType: assert field3.python_name == "other" assert field3.graphql_name is None - assert field3.type is other_sentinel + assert field3.type is OtherSentinel def test_referencing_other_models_fails_when_not_registered(): @@ -668,7 +668,7 @@ class Query: @strawberry.type class Mutation: @strawberry.mutation - def updateGroup(group: GroupInput) -> GroupOutput: + def update_group(group: GroupInput) -> GroupOutput: pass # This triggers the exception from #1504 diff --git a/tests/extensions/tracing/__init__.py b/tests/extensions/tracing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/federation/printer/test_additional_directives.py b/tests/federation/printer/test_additional_directives.py index 58aff77180..8ae4662353 100644 --- a/tests/federation/printer/test_additional_directives.py +++ b/tests/federation/printer/test_additional_directives.py @@ -20,7 +20,7 @@ class FederatedType: @strawberry.type class Query: - federatedType: FederatedType + federatedType: FederatedType # noqa: N815 expected_type = """ directive @CacheControl(max_age: Int!) on OBJECT @@ -70,7 +70,7 @@ class FederatedType: @strawberry.type class Query: - federatedType: FederatedType + federatedType: FederatedType # noqa: N815 expected_type = """ directive @CacheControl0(max_age: Int!) on OBJECT diff --git a/tests/federation/printer/test_compose_directive.py b/tests/federation/printer/test_compose_directive.py index ddbc0463c9..362dcbf056 100644 --- a/tests/federation/printer/test_compose_directive.py +++ b/tests/federation/printer/test_compose_directive.py @@ -30,7 +30,7 @@ class FederatedType: @strawberry.type class Query: - federatedType: FederatedType + federatedType: FederatedType # noqa: N815 expected_type = """ directive @cacheControl(maxAge: Int!) on OBJECT @@ -95,7 +95,7 @@ class FederatedType: @strawberry.type class Query: - federatedType: FederatedType + federatedType: FederatedType # noqa: N815 expected_type = """ directive @cacheControl(maxAge: Int!) on OBJECT diff --git a/tests/fields/__init__.py b/tests/fields/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fields/test_field_defaults.py b/tests/fields/test_field_defaults.py index cfbc0a6d8d..b3ceba3aa5 100644 --- a/tests/fields/test_field_defaults.py +++ b/tests/fields/test_field_defaults.py @@ -55,14 +55,16 @@ class Query: def test_field_with_separate_resolver_default(): - with pytest.raises(FieldWithResolverAndDefaultValueError): + def fruit_resolver() -> str: # pragma: no cover + return "banana" - def gun_resolver() -> str: - return "revolver" + with pytest.raises(FieldWithResolverAndDefaultValueError): @strawberry.type class Query: - weapon: str = strawberry.field(default="sword", resolver=gun_resolver) + weapon: str = strawberry.field( + default="strawberry", resolver=fruit_resolver + ) def test_field_with_resolver_and_default_factory(): diff --git a/tests/fields/test_field_exceptions.py b/tests/fields/test_field_exceptions.py index 7eeeb3f60b..61a6de7ec4 100644 --- a/tests/fields/test_field_exceptions.py +++ b/tests/fields/test_field_exceptions.py @@ -24,14 +24,14 @@ def fruit(self) -> str: def test_field_with_separate_resolver_default(): - with pytest.raises(FieldWithResolverAndDefaultValueError): + def fruit_resolver() -> str: # pragma: no cover + return "strawberry" - def gun_resolver() -> str: - return "revolver" + with pytest.raises(FieldWithResolverAndDefaultValueError): @strawberry.type class Query: - weapon: str = strawberry.field(default="sword", resolver=gun_resolver) + weapon: str = strawberry.field(default="banana", resolver=fruit_resolver) def test_field_with_resolver_default_factory(): diff --git a/tests/fields/test_resolvers.py b/tests/fields/test_resolvers.py index dbf9c53536..60d62039ff 100644 --- a/tests/fields/test_resolvers.py +++ b/tests/fields/test_resolvers.py @@ -389,11 +389,11 @@ def parent_and_info( @pytest.mark.parametrize( "resolver_func", - ( + [ pytest.param(self_and_info), pytest.param(root_and_info), pytest.param(parent_and_info), - ), + ], ) def test_resolver_annotations(resolver_func): """Ensure only non-reserved annotations are returned.""" @@ -417,7 +417,7 @@ def test_resolver_with_unhashable_default(): @strawberry.type class Query: @strawberry.field - def field(self, x: list[str] = ["foo"], y: JSON = {"foo": 42}) -> str: + def field(self, x: list[str] = ["foo"], y: JSON = {"foo": 42}) -> str: # noqa: B006 return f"{x} {y}" schema = strawberry.Schema(Query) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 421fe961a1..07f588c349 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -15,10 +15,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin -from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -173,9 +173,11 @@ async def ws_connect( *, protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: - async with TestClient(TestServer(self.app)) as client: - async with client.ws_connect(url, protocols=protocols) as ws: - yield AioWebSocketClient(ws) + async with ( + TestClient(TestServer(self.app)) as client, + client.ws_connect(url, protocols=protocols) as ws, + ): + yield AioWebSocketClient(ws) class AioWebSocketClient(WebSocketClient): diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 9a8036d688..a354dcb935 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -16,10 +16,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin -from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -194,7 +194,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: self._close_code = m["code"] self._close_reason = m["reason"] return Message(type=m["type"], data=m["code"], extra=m["reason"]) - elif m["type"] == "websocket.send": + if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) return Message(type=m["type"], data=m["data"], extra=m["extra"]) diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index fe97a43c8b..870d27f6ed 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -8,9 +8,9 @@ from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import Response, ResultOverrideFunction from .django import DjangoHttpClient diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index d828f7d929..1ad8cb0356 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -9,9 +9,9 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import ResultOverrideFunction from .flask import FlaskHttpClient diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 795dc60094..786ec4f8bd 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -270,6 +270,7 @@ async def send_bytes(self, payload: bytes) -> None: ... @abc.abstractmethod async def receive(self, timeout: Optional[float] = None) -> Message: ... + @abc.abstractmethod async def receive_json(self, timeout: Optional[float] = None) -> Any: ... @abc.abstractmethod diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index e57062bb3b..3ea6189326 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -14,9 +14,9 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import JSON, HttpClient, Response, ResultOverrideFunction diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 6fe6a135e1..4a89324a70 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -20,10 +20,10 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin -from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -306,7 +306,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: self._close_code = m["code"] self._close_reason = m.get("reason") return Message(type=m["type"], data=m["code"], extra=m.get("reason")) - elif m["type"] == "websocket.send": + if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) return Message(type=m["type"], data=m["data"], extra=m["extra"]) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index fae0823f0c..1a2301ad07 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -14,9 +14,9 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import JSON, HttpClient, Response, ResultOverrideFunction diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 21cf010d54..c5a8da97da 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -13,10 +13,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin -from ..context import get_context from .asgi import AsgiWebSocketClient from .base import ( JSON, @@ -125,11 +125,10 @@ async def _graphql_request( if body: if method == "get": kwargs["params"] = body + elif files: + kwargs["data"] = body else: - if files: - kwargs["data"] = body - else: - kwargs["content"] = json.dumps(body) + kwargs["content"] = json.dumps(body) if files: kwargs["files"] = files diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 7509d0f911..644f30095c 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -16,9 +16,9 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import JSON, HttpClient, Response, ResultOverrideFunction diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 931cc98297..48c4f0703d 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -15,10 +15,10 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.litestar import make_graphql_controller from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin -from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -200,7 +200,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: self._close_code = m["code"] self._close_reason = m["reason"] return Message(type=m["type"], data=m["code"], extra=m["reason"]) - elif m["type"] == "websocket.send": + if m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) assert "data" in m diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index a562aa0f5a..1711e58b45 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -12,9 +12,9 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.quart.views import GraphQLView as BaseGraphQLView from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -112,10 +112,9 @@ async def request( headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: - async with self.app.test_app() as test_app: - async with self.app.app_context(): - client = test_app.test_client() - response = await getattr(client, method)(url, headers=headers, **kwargs) + async with self.app.test_app() as test_app, self.app.app_context(): + client = test_app.test_client() + response = await getattr(client, method)(url, headers=headers, **kwargs) return Response( status_code=response.status_code, diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 86a3346f73..f43b324afe 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -13,9 +13,9 @@ from strawberry.http.temporal_response import TemporalResponse from strawberry.sanic.views import GraphQLView as BaseGraphQLView from strawberry.types import ExecutionResult +from tests.http.context import get_context from tests.views.schema import Query, schema -from ..context import get_context from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -87,11 +87,10 @@ async def _graphql_request( if body: if method == "get": kwargs["params"] = body + elif files: + kwargs["data"] = body else: - if files: - kwargs["data"] = body - else: - kwargs["content"] = dumps(body) + kwargs["content"] = dumps(body) request, response = await self.app.asgi_client.request( method, diff --git a/tests/http/conftest.py b/tests/http/conftest.py index 3f61148830..cd8b8b69e9 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -32,7 +32,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: importlib.import_module(f".{module}", package="tests.http.clients"), client, ) - except ImportError as e: + except ImportError: client_class = None yield pytest.param( @@ -51,6 +51,6 @@ def http_client_class(request: Any) -> type[HttpClient]: return request.param -@pytest.fixture() +@pytest.fixture def http_client(http_client_class: type[HttpClient]) -> HttpClient: return http_client_class() diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index fa8a08248f..bb964c2898 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -8,7 +8,7 @@ from .clients.base import HttpClient -@pytest.fixture() +@pytest.fixture def http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): import django diff --git a/tests/http/test_process_result.py b/tests/http/test_process_result.py index fdc852af72..e30fee5cc5 100644 --- a/tests/http/test_process_result.py +++ b/tests/http/test_process_result.py @@ -19,7 +19,7 @@ def process_result(result: ExecutionResult) -> GraphQLHTTPResponse: return {} -@pytest.fixture() +@pytest.fixture def http_client(http_client_class) -> HttpClient: return http_client_class(result_override=process_result) diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py index e82f7e30b5..6d3b97d071 100644 --- a/tests/http/test_upload.py +++ b/tests/http/test_upload.py @@ -8,7 +8,7 @@ from .clients.base import HttpClient -@pytest.fixture() +@pytest.fixture def http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): from .clients.chalice import ChaliceHttpClient @@ -19,7 +19,7 @@ def http_client(http_client_class: type[HttpClient]) -> HttpClient: return http_client_class() -@pytest.fixture() +@pytest.fixture def enabled_http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): from .clients.chalice import ChaliceHttpClient diff --git a/tests/objects/__init__.py b/tests/objects/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/objects/generics/__init__.py b/tests/objects/generics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/strawberry_exceptions.py b/tests/plugins/strawberry_exceptions.py index e1623f66fc..c57b895607 100644 --- a/tests/plugins/strawberry_exceptions.py +++ b/tests/plugins/strawberry_exceptions.py @@ -32,9 +32,11 @@ def suppress_output(verbosity_level: int = 0) -> Generator[None, None, None]: return - with Path(os.devnull).open("w", encoding="utf-8") as devnull: - with contextlib.redirect_stdout(devnull): - yield + with ( + Path(os.devnull).open("w", encoding="utf-8") as devnull, + contextlib.redirect_stdout(devnull), + ): + yield class StrawberryExceptionsPlugin: diff --git a/tests/python_312/test_generics_schema.py b/tests/python_312/test_generics_schema.py index d3611ac368..76da5dc8e4 100644 --- a/tests/python_312/test_generics_schema.py +++ b/tests/python_312/test_generics_schema.py @@ -844,7 +844,7 @@ def user(self) -> Union[User, Edge[User]]: assert result.data == {"user": {"__typename": "UserEdge", "nodes": []}} -@pytest.mark.xfail() +@pytest.mark.xfail def test_raises_error_when_unable_to_find_type(): @strawberry.type class User: @@ -995,7 +995,7 @@ class Book(Node[str]): class Query: @strawberry.field def books(self) -> list[Book]: - return list() + return [] schema = strawberry.Schema(query=Query) diff --git a/tests/python_312/test_python_generics.py b/tests/python_312/test_python_generics.py index 7f3a8b1b90..d7c653c07b 100644 --- a/tests/python_312/test_python_generics.py +++ b/tests/python_312/test_python_generics.py @@ -75,7 +75,7 @@ class GenericInterface[T]: @strawberry.field def value(self) -> str: - raise NotImplementedError() + raise NotImplementedError @strawberry.type class ImplementationOne(GenericInterface[str]): @@ -132,7 +132,7 @@ class GenericInterface[T]: @strawberry.field def value(self) -> str: - raise NotImplementedError() + raise NotImplementedError @strawberry.type class ImplementationOne(GenericInterface[str]): diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py index 756f97822f..b849bd2a2f 100644 --- a/tests/relay/test_types.py +++ b/tests/relay/test_types.py @@ -23,13 +23,13 @@ class FakeInfo: @pytest.mark.parametrize("type_name", [None, 1, 1.1]) def test_global_id_wrong_type_name(type_name: Any): - with pytest.raises(relay.GlobalIDValueError) as exc_info: + with pytest.raises(relay.GlobalIDValueError): relay.GlobalID(type_name=type_name, node_id="foobar") @pytest.mark.parametrize("node_id", [None, 1, 1.1]) def test_global_id_wrong_type_node_id(node_id: Any): - with pytest.raises(relay.GlobalIDValueError) as exc_info: + with pytest.raises(relay.GlobalIDValueError): relay.GlobalID(type_name="foobar", node_id=node_id) @@ -41,7 +41,7 @@ def test_global_id_from_id(): @pytest.mark.parametrize("value", ["foobar", ["Zm9vYmFy"], 123]) def test_global_id_from_id_error(value: Any): - with pytest.raises(relay.GlobalIDValueError) as exc_info: + with pytest.raises(relay.GlobalIDValueError): relay.GlobalID.from_id(value) @@ -67,9 +67,9 @@ def test_global_id_resolve_node_sync_non_existing(): def test_global_id_resolve_node_sync_non_existing_but_required(): + gid = relay.GlobalID(type_name="Fruit", node_id="999") with pytest.raises(KeyError): - gid = relay.GlobalID(type_name="Fruit", node_id="999") - fruit = gid.resolve_node_sync(fake_info, required=True) + gid.resolve_node_sync(fake_info, required=True) def test_global_id_resolve_node_sync_ensure_type(): @@ -97,7 +97,7 @@ class Foo: ... gid = relay.GlobalID(type_name="Fruit", node_id="1") with pytest.raises(TypeError): - fruit = gid.resolve_node_sync(fake_info, ensure_type=Foo) + gid.resolve_node_sync(fake_info, ensure_type=Foo) async def test_global_id_resolve_node(): @@ -117,9 +117,9 @@ async def test_global_id_resolve_node_non_existing(): async def test_global_id_resolve_node_non_existing_but_required(): + gid = relay.GlobalID(type_name="FruitAsync", node_id="999") with pytest.raises(KeyError): - gid = relay.GlobalID(type_name="FruitAsync", node_id="999") - fruit = await gid.resolve_node(fake_info, required=True) + await gid.resolve_node(fake_info, required=True) async def test_global_id_resolve_node_ensure_type(): @@ -147,7 +147,7 @@ class Foo: ... gid = relay.GlobalID(type_name="FruitAsync", node_id="1") with pytest.raises(TypeError): - fruit = await gid.resolve_node(fake_info, ensure_type=Foo) + await gid.resolve_node(fake_info, ensure_type=Foo) async def test_resolve_async_list_connection(): diff --git a/tests/schema/extensions/schema_extensions/conftest.py b/tests/schema/extensions/schema_extensions/conftest.py index c859269d21..dc43e48a95 100644 --- a/tests/schema/extensions/schema_extensions/conftest.py +++ b/tests/schema/extensions/schema_extensions/conftest.py @@ -43,7 +43,7 @@ def assert_expected(cls) -> None: assert cls.called_hooks == cls.expected -@pytest.fixture() +@pytest.fixture def default_query_types_and_query() -> SchemaHelper: @strawberry.type class Person: @@ -94,7 +94,7 @@ def hook_wrap(list_: list[str], hook_name: str): list_.append(f"{hook_name} Exited") -@pytest.fixture() +@pytest.fixture def async_extension() -> type[ExampleExtension]: class MyExtension(ExampleExtension): async def on_operation(self): diff --git a/tests/schema/extensions/schema_extensions/test_extensions.py b/tests/schema/extensions/schema_extensions/test_extensions.py index f569e13fbd..2dbc6e4768 100644 --- a/tests/schema/extensions/schema_extensions/test_extensions.py +++ b/tests/schema/extensions/schema_extensions/test_extensions.py @@ -205,7 +205,7 @@ def on_operation(self): assert res.data == {"override": 20} -@pytest.fixture() +@pytest.fixture def sync_extension() -> type[ExampleExtension]: class MyExtension(ExampleExtension): def on_operation(self): @@ -646,7 +646,7 @@ def on_execute(self): @pytest.mark.parametrize( "failing_hook", - ( + [ "on_operation_start", "on_operation_end", "on_parse_start", @@ -655,7 +655,7 @@ def on_execute(self): "on_validate_end", "on_execute_start", "on_execute_end", - ), + ], ) @pytest.mark.asyncio async def test_exceptions_are_included_in_the_execution_result(failing_hook): @@ -684,7 +684,7 @@ def ping(self) -> str: @pytest.mark.parametrize( ("failing_hook", "expected_hooks"), - ( + [ ("on_operation_start", set()), ("on_parse_start", {1, 8}), ("on_parse_end", {1, 2, 8}), @@ -693,7 +693,7 @@ def ping(self) -> str: ("on_execute_start", {1, 2, 3, 4, 5, 8}), ("on_execute_end", {1, 2, 3, 4, 5, 6, 8}), ("on_operation_end", {1, 2, 3, 4, 5, 6, 7}), - ), + ], ) @pytest.mark.asyncio async def test_exceptions_abort_evaluation(failing_hook, expected_hooks): diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index c0e8bdb613..ae9bc3848d 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -284,8 +284,7 @@ def resolve( ): nonlocal field_kwargs field_kwargs = kwargs - result = next_(source, info, **kwargs) - return result + return next_(source, info, **kwargs) @strawberry.type class Query: @@ -316,8 +315,7 @@ def resolve( **kwargs: Any, ): kwargs["some_input"] += 10 - result = next_(source, info, **kwargs) - return result + return next_(source, info, **kwargs) @strawberry.type class Query: @@ -353,8 +351,7 @@ def resolve( assert argument_def is not None argument_metadata[key] = argument_def.metadata - result = next_(source, info, **kwargs) - return result + return next_(source, info, **kwargs) @strawberry.type class Query: diff --git a/tests/schema/extensions/test_imports.py b/tests/schema/extensions/test_imports.py index 13a41a7fc6..cf661ac452 100644 --- a/tests/schema/extensions/test_imports.py +++ b/tests/schema/extensions/test_imports.py @@ -8,4 +8,4 @@ def test_can_import(mocker): def test_fails_if_import_is_not_found(): with pytest.raises(ImportError): - from strawberry.extensions.tracing import Blueberry # noqa + from strawberry.extensions.tracing import Blueberry # noqa: F401 diff --git a/tests/schema/extensions/test_query_depth_limiter.py b/tests/schema/extensions/test_query_depth_limiter.py index 89c1cec115..d8ca4526ef 100644 --- a/tests/schema/extensions/test_query_depth_limiter.py +++ b/tests/schema/extensions/test_query_depth_limiter.py @@ -245,11 +245,6 @@ def test_should_catch_query_thats_too_deep(): def test_should_raise_invalid_ignore(): - query = """ - query read1 { - user { address { city } } - } - """ with pytest.raises( TypeError, match="The `should_ignore` argument to `QueryDepthLimiter` must be a callable.", @@ -272,11 +267,7 @@ def test_should_ignore_field_by_name(): """ def should_ignore(ignore: IgnoreContext) -> bool: - return ( - ignore.field_name == "user1" - or ignore.field_name == "user2" - or ignore.field_name == "user3" - ) + return ignore.field_name in ("user1", "user2", "user3") errors, result = run_query(query, 10, should_ignore=should_ignore) diff --git a/tests/schema/test_basic.py b/tests/schema/test_basic.py index 47e248f2f8..1f63c86084 100644 --- a/tests/schema/test_basic.py +++ b/tests/schema/test_basic.py @@ -463,9 +463,9 @@ class Query: def test_str_magic_method_prints_schema_sdl(): @strawberry.type class Query: - exampleBool: bool - exampleStr: str = "Example" - exampleInt: int = 1 + example_bool: bool + example_str: str = "Example" + example_int: int = 1 schema = strawberry.Schema(query=Query) expected = """ diff --git a/tests/schema/test_directives.py b/tests/schema/test_directives.py index b67391b28b..e5a83bc88c 100644 --- a/tests/schema/test_directives.py +++ b/tests/schema/test_directives.py @@ -420,7 +420,7 @@ class Locale(Enum): @strawberry.type class Query: @strawberry.field - def greetingTemplate(self, locale: Locale = Locale.EN) -> str: + def greeting_template(self, locale: Locale = Locale.EN) -> str: return greetings[locale] field = get_object_definition(Query, strict=True).fields[0] diff --git a/tests/schema/test_enum.py b/tests/schema/test_enum.py index 7adda83ea2..57b9f80717 100644 --- a/tests/schema/test_enum.py +++ b/tests/schema/test_enum.py @@ -129,7 +129,7 @@ class IceCreamFlavour(Enum): @strawberry.input class Input: flavour: IceCreamFlavour - optionalFlavour: typing.Optional[IceCreamFlavour] = None + optional_flavour: typing.Optional[IceCreamFlavour] = None @strawberry.type class Query: diff --git a/tests/schema/test_execution.py b/tests/schema/test_execution.py index 031e65b456..6c66e62cd6 100644 --- a/tests/schema/test_execution.py +++ b/tests/schema/test_execution.py @@ -10,7 +10,7 @@ from strawberry.extensions import AddValidationRules, DisableValidation -@pytest.mark.parametrize("validate_queries", (True, False)) +@pytest.mark.parametrize("validate_queries", [True, False]) @patch("strawberry.schema.execute.validate", wraps=validate) def test_enabling_query_validation_sync(mock_validate, validate_queries): @strawberry.type @@ -43,7 +43,7 @@ class Query: @pytest.mark.asyncio -@pytest.mark.parametrize("validate_queries", (True, False)) +@pytest.mark.parametrize("validate_queries", [True, False]) async def test_enabling_query_validation(validate_queries): @strawberry.type class Query: diff --git a/tests/schema/test_generics.py b/tests/schema/test_generics.py index b442d2e5cd..eb547312d2 100644 --- a/tests/schema/test_generics.py +++ b/tests/schema/test_generics.py @@ -996,7 +996,7 @@ class Book(Node[str]): class Query: @strawberry.field def books(self) -> list[Book]: - return list() + return [] schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_input.py b/tests/schema/test_input.py index 5fe43244e1..afe2dc9439 100644 --- a/tests/schema/test_input.py +++ b/tests/schema/test_input.py @@ -96,7 +96,7 @@ def example(self, data: Input) -> ExampleOutput: } """ result = schema.execute_sync( - query, variable_values=dict(input_data=dict(nonScalarField={})) + query, variable_values={"input_data": {"nonScalarField": {}}} ) assert not result.errors diff --git a/tests/schema/test_interface.py b/tests/schema/test_interface.py index bfa7b0edd4..c54d7f887c 100644 --- a/tests/schema/test_interface.py +++ b/tests/schema/test_interface.py @@ -191,7 +191,7 @@ class Anime(Entity): class Query: @strawberry.field def anime(self) -> Anime: - return dict(id=1, name="One Piece") # type: ignore + return {"id": 1, "name": "One Piece"} # type: ignore schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_lazy_types/type_a.py b/tests/schema/test_lazy_types/type_a.py index dac9ba9c1e..afcba30b30 100644 --- a/tests/schema/test_lazy_types/type_a.py +++ b/tests/schema/test_lazy_types/type_a.py @@ -15,7 +15,7 @@ class TypeA: ] = None @strawberry.field - def type_b(self) -> strawberry.LazyType["TypeB", ".type_b"]: # noqa + def type_b(self) -> strawberry.LazyType["TypeB", ".type_b"]: # noqa: F722 from .type_b import TypeB return TypeB() diff --git a/tests/schema/test_one_of.py b/tests/schema/test_one_of.py index a1139a97f9..7a0e63ab16 100644 --- a/tests/schema/test_one_of.py +++ b/tests/schema/test_one_of.py @@ -32,12 +32,12 @@ def test(self, input: ExampleInputTagged) -> ExampleResult: @pytest.mark.parametrize( ("default_value", "variables"), - ( + [ ("{a: null, b: null}", {}), ('{ a: "abc", b: 123 }', {}), ("{a: null, b: 123}", {}), ("{}", {}), - ), + ], ) def test_must_specify_at_least_one_key_default( default_value: str, variables: dict[str, Any] @@ -220,7 +220,7 @@ class Result: class Query: @strawberry.field def test(self, input: ExampleWithLongerNames) -> Result: - return Result( # noqa + return Result( # noqa: F821 a_field=None if input.a_field is strawberry.UNSET else input.a_field, b_field=None if input.b_field is strawberry.UNSET else input.b_field, ) diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index bd687d4f21..4f3e95175b 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -500,10 +500,6 @@ class Query: def name(self) -> User: # pragma: no cover return User(name="ABC") - error = re.escape( - "Cannot use fail_silently=True with a non-optional " "or non-list field" - ) - strawberry.Schema(query=Query) diff --git a/tests/schema/test_resolvers.py b/tests/schema/test_resolvers.py index 0bcf6f41d5..1ab97cea3f 100644 --- a/tests/schema/test_resolvers.py +++ b/tests/schema/test_resolvers.py @@ -514,7 +514,7 @@ def arbitrarily_named_info(icon: str, info_argument: Info) -> str: @pytest.mark.parametrize( ("resolver", "deprecation"), - ( + [ pytest.param( name_based_info, pytest.deprecated_call(match="Argument name-based matching of"), @@ -522,7 +522,7 @@ def arbitrarily_named_info(icon: str, info_argument: Info) -> str: pytest.param(type_based_info, nullcontext()), pytest.param(generic_type_based_info, nullcontext()), pytest.param(arbitrarily_named_info, nullcontext()), - ), + ], ) def test_info_argument(resolver, deprecation): with deprecation: @@ -565,10 +565,10 @@ def static_method_parent(asdf: Parent[UserLiteral]) -> str: @pytest.mark.parametrize( "resolver", - ( + [ pytest.param(parent_no_self), pytest.param(Foo.static_method_parent), - ), + ], ) def test_parent_argument(resolver): @strawberry.type @@ -610,11 +610,11 @@ def multiple_infos(root, info1: Info, info2: Info) -> str: @pytest.mark.parametrize( "resolver", - ( + [ pytest.param(parent_self_and_root), pytest.param(multiple_parents), pytest.param(multiple_infos), - ), + ], ) @pytest.mark.raises_strawberry_exception( ConflictingArgumentsError, @@ -631,7 +631,7 @@ class Query: strawberry.Schema(query=Query) -@pytest.mark.parametrize("resolver", (parent_and_self, self_and_root)) +@pytest.mark.parametrize("resolver", [parent_and_self, self_and_root]) def test_self_should_not_raise_conflicting_arguments_error(resolver): @strawberry.type class Query: diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index 4ca0f5d87d..c9ae295531 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -97,14 +97,14 @@ async def example(self, name: str) -> AsyncGenerator[str, None]: @pytest.mark.parametrize( "return_annotation", - ( + [ "AsyncGenerator[str, None]", "AsyncIterable[str]", "AsyncIterator[str]", "abc.AsyncIterator[str]", "abc.AsyncGenerator[str, None]", "abc.AsyncIterable[str]", - ), + ], ) @pytest.mark.asyncio async def test_subscription_return_annotations(return_annotation: str): diff --git a/tests/schema/test_union.py b/tests/schema/test_union.py index edd219c25e..71cc92bbf4 100644 --- a/tests/schema/test_union.py +++ b/tests/schema/test_union.py @@ -520,16 +520,15 @@ class Input: name: str something: Union[A, B] + @strawberry.type + class Query: + @strawberry.field + def user(self, data: Input) -> User: + return User(name=data.name, age=100) + with pytest.raises( TypeError, match="Union for A is not supported because it is an Input type" ): - - @strawberry.type - class Query: - @strawberry.field - def user(self, data: Input) -> User: - return User(name=data.name, age=100) - strawberry.Schema(query=Query) diff --git a/tests/schema/types/test_date.py b/tests/schema/types/test_date.py index 67eee59890..dbd75bac12 100644 --- a/tests/schema/types/test_date.py +++ b/tests/schema/types/test_date.py @@ -90,13 +90,13 @@ def date_input(self, date_input: datetime.date) -> datetime.date: @pytest.mark.parametrize( "value", - ( + [ "2012-12-01T09:00", "2012-13-01", "2012-04-9", # this might have been fixed in 3.11 # "20120411", - ), + ], ) def test_serialization_of_incorrect_date_string(value): """Test GraphQLError is raised for incorrect date. diff --git a/tests/schema/types/test_datetime.py b/tests/schema/types/test_datetime.py index a03819d08c..45dece84f3 100644 --- a/tests/schema/types/test_datetime.py +++ b/tests/schema/types/test_datetime.py @@ -136,7 +136,7 @@ def datetime_input( @pytest.mark.parametrize( "value", - ( + [ "2012-13-01", "2012-04-9", "20120411T03:30+", @@ -146,7 +146,7 @@ def datetime_input( "20120411T03:30+00:61", "20120411T033030.123456012:00" "2014-03-12T12:30:14", "2014-04-21T24:00:01", - ), + ], ) def test_serialization_of_incorrect_datetime_string(value): """Test GraphQLError is raised for incorrect datetime. diff --git a/tests/schema/types/test_time.py b/tests/schema/types/test_time.py index 2a45037335..bb7a7f8e2f 100644 --- a/tests/schema/types/test_time.py +++ b/tests/schema/types/test_time.py @@ -90,12 +90,12 @@ def time_input(self, time_input: datetime.time) -> datetime.time: @pytest.mark.parametrize( "value", - ( + [ "2012-12-01T09:00", "03:30+", "03:30+1234567", "03:30-25:40", - ), + ], ) def test_serialization_of_incorrect_time_string(value): """Test GraphQLError is raised for incorrect time. diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index d56763a24c..2a440458e0 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -158,7 +158,7 @@ async def test_caches_by_id_when_loading_many(mocker: MockerFixture): assert a == b - assert [1, 1] == await asyncio.gather(a, b) + assert await asyncio.gather(a, b) == [1, 1] mock_loader.assert_called_once_with([1]) @@ -330,8 +330,8 @@ async def idx(keys: list[int]) -> list[int]: value_b = cast("Future[Any]", loader.load(2)) value_b.cancel() # value_c will be cancelled by the timeout + value_c = cast("Future[Any]", loader.load(3)) with pytest.raises(asyncio.TimeoutError): - value_c = cast("Future[Any]", loader.load(3)) await asyncio.wait_for(value_c, 0.1) value_d = await loader.load(4) @@ -392,7 +392,7 @@ def clear(self) -> None: loader.clear(1) assert len(custom_cache.cache) == 2 - assert sorted(list(custom_cache.cache.keys())) == [2, 3] + assert sorted(custom_cache.cache.keys()) == [2, 3] loader.clear_all() assert len(custom_cache.cache) == 0 @@ -429,7 +429,7 @@ def custom_cache_key(key: list[int]) -> str: loader = DataLoader(load_fn=idx, cache_key_fn=custom_cache_key) data = await loader.load([1, 2, "test"]) - assert [1, 2, "test"] == data + assert data == [1, 2, "test"] @pytest.mark.asyncio diff --git a/tests/test_printer/test_basic.py b/tests/test_printer/test_basic.py index f4b14a5f6b..bac73e9b16 100644 --- a/tests/test_printer/test_basic.py +++ b/tests/test_printer/test_basic.py @@ -276,15 +276,15 @@ class MyInput: @strawberry.type class Query: @strawberry.field - def search(self, j: JSON = {}) -> JSON: + def search(self, j: JSON = {}) -> JSON: # noqa: B006 return j @strawberry.field - def search2(self, j: JSON = {"hello": "world"}) -> JSON: + def search2(self, j: JSON = {"hello": "world"}) -> JSON: # noqa: B006 return j @strawberry.field - def search3(self, j: JSON = {"hello": {"nice": "world"}}) -> JSON: + def search3(self, j: JSON = {"hello": {"nice": "world"}}) -> JSON: # noqa: B006 return j expected_type = """ diff --git a/tests/typecheckers/utils/mypy.py b/tests/typecheckers/utils/mypy.py index 85e858c627..8fc44ad230 100644 --- a/tests/typecheckers/utils/mypy.py +++ b/tests/typecheckers/utils/mypy.py @@ -76,8 +76,8 @@ def run_mypy(code: str, strict: bool = True) -> list[Result]: column=mypy_result["column"] + 1, ) ) - except json.JSONDecodeError: - raise Exception(f"Invalid JSON: {full_output}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON: {full_output}") from e results.sort(key=lambda x: (x.line, x.column, x.message)) diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/cross_module_resolvers/__init__.py b/tests/types/cross_module_resolvers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/cross_module_resolvers/c_mod.py b/tests/types/cross_module_resolvers/c_mod.py index 9182310e8c..f9f5443e9e 100644 --- a/tests/types/cross_module_resolvers/c_mod.py +++ b/tests/types/cross_module_resolvers/c_mod.py @@ -1,11 +1,9 @@ -import a_mod -import b_mod -import x_mod -from a_mod import AObject as C_AObject -from b_mod import BObject as C_BObject - import strawberry +from . import a_mod, b_mod, x_mod +from .a_mod import AObject as C_AObject +from .b_mod import BObject as C_BObject + def c_inheritance_resolver() -> list["CInheritance"]: pass diff --git a/tests/types/cross_module_resolvers/test_cross_module_resolvers.py b/tests/types/cross_module_resolvers/test_cross_module_resolvers.py index bc9241c4ca..6e04ad7ff5 100644 --- a/tests/types/cross_module_resolvers/test_cross_module_resolvers.py +++ b/tests/types/cross_module_resolvers/test_cross_module_resolvers.py @@ -4,13 +4,10 @@ (forward reference) and can only be resolved at schema construction. """ -import a_mod -import b_mod -import c_mod -import x_mod - import strawberry +from . import a_mod, b_mod, c_mod, x_mod + def test_a(): @strawberry.type diff --git a/tests/types/resolving/__init__.py b/tests/types/resolving/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/resolving/test_union_pipe.py b/tests/types/resolving/test_union_pipe.py index 37e0253c80..ad016e679f 100644 --- a/tests/types/resolving/test_union_pipe.py +++ b/tests/types/resolving/test_union_pipe.py @@ -98,7 +98,7 @@ class Error: class Query: user: UserOrError | int - schema = strawberry.Schema(query=Query) + strawberry.Schema(query=Query) @pytest.mark.raises_strawberry_exception( diff --git a/tests/types/resolving/test_unions.py b/tests/types/resolving/test_unions.py index afb7efc4f9..2f28055aa3 100644 --- a/tests/types/resolving/test_unions.py +++ b/tests/types/resolving/test_unions.py @@ -123,7 +123,7 @@ def test_error_with_scalar_types(): class Query: something: Something - schema = strawberry.Schema(query=Query) + strawberry.Schema(query=Query) @pytest.mark.raises_strawberry_exception( @@ -145,4 +145,4 @@ def test_error_with_scalar_types_pipe(): class Query: something: Something2 - schema = strawberry.Schema(query=Query) + strawberry.Schema(query=Query) diff --git a/tests/types/test_annotation.py b/tests/types/test_annotation.py index 969dddf871..ee7aa5cb54 100644 --- a/tests/types/test_annotation.py +++ b/tests/types/test_annotation.py @@ -58,7 +58,7 @@ def __eq__(self, other): assert Foo() != object() assert object() != Foo() assert Foo() != 123 != Foo() - assert 123 != Foo() + assert Foo() != 123 assert Foo() == StrawberryAnnotation(int) assert StrawberryAnnotation(int) == Foo() diff --git a/tests/types/test_deferred_annotations.py b/tests/types/test_deferred_annotations.py index b752a14ac7..64d283bc8f 100644 --- a/tests/types/test_deferred_annotations.py +++ b/tests/types/test_deferred_annotations.py @@ -24,7 +24,7 @@ def test_deferred_other_module(): modules[mod.__name__] = mod try: - exec(deferred_module_source, mod.__dict__) + exec(deferred_module_source, mod.__dict__) # noqa: S102 @strawberry.type class Post(mod.UserContent): diff --git a/tests/types/test_lazy_types.py b/tests/types/test_lazy_types.py index b03fb38c3c..bcf2bc8178 100644 --- a/tests/types/test_lazy_types.py +++ b/tests/types/test_lazy_types.py @@ -38,8 +38,7 @@ class LazyEnum(enum.Enum): def test_lazy_type(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") annotation = StrawberryAnnotation(LazierType) resolved = annotation.resolve() @@ -53,8 +52,7 @@ def test_lazy_type(): def test_lazy_type_alias(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LazyTypeAlias", "test_lazy_types") + LazierType = LazyType("LazyTypeAlias", "tests.types.test_lazy_types") annotation = StrawberryAnnotation(LazierType) resolved = annotation.resolve() @@ -69,7 +67,9 @@ def test_lazy_type_alias(): def test_lazy_type_function(): - LethargicType = Annotated["LaziestType", strawberry.lazy("test_lazy_types")] + LethargicType = Annotated[ + "LaziestType", strawberry.lazy("tests.types.test_lazy_types") + ] annotation = StrawberryAnnotation(LethargicType) resolved = annotation.resolve() @@ -79,8 +79,7 @@ def test_lazy_type_function(): def test_lazy_type_enum(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LazyEnum", "test_lazy_types") + LazierType = LazyType("LazyEnum", "tests.types.test_lazy_types") annotation = StrawberryAnnotation(LazierType) resolved = annotation.resolve() @@ -94,8 +93,7 @@ def test_lazy_type_enum(): def test_lazy_type_argument(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") @strawberry.mutation def slack_off(emotion: LazierType) -> bool: @@ -109,8 +107,7 @@ def slack_off(emotion: LazierType) -> bool: def test_lazy_type_field(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") annotation = StrawberryAnnotation(LazierType) field = StrawberryField(type_annotation=annotation) @@ -127,8 +124,7 @@ def test_lazy_type_generic(): class GenericType(Generic[T]): item: T - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") ResolvedType = GenericType[LazierType] annotation = StrawberryAnnotation(ResolvedType) @@ -142,8 +138,7 @@ class GenericType(Generic[T]): def test_lazy_type_object(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") @strawberry.type class WaterParkFeature: @@ -157,8 +152,7 @@ class WaterParkFeature: def test_lazy_type_resolver(): - # Module path is short and relative because of the way pytest runs the file - LazierType = LazyType("LaziestType", "test_lazy_types") + LazierType = LazyType("LaziestType", "tests.types.test_lazy_types") def slaking_pokemon() -> LazierType: raise NotImplementedError @@ -170,8 +164,8 @@ def slaking_pokemon() -> LazierType: def test_lazy_type_in_union(): - ActiveType = LazyType("LaziestType", "test_lazy_types") - ActiveEnum = LazyType("LazyEnum", "test_lazy_types") + ActiveType = LazyType("LaziestType", "tests.types.test_lazy_types") + ActiveEnum = LazyType("LazyEnum", "tests.types.test_lazy_types") something = Annotated[Union[ActiveType, ActiveEnum], union(name="CoolUnion")] annotation = StrawberryAnnotation(something) @@ -187,8 +181,10 @@ def test_lazy_type_in_union(): def test_lazy_function_in_union(): - ActiveType = Annotated["LaziestType", strawberry.lazy("test_lazy_types")] - ActiveEnum = Annotated["LazyEnum", strawberry.lazy("test_lazy_types")] + ActiveType = Annotated[ + "LaziestType", strawberry.lazy("tests.types.test_lazy_types") + ] + ActiveEnum = Annotated["LazyEnum", strawberry.lazy("tests.types.test_lazy_types")] something = Annotated[Union[ActiveType, ActiveEnum], union(name="CoolUnion")] annotation = StrawberryAnnotation(something) diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index e12dfa8632..7b784c2168 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -4,7 +4,7 @@ import pytest -from ..http.clients.base import HttpClient +from tests.http.clients.base import HttpClient def _get_http_client_classes() -> Generator[Any, None, None]: @@ -38,6 +38,6 @@ def http_client_class(request: Any) -> type[HttpClient]: return request.param -@pytest.fixture() +@pytest.fixture def http_client(http_client_class: type[HttpClient]) -> HttpClient: return http_client_class() diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 787d215aac..c075505738 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -28,7 +28,7 @@ from tests.views.schema import MyExtension, Schema if TYPE_CHECKING: - from ..http.clients.base import HttpClient, WebSocketClient + from tests.http.clients.base import HttpClient, WebSocketClient @pytest_asyncio.fixture diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 246caf76b5..a8d7d56be0 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -23,7 +23,7 @@ from tests.views.schema import MyExtension, Schema if TYPE_CHECKING: - from ..http.clients.aiohttp import HttpClient, WebSocketClient + from tests.http.clients.aiohttp import HttpClient, WebSocketClient @pytest_asyncio.fixture @@ -608,7 +608,7 @@ async def test_task_cancellation_separation(aiohttp_app_client: HttpClient): # This only works for aiohttp, where we are using the same event loop # on the client side and server. try: - from ..http.clients.aiohttp import AioHttpClient + from tests.http.clients.aiohttp import AioHttpClient aio = aiohttp_app_client == AioHttpClient # type: ignore except ImportError: diff --git a/tests/websockets/views.py b/tests/websockets/views.py index 981ad96354..bad7e989f5 100644 --- a/tests/websockets/views.py +++ b/tests/websockets/views.py @@ -33,7 +33,7 @@ async def on_ws_connect( if connection_params.get("test-reject"): if "err-payload" in connection_params: raise ConnectionRejectionError(connection_params["err-payload"]) - raise ConnectionRejectionError() + raise ConnectionRejectionError if connection_params.get("test-accept"): if "ack-payload" in connection_params: