Skip to content

Commit

Permalink
gopls/internal/golang: add extract interface code action
Browse files Browse the repository at this point in the history
  • Loading branch information
martskins committed Feb 16, 2024
1 parent 7240af8 commit 3786889
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 23 deletions.
65 changes: 42 additions & 23 deletions gopls/internal/golang/codeaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
}
}

pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
if err != nil {
return nil, err
}

if want[protocol.RefactorExtract] {
extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options())
extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -179,20 +184,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
}

// getExtractCodeActions returns any refactor.extract code actions for the selection.
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
if rng.Start == rng.End {
return nil, nil
}

func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
start, end, err := pgf.RangePos(rng)
if err != nil {
return nil, err
}

puri := pgf.URI
var commands []protocol.Command
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
Fix: fixExtractFunction,

if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{
Fix: fixExtractInterface,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
Expand All @@ -201,9 +204,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
return nil, err
}
commands = append(commands, cmd)
if methodOk {
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
Fix: fixExtractMethod,
}

if rng.Start != rng.End {
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
Fix: fixExtractFunction,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
Expand All @@ -212,20 +218,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
return nil, err
}
commands = append(commands, cmd)
if methodOk {
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
Fix: fixExtractMethod,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
}
commands = append(commands, cmd)
}
}
}
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
Fix: fixExtractVariable,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
Fix: fixExtractVariable,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
}
commands = append(commands, cmd)
}
commands = append(commands, cmd)
}

var actions []protocol.CodeAction
for i := range commands {
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))
Expand Down
34 changes: 34 additions & 0 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/util/bug"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/analysisinternal"
Expand Down Expand Up @@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N
return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
}

// CanExtractInterface reports whether the code in the given position is for a
// type which can be represented as an interface.
func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
path, _ := astutil.PathEnclosingInterval(file, start, end)
if len(path) == 0 {
return nil, nil, false, fmt.Errorf("no path enclosing interval")
}

node := path[0]
expr, ok := node.(ast.Expr)
if !ok {
return nil, nil, false, fmt.Errorf("node is not an expression")
}

switch e := expr.(type) {
case *ast.Ident:
o, ok := pkg.GetTypesInfo().ObjectOf(e).(*types.TypeName)
if !ok {
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
}

if _, ok := o.Type().(*types.Basic); ok {
return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface")
}

return expr, path, true, nil
case *ast.StarExpr, *ast.SelectorExpr:
return expr, path, true, nil
default:
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
}
}

// Calculate indentation for insertion.
// When inserting lines of code, we must ensure that the lines have consistent
// formatting (i.e. the proper indentation). To do so, we observe the indentation on the
Expand Down
141 changes: 141 additions & 0 deletions gopls/internal/golang/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
package golang

import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"slices"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
Expand All @@ -22,6 +26,7 @@ import (
"golang.org/x/tools/gopls/internal/file"
"golang.org/x/tools/gopls/internal/protocol"
"golang.org/x/tools/gopls/internal/util/bug"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/imports"
)

Expand Down Expand Up @@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
const (
fixExtractVariable = "extract_variable"
fixExtractFunction = "extract_function"
fixExtractInterface = "extract_interface"
fixExtractMethod = "extract_method"
fixInlineCall = "inline_call"
fixInvertIfCondition = "invert_if_condition"
Expand Down Expand Up @@ -110,6 +116,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file

// Ad-hoc fixers: these are used when the command is
// constructed directly by logic in server/code_action.
fixExtractInterface: extractInterface,
fixExtractFunction: singleFile(extractFunction),
fixExtractMethod: singleFile(extractMethod),
fixExtractVariable: singleFile(extractVariable),
Expand Down Expand Up @@ -138,6 +145,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion)
}

func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)

var field *ast.Field
var decl ast.Decl
for _, node := range path {
if f, ok := node.(*ast.Field); ok {
field = f
continue
}

// Record the node that starts the declaration of the type that contains
// the field we are creating the interface for.
if d, ok := node.(ast.Decl); ok {
decl = d
break // we have both the field and the declaration
}
}

if field == nil || decl == nil {
return nil, nil, nil
}

p := safetoken.StartPosition(pkg.FileSet(), field.Pos())
pos := protocol.Position{
Line: uint32(p.Line - 1), // Line is zero-based
Character: uint32(p.Column - 1), // Character is zero-based
}

fh, err := snapshot.ReadFile(ctx, pgf.URI)
if err != nil {
return nil, nil, err
}

refs, err := references(ctx, snapshot, fh, pos, false)
if err != nil {
return nil, nil, err
}

type method struct {
signature *types.Signature
name string
}

var methods []method
for _, ref := range refs {
locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI)
if err != nil {
return nil, nil, err
}

_, end, err := locPgf.RangePos(ref.location.Range)
if err != nil {
return nil, nil, err
}

// We are interested in the method call, so we need the node after the dot
rangeEnd := end + token.Pos(len("."))
path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd)
id, ok := path[0].(*ast.Ident)
if !ok {
continue
}

obj := locPkg.GetTypesInfo().ObjectOf(id)
if obj == nil {
continue
}

sig, ok := obj.Type().(*types.Signature)
if !ok {
return nil, nil, errors.New("cannot extract interface with non-method accesses")
}

fc := method{signature: sig, name: obj.Name()}
if !slices.Contains(methods, fc) {
methods = append(methods, fc)
}
}

interfaceName := "I" + pkg.GetTypesInfo().ObjectOf(field.Names[0]).Name()
var buf bytes.Buffer
buf.WriteString("\ntype ")
buf.WriteString(interfaceName)
buf.WriteString(" interface {\n")
for _, fc := range methods {
buf.WriteString("\t")
buf.WriteString(fc.name)
types.WriteSignature(&buf, fc.signature, relativeTo(pkg.GetTypes()))
buf.WriteByte('\n')
}
buf.WriteByte('}')
buf.WriteByte('\n')

interfacePos := decl.Pos() - 1
// Move the interface above the documentation comment if the type declaration
// includes one.
switch d := decl.(type) {
case *ast.GenDecl:
if d.Doc != nil {
interfacePos = d.Doc.Pos() - 1
}
case *ast.FuncDecl:
if d.Doc != nil {
interfacePos = d.Doc.Pos() - 1
}
}

return pkg.FileSet(), &analysis.SuggestedFix{
Message: "Extract interface",
TextEdits: []analysis.TextEdit{{
Pos: interfacePos,
End: interfacePos,
NewText: buf.Bytes(),
}, {
Pos: field.Type.Pos(),
End: field.Type.End(),
NewText: []byte(interfaceName),
}},
}, nil
}

func relativeTo(pkg *types.Package) types.Qualifier {
if pkg == nil {
return nil
}
return func(other *types.Package) string {
if pkg == other {
return "" // same package; unqualified
}
return other.Name()
}
}

// suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form.
func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) {
editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{}
Expand Down
Loading

0 comments on commit 3786889

Please sign in to comment.