diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 99bfa3d75cf..acc746128fc 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -85,8 +85,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, } } + pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI()) + if err != nil { + return nil, err + } + if want[protocol.RefactorExtract] { - extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options()) + extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options()) if err != nil { return nil, err } @@ -198,20 +203,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) } // getExtractCodeActions returns any refactor.extract code actions for the selection. -func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) { - if rng.Start == rng.End { - return nil, nil - } - +func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) { start, end, err := pgf.RangePos(rng) if err != nil { return nil, err } + puri := pgf.URI var commands []protocol.Command - if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok { - cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{ - Fix: fixExtractFunction, + + if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok { + cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{ + Fix: fixExtractInterface, URI: puri, Range: rng, ResolveEdits: supportsResolveEdits(options), @@ -220,9 +223,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti return nil, err } commands = append(commands, cmd) - if methodOk { - cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{ - Fix: fixExtractMethod, + } + + if rng.Start != rng.End { + if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok { + cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{ + Fix: fixExtractFunction, URI: puri, Range: rng, ResolveEdits: supportsResolveEdits(options), @@ -231,20 +237,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti return nil, err } commands = append(commands, cmd) + if methodOk { + cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{ + Fix: fixExtractMethod, + URI: puri, + Range: rng, + ResolveEdits: supportsResolveEdits(options), + }) + if err != nil { + return nil, err + } + commands = append(commands, cmd) + } } - } - if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok { - cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{ - Fix: fixExtractVariable, - URI: puri, - Range: rng, - ResolveEdits: supportsResolveEdits(options), - }) - if err != nil { - return nil, err + if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok { + cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{ + Fix: fixExtractVariable, + URI: puri, + Range: rng, + ResolveEdits: supportsResolveEdits(options), + }) + if err != nil { + return nil, err + } + commands = append(commands, cmd) } - commands = append(commands, cmd) } + var actions []protocol.CodeAction for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options)) diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index c07faec1b7a..46ac3b29a33 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -18,6 +18,7 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/analysisinternal" @@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) } +// CanExtractInterface reports whether the code in the given position is for a +// type which can be represented as an interface. +func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) { + path, _ := astutil.PathEnclosingInterval(file, start, end) + if len(path) == 0 { + return nil, nil, false, fmt.Errorf("no path enclosing interval") + } + + node := path[0] + expr, ok := node.(ast.Expr) + if !ok { + return nil, nil, false, fmt.Errorf("node is not an expression") + } + + switch e := expr.(type) { + case *ast.Ident: + o, ok := pkg.TypesInfo().ObjectOf(e).(*types.TypeName) + if !ok { + return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr) + } + + if _, ok := o.Type().(*types.Basic); ok { + return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface") + } + + return expr, path, true, nil + case *ast.StarExpr, *ast.SelectorExpr: + return expr, path, true, nil + default: + return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr) + } +} + // Calculate indentation for insertion. // When inserting lines of code, we must ensure that the lines have consistent // formatting (i.e. the proper indentation). To do so, we observe the indentation on the diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index 2215da9b65e..a2d7748983f 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -5,13 +5,17 @@ package golang import ( + "bytes" "context" + "errors" "fmt" "go/ast" "go/token" "go/types" + "slices" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillstruct" "golang.org/x/tools/gopls/internal/analysis/stubmethods" @@ -22,6 +26,7 @@ import ( "golang.org/x/tools/gopls/internal/file" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/bug" + "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/imports" ) @@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer { const ( fixExtractVariable = "extract_variable" fixExtractFunction = "extract_function" + fixExtractInterface = "extract_interface" fixExtractMethod = "extract_method" fixInlineCall = "inline_call" fixInvertIfCondition = "invert_if_condition" @@ -112,6 +118,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file // Ad-hoc fixers: these are used when the command is // constructed directly by logic in server/code_action. + fixExtractInterface: extractInterface, fixExtractFunction: singleFile(extractFunction), fixExtractMethod: singleFile(extractMethod), fixExtractVariable: singleFile(extractVariable), @@ -142,6 +149,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion) } +func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { + path, _ := astutil.PathEnclosingInterval(pgf.File, start, end) + + var field *ast.Field + var decl ast.Decl + for _, node := range path { + if f, ok := node.(*ast.Field); ok { + field = f + continue + } + + // Record the node that starts the declaration of the type that contains + // the field we are creating the interface for. + if d, ok := node.(ast.Decl); ok { + decl = d + break // we have both the field and the declaration + } + } + + if field == nil || decl == nil { + return nil, nil, nil + } + + p := safetoken.StartPosition(pkg.FileSet(), field.Pos()) + pos := protocol.Position{ + Line: uint32(p.Line - 1), // Line is zero-based + Character: uint32(p.Column - 1), // Character is zero-based + } + + fh, err := snapshot.ReadFile(ctx, pgf.URI) + if err != nil { + return nil, nil, err + } + + refs, err := references(ctx, snapshot, fh, pos, false) + if err != nil { + return nil, nil, err + } + + type method struct { + signature *types.Signature + name string + } + + var methods []method + for _, ref := range refs { + locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI) + if err != nil { + return nil, nil, err + } + + _, end, err := locPgf.RangePos(ref.location.Range) + if err != nil { + return nil, nil, err + } + + // We are interested in the method call, so we need the node after the dot + rangeEnd := end + token.Pos(len(".")) + path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd) + id, ok := path[0].(*ast.Ident) + if !ok { + continue + } + + obj := locPkg.TypesInfo().ObjectOf(id) + if obj == nil { + continue + } + + sig, ok := obj.Type().(*types.Signature) + if !ok { + return nil, nil, errors.New("cannot extract interface with non-method accesses") + } + + fc := method{signature: sig, name: obj.Name()} + if !slices.Contains(methods, fc) { + methods = append(methods, fc) + } + } + + interfaceName := "I" + pkg.TypesInfo().ObjectOf(field.Names[0]).Name() + var buf bytes.Buffer + buf.WriteString("\ntype ") + buf.WriteString(interfaceName) + buf.WriteString(" interface {\n") + for _, fc := range methods { + buf.WriteString("\t") + buf.WriteString(fc.name) + types.WriteSignature(&buf, fc.signature, relativeTo(pkg.Types())) + buf.WriteByte('\n') + } + buf.WriteByte('}') + buf.WriteByte('\n') + + interfacePos := decl.Pos() - 1 + // Move the interface above the documentation comment if the type declaration + // includes one. + switch d := decl.(type) { + case *ast.GenDecl: + if d.Doc != nil { + interfacePos = d.Doc.Pos() - 1 + } + case *ast.FuncDecl: + if d.Doc != nil { + interfacePos = d.Doc.Pos() - 1 + } + } + + return pkg.FileSet(), &analysis.SuggestedFix{ + Message: "Extract interface", + TextEdits: []analysis.TextEdit{{ + Pos: interfacePos, + End: interfacePos, + NewText: buf.Bytes(), + }, { + Pos: field.Type.Pos(), + End: field.Type.End(), + NewText: []byte(interfaceName), + }}, + }, nil +} + +func relativeTo(pkg *types.Package) types.Qualifier { + if pkg == nil { + return nil + } + return func(other *types.Package) string { + if pkg == other { + return "" // same package; unqualified + } + return other.Name() + } +} + // suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form. func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) { editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{} diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt b/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt new file mode 100644 index 00000000000..485c1bbde7e --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_interface.txt @@ -0,0 +1,82 @@ +This test checks the behavior of the 'extract interface' code action. +See extract_interface_resolve.txt for the same test with resolve support. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/extract + +go 1.18 + +-- b/b.go -- +package b + +type BFoo struct {} + +func (b BFoo) Bar() string { + return "" +} + +func (b BFoo) Baz() int { + return 0 +} + +-- a.go -- +package extract + +import ( + "golang.org/lsptests/extract/b" +) + +// foo doc comment +type foo struct { + fieldOne bar //@codeactionedit("bar", "refactor.extract", a1) + fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2) +} + +type bar struct {} + +func (b bar) baz() error {} +func (b bar) qux(a string, b int, c func() string) {} + +func (f foo) quux() { + f.fieldTwo.Bar() + f.fieldOne.baz() +} + +func (f foo) corge() { + f.fieldOne.qux("someString", 3, func() string { return "" }) +} + +func FuncThatUsesBar(b *bar) { //@codeactionedit("bar", "refactor.extract", a3) + b.qux() +} + +-- @a1/a.go -- +@@ -7 +7,5 @@ ++type IfieldOne interface { ++ baz() error ++ qux(a string, b int, c func() string) ++} ++ +@@ -9 +14 @@ +- fieldOne bar //@codeactionedit("bar", "refactor.extract", a1) ++ fieldOne IfieldOne //@codeactionedit("bar", "refactor.extract", a1) +-- @a2/a.go -- +@@ -7 +7,4 @@ ++type IfieldTwo interface { ++ Bar() string ++} ++ +@@ -10 +14 @@ +- fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2) ++ fieldTwo IfieldTwo //@codeactionedit("BFoo", "refactor.extract", a2) +-- @a3/a.go -- +@@ -27 +27,5 @@ +-func FuncThatUsesBar(b *bar) { //@codeactionedit("bar", "refactor.extract", a3) ++type Ib interface { ++ qux(a string, b int, c func() string) ++} ++ ++func FuncThatUsesBar(b Ib) { //@codeactionedit("bar", "refactor.extract", a3) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt new file mode 100644 index 00000000000..d4ee4099138 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_interface_resolve.txt @@ -0,0 +1,81 @@ +This test checks the behavior of the 'extract interface' code action. +See extract_interface_resolve.txt for the same test with resolve support. + +-- capabilities.json -- +{ + "textDocument": { + "codeAction": { + "dataSupport": true, + "resolveSupport": { + "properties": ["edit"] + } + } + } +} +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/extract + +go 1.18 + +-- b/b.go -- +package b + +type BFoo struct {} + +func (b BFoo) Bar() string { + return "" +} + +func (b BFoo) Baz() int { + return 0 +} + +-- a.go -- +package extract + +import ( + "golang.org/lsptests/extract/b" +) + +// foo doc comment +type foo struct { + fieldOne bar //@codeactionedit("bar", "refactor.extract", a1) + fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2) +} + +type bar struct {} + +func (b bar) baz() error {} +func (b bar) qux(a string, b int, c func() string) {} + +func (f foo) quux() { + f.fieldTwo.Bar() + f.fieldOne.baz() +} + +func (f foo) corge() { + f.fieldOne.qux("someString", 3, func() string { return "" }) +} + +-- @a1/a.go -- +@@ -7 +7,5 @@ ++type IfieldOne interface { ++ baz() error ++ qux(a string, b int, c func() string) ++} ++ +@@ -9 +14 @@ +- fieldOne bar //@codeactionedit("bar", "refactor.extract", a1) ++ fieldOne IfieldOne //@codeactionedit("bar", "refactor.extract", a1) +-- @a2/a.go -- +@@ -7 +7,4 @@ ++type IfieldTwo interface { ++ Bar() string ++} ++ +@@ -10 +14 @@ +- fieldTwo b.BFoo //@codeactionedit("BFoo", "refactor.extract", a2) ++ fieldTwo IfieldTwo //@codeactionedit("BFoo", "refactor.extract", a2)