Skip to content

Commit

Permalink
Add codemod for resolver type
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaspiatkowski committed Feb 15, 2024
1 parent 1560cd3 commit 42c21a7
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 16 deletions.
35 changes: 19 additions & 16 deletions strawberry/cli/commands/upgrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,31 @@
import glob
import pathlib # noqa: TCH003
import sys
from enum import Enum
from typing import List
from typing_extensions import assert_never

import rich
import typer
from libcst.codemod import CodemodContext
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand

from strawberry.cli.app import app
from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion
from strawberry.codemods.resolver_type import ConvertFieldWithResolverTypeAnnotations

from ._run_codemod import run_codemod

codemods = {
"annotated-union": ConvertUnionToAnnotatedUnion,
}

class Codemod(Enum):
ANNOTATED_UNION = "annotated-union"
RESOLVER_TYPE = "resolver-type"


# TODO: add support for running all of them
@app.command(help="Upgrades a Strawberry project to the latest version")
def upgrade(
codemod: str = typer.Argument(
codemod: Codemod = typer.Argument(
...,
autocompletion=lambda: list(codemods.keys()),
help="Name of the upgrade to run",
),
paths: List[pathlib.Path] = typer.Argument(..., file_okay=True, dir_okay=True),
Expand All @@ -39,18 +42,18 @@ def upgrade(
help="Use typing_extensions instead of typing for newer features",
),
) -> None:
if codemod not in codemods:
rich.print(f'[red]Upgrade named "{codemod}" does not exist')

raise typer.Exit(2)

python_target_version = tuple(int(x) for x in python_target.split("."))

transformer = ConvertUnionToAnnotatedUnion(
CodemodContext(),
use_pipe_syntax=python_target_version >= (3, 10),
use_typing_extensions=use_typing_extensions,
)
if codemod == Codemod.ANNOTATED_UNION:
transformer: VisitorBasedCodemodCommand = ConvertUnionToAnnotatedUnion(
CodemodContext(),
use_pipe_syntax=python_target_version >= (3, 10),
use_typing_extensions=use_typing_extensions,
)
elif codemod == Codemod.RESOLVER_TYPE:
transformer = ConvertFieldWithResolverTypeAnnotations(CodemodContext())
else:
return assert_never(codemod)

files: list[str] = []

Expand Down
139 changes: 139 additions & 0 deletions strawberry/codemods/resolver_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from typing import List, Optional, Union

import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor


class ConvertFieldWithResolverTypeAnnotations(VisitorBasedCodemodCommand):
DESCRIPTION: str = (
"Converts field: T = strawberry.field(resolver=...) to "
"field: strawberry.Resolver[T] = strawberry.field(resolver=...)"
)

def __init__(
self,
context: CodemodContext,
) -> None:
self._is_using_named_import = False

super().__init__(context)

def visit_Module(self, node: cst.Module) -> Optional[bool]:
self._is_using_named_import = False

return super().visit_Module(node)

@m.visit(
m.ImportFrom(
m.Name("strawberry"),
[
m.ZeroOrMore(),
m.ImportAlias(m.Name("field")),
m.ZeroOrMore(),
],
)
)
def visit_import_from(self, original_node: cst.ImportFrom) -> None:
self._is_using_named_import = True

def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
new_body = []

for node in updated_node.body.body:
if not m.matches(
node,
m.SimpleStatementLine(
body=[
m.AnnAssign(value=self._field_with_resolver_call_matcher()),
]
),
):
new_body.append(node)
continue

self._add_imports()

node_stmt = cst.ensure_type(node, cst.SimpleStatementLine)
ann_assign = cst.ensure_type(node_stmt.body[0], cst.AnnAssign)

new_annotation = self._new_annotation_wrapped_in_resolver(
ann_assign.annotation
)

new_body.append(
node_stmt.with_changes(
body=[
ann_assign.with_changes(
annotation=new_annotation,
)
]
)
)

return updated_node.with_changes(
body=updated_node.body.with_changes(body=new_body)
)

def _field_with_resolver_call_matcher(self) -> m.Call:
"""Matches a call to strawberry.field with a resolver argument."""

args: List[Union[m.ArgMatchType, m.AtLeastN[m.DoNotCareSentinel]]] = [
m.ZeroOrMore(),
m.Arg(
keyword=m.Name("resolver"),
),
m.ZeroOrMore(),
]

if self._is_using_named_import:
return m.Call(
func=m.Name("field"),
args=args,
)

return m.Call(
func=m.Attribute(
value=m.Name("strawberry"),
attr=m.Name("field"),
),
args=args,
)

def _add_imports(self) -> None:
"""Add named import of Resolver if this module uses named import of field."""
if self._is_using_named_import:
AddImportsVisitor.add_needed_import(
self.context,
"strawberry",
"Resolver",
)

def _new_annotation_wrapped_in_resolver(
self, annotation: cst.Annotation
) -> cst.Annotation:
"""Wraps the annotation in a strawberry.Resolver[] type."""

if self._is_using_named_import:
resolver_type: Union[cst.Name, cst.Attribute] = cst.Name("Resolver")
else:
resolver_type = cst.Attribute(
value=cst.Name("strawberry"),
attr=cst.Name("Resolver"),
)

return annotation.with_changes(
annotation=cst.Subscript(
value=resolver_type,
slice=[
cst.SubscriptElement(
slice=cst.Index(value=annotation.annotation),
),
],
)
)
60 changes: 60 additions & 0 deletions tests/codemods/test_resolver_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from libcst.codemod import CodemodTest

from strawberry.codemods.resolver_type import ConvertFieldWithResolverTypeAnnotations


class TestConvertConstantCommand(CodemodTest):
TRANSFORM = ConvertFieldWithResolverTypeAnnotations

def test_update_annotation(self) -> None:
before = """
class User:
name: str = strawberry.field(description="Name")
age: int = strawberry.field(resolver=get_user_age)
"""

after = """
class User:
name: str = strawberry.field(description="Name")
age: strawberry.Resolver[int] = strawberry.field(resolver=get_user_age)
"""

self.assertCodemod(before, after)

def test_update_annotation_named_import(self) -> None:
before = """
from strawberry import field
class User:
name: str = field(description="Name")
age: int = field(resolver=get_user_age)
"""

after = """
from strawberry import Resolver, field
class User:
name: str = field(description="Name")
age: Resolver[int] = field(resolver=get_user_age)
"""

self.assertCodemod(before, after)

def test_noop_no_resolvers(self) -> None:
before = """
from strawberry import field
class User:
name: str = field(description="Name")
age: int = field(deprecation_reason="some")
"""

after = """
from strawberry import field
class User:
name: str = field(description="Name")
age: int = field(deprecation_reason="some")
"""

self.assertCodemod(before, after)

0 comments on commit 42c21a7

Please sign in to comment.