From 42c21a789feed81c56efd82ad8423bc178eacb99 Mon Sep 17 00:00:00 2001 From: Lukasz Piatkowski Date: Thu, 15 Feb 2024 13:15:55 +0100 Subject: [PATCH] Add codemod for resolver type --- strawberry/cli/commands/upgrade/__init__.py | 35 ++--- strawberry/codemods/resolver_type.py | 139 ++++++++++++++++++++ tests/codemods/test_resolver_type.py | 60 +++++++++ 3 files changed, 218 insertions(+), 16 deletions(-) create mode 100644 strawberry/codemods/resolver_type.py create mode 100644 tests/codemods/test_resolver_type.py diff --git a/strawberry/cli/commands/upgrade/__init__.py b/strawberry/cli/commands/upgrade/__init__.py index 5ea4761e23..dcf6758878 100644 --- a/strawberry/cli/commands/upgrade/__init__.py +++ b/strawberry/cli/commands/upgrade/__init__.py @@ -3,28 +3,31 @@ import glob import pathlib # noqa: TCH003 import sys +from enum import Enum from typing import List +from typing_extensions import assert_never import rich import typer -from libcst.codemod import CodemodContext +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from strawberry.cli.app import app from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion +from strawberry.codemods.resolver_type import ConvertFieldWithResolverTypeAnnotations from ._run_codemod import run_codemod -codemods = { - "annotated-union": ConvertUnionToAnnotatedUnion, -} + +class Codemod(Enum): + ANNOTATED_UNION = "annotated-union" + RESOLVER_TYPE = "resolver-type" # TODO: add support for running all of them @app.command(help="Upgrades a Strawberry project to the latest version") def upgrade( - codemod: str = typer.Argument( + codemod: Codemod = typer.Argument( ..., - autocompletion=lambda: list(codemods.keys()), help="Name of the upgrade to run", ), paths: List[pathlib.Path] = typer.Argument(..., file_okay=True, dir_okay=True), @@ -39,18 +42,18 @@ def upgrade( help="Use typing_extensions instead of typing for newer features", ), ) -> None: - if codemod not in codemods: - rich.print(f'[red]Upgrade named "{codemod}" does not exist') - - raise typer.Exit(2) - python_target_version = tuple(int(x) for x in python_target.split(".")) - transformer = ConvertUnionToAnnotatedUnion( - CodemodContext(), - use_pipe_syntax=python_target_version >= (3, 10), - use_typing_extensions=use_typing_extensions, - ) + if codemod == Codemod.ANNOTATED_UNION: + transformer: VisitorBasedCodemodCommand = ConvertUnionToAnnotatedUnion( + CodemodContext(), + use_pipe_syntax=python_target_version >= (3, 10), + use_typing_extensions=use_typing_extensions, + ) + elif codemod == Codemod.RESOLVER_TYPE: + transformer = ConvertFieldWithResolverTypeAnnotations(CodemodContext()) + else: + return assert_never(codemod) files: list[str] = [] diff --git a/strawberry/codemods/resolver_type.py b/strawberry/codemods/resolver_type.py new file mode 100644 index 0000000000..6d1b77928a --- /dev/null +++ b/strawberry/codemods/resolver_type.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import List, Optional, Union + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor + + +class ConvertFieldWithResolverTypeAnnotations(VisitorBasedCodemodCommand): + DESCRIPTION: str = ( + "Converts field: T = strawberry.field(resolver=...) to " + "field: strawberry.Resolver[T] = strawberry.field(resolver=...)" + ) + + def __init__( + self, + context: CodemodContext, + ) -> None: + self._is_using_named_import = False + + super().__init__(context) + + def visit_Module(self, node: cst.Module) -> Optional[bool]: + self._is_using_named_import = False + + return super().visit_Module(node) + + @m.visit( + m.ImportFrom( + m.Name("strawberry"), + [ + m.ZeroOrMore(), + m.ImportAlias(m.Name("field")), + m.ZeroOrMore(), + ], + ) + ) + def visit_import_from(self, original_node: cst.ImportFrom) -> None: + self._is_using_named_import = True + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + new_body = [] + + for node in updated_node.body.body: + if not m.matches( + node, + m.SimpleStatementLine( + body=[ + m.AnnAssign(value=self._field_with_resolver_call_matcher()), + ] + ), + ): + new_body.append(node) + continue + + self._add_imports() + + node_stmt = cst.ensure_type(node, cst.SimpleStatementLine) + ann_assign = cst.ensure_type(node_stmt.body[0], cst.AnnAssign) + + new_annotation = self._new_annotation_wrapped_in_resolver( + ann_assign.annotation + ) + + new_body.append( + node_stmt.with_changes( + body=[ + ann_assign.with_changes( + annotation=new_annotation, + ) + ] + ) + ) + + return updated_node.with_changes( + body=updated_node.body.with_changes(body=new_body) + ) + + def _field_with_resolver_call_matcher(self) -> m.Call: + """Matches a call to strawberry.field with a resolver argument.""" + + args: List[Union[m.ArgMatchType, m.AtLeastN[m.DoNotCareSentinel]]] = [ + m.ZeroOrMore(), + m.Arg( + keyword=m.Name("resolver"), + ), + m.ZeroOrMore(), + ] + + if self._is_using_named_import: + return m.Call( + func=m.Name("field"), + args=args, + ) + + return m.Call( + func=m.Attribute( + value=m.Name("strawberry"), + attr=m.Name("field"), + ), + args=args, + ) + + def _add_imports(self) -> None: + """Add named import of Resolver if this module uses named import of field.""" + if self._is_using_named_import: + AddImportsVisitor.add_needed_import( + self.context, + "strawberry", + "Resolver", + ) + + def _new_annotation_wrapped_in_resolver( + self, annotation: cst.Annotation + ) -> cst.Annotation: + """Wraps the annotation in a strawberry.Resolver[] type.""" + + if self._is_using_named_import: + resolver_type: Union[cst.Name, cst.Attribute] = cst.Name("Resolver") + else: + resolver_type = cst.Attribute( + value=cst.Name("strawberry"), + attr=cst.Name("Resolver"), + ) + + return annotation.with_changes( + annotation=cst.Subscript( + value=resolver_type, + slice=[ + cst.SubscriptElement( + slice=cst.Index(value=annotation.annotation), + ), + ], + ) + ) diff --git a/tests/codemods/test_resolver_type.py b/tests/codemods/test_resolver_type.py new file mode 100644 index 0000000000..ef95739e98 --- /dev/null +++ b/tests/codemods/test_resolver_type.py @@ -0,0 +1,60 @@ +from libcst.codemod import CodemodTest + +from strawberry.codemods.resolver_type import ConvertFieldWithResolverTypeAnnotations + + +class TestConvertConstantCommand(CodemodTest): + TRANSFORM = ConvertFieldWithResolverTypeAnnotations + + def test_update_annotation(self) -> None: + before = """ + class User: + name: str = strawberry.field(description="Name") + age: int = strawberry.field(resolver=get_user_age) + """ + + after = """ + class User: + name: str = strawberry.field(description="Name") + age: strawberry.Resolver[int] = strawberry.field(resolver=get_user_age) + """ + + self.assertCodemod(before, after) + + def test_update_annotation_named_import(self) -> None: + before = """ + from strawberry import field + + class User: + name: str = field(description="Name") + age: int = field(resolver=get_user_age) + """ + + after = """ + from strawberry import Resolver, field + + class User: + name: str = field(description="Name") + age: Resolver[int] = field(resolver=get_user_age) + """ + + self.assertCodemod(before, after) + + def test_noop_no_resolvers(self) -> None: + before = """ + from strawberry import field + + class User: + name: str = field(description="Name") + age: int = field(deprecation_reason="some") + """ + + after = """ + from strawberry import field + + class User: + name: str = field(description="Name") + age: int = field(deprecation_reason="some") + """ + + self.assertCodemod(before, after)