diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 5e42b726403..8c8b0f1e0a9 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -339,6 +339,10 @@ type completionContext struct { // packageCompletion is true if we are completing a package name. packageCompletion bool + + // syntaxError is information about when the source code contains + // syntax errors. Only triggered if the completion candidate has period. + syntaxError syntaxErrorContext } // A Selection represents the cursor position and surrounding identifier. @@ -765,6 +769,10 @@ func (c *completer) containingIdent(src []byte) *ast.Ident { // is a keyword. This improves completion after an "accidental // keyword", e.g. completing to "variance" in "someFunc(var<>)". return fakeIdent + } else if tkn == token.IDENT { + // Use manually extracted token when the source contains + // syntax errors. This provides better developer experience. + return fakeIdent } return nil @@ -772,9 +780,13 @@ func (c *completer) containingIdent(src []byte) *ast.Ident { // scanToken scans pgh's contents for the token containing pos. func (c *completer) scanToken(contents []byte) (token.Pos, token.Token, string) { - tok := c.pkg.FileSet().File(c.pos) + var ( + lastLit string + prdPos token.Pos + s scanner.Scanner + ) - var s scanner.Scanner + tok := c.pkg.FileSet().File(c.pos) s.Init(tok, contents, nil, 0) for { tknPos, tkn, lit := s.Scan() @@ -782,6 +794,20 @@ func (c *completer) scanToken(contents []byte) (token.Pos, token.Token, string) return token.NoPos, token.ILLEGAL, "" } + if tkn == token.PERIOD { + prdPos = tknPos + // Save the last lit declared just before the period. + c.completionContext.syntaxError.lit = lastLit + } + // Set hasPeriod to true if cursor is: + // - Right after the period (e.g., "foo.<>"). + // - One or more characters after the period (e.g., "foo.b<>", "foo.bar<>"). + c.completionContext.syntaxError.hasPeriod = tknPos == prdPos || tknPos == prdPos+1 + + if len(lit) > 0 { + lastLit = lit + } + if len(lit) > 0 && tknPos <= c.pos && c.pos <= tknPos+token.Pos(len(lit)) { return tknPos, tkn, lit } @@ -1595,6 +1621,14 @@ func (c *completer) lexical(ctx context.Context) error { continue // Name was declared in some enclosing scope, or not at all. } + // Provide better completion suggestions when the source code contains syntax + // errors and hasPeriod is true. This helps to offer relevant completions despite + // the presence of syntax errors in the code. + if c.completionContext.syntaxError.hasPeriod { + c.syntaxErrorCompletion(obj) + continue + } + // If obj's type is invalid, find the AST node that defines the lexical block // containing the declaration of obj. Don't resolve types for packages. if !isPkgName(obj) && !typeIsValid(obj.Type()) { diff --git a/gopls/internal/golang/completion/syntax_error_completion.go b/gopls/internal/golang/completion/syntax_error_completion.go new file mode 100644 index 00000000000..73312ff7f72 --- /dev/null +++ b/gopls/internal/golang/completion/syntax_error_completion.go @@ -0,0 +1,59 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package completion + +import ( + "go/types" +) + +// syntaxErrorContext represents the context of the scenario when +// the source code contains syntax errors during code completion. +type syntaxErrorContext struct { + // hasPeriod is true if we are handling scenarios where the source + // contains syntax errors and the candidate includes the period. + hasPeriod bool + // lit is the literal value of the token that appeared before the period. + lit string +} + +// syntaxErrorCompletion provides better code completion when the source contains +// syntax errors and the candidate has periods. Only triggered if hasPeriod is true. +func (c *completer) syntaxErrorCompletion(obj types.Object) { + // Check if the object is equal to the literal before the period. + // If not, check for nested types (e.g., "foo.bar.baz<>"). + if obj.Name() != c.completionContext.syntaxError.lit { + c.nestedSynaxErrorCompletion(obj.Type()) + return + } + + switch obj := obj.(type) { + case *types.PkgName: + c.packageMembers(obj.Imported(), stdScore, nil, c.deepState.enqueue) + default: + c.methodsAndFields(obj.Type(), isVar(obj), nil, c.deepState.enqueue) + } +} + +// nestedSynaxErrorCompletion attempts to resolve code completion within nested types +// when the source contains syntax errors. It visits the types to find a match for the literal. +func (c *completer) nestedSynaxErrorCompletion(T types.Type) { + var visit func(T types.Type) + visit = func(T types.Type) { + switch t := T.Underlying().(type) { + case *types.Struct: + for i := 0; i < t.NumFields(); i++ { + field := t.Field(i) + if field.Name() == c.completionContext.syntaxError.lit { + c.methodsAndFields(field.Type(), isVar(field), nil, c.deepState.enqueue) + return + } + if t, ok := field.Type().Underlying().(*types.Struct); ok { + visit(t) + } + } + } + } + visit(T) +} diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index c96e569f1ad..9b1d054dc7d 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -694,6 +694,7 @@ func F3[K comparable, V any](map[K]V, chan V) {} } func TestPackageMemberCompletionAfterSyntaxError(t *testing.T) { + // Update: not broken anymore, fixed. // This test documents the current broken behavior due to golang/go#58833. const src = ` -- go.mod -- @@ -727,7 +728,7 @@ func main() { // (In VSCode, "Abs" wrongly appears in the completion menu.) // This is a consequence of poor error recovery in the parser // causing "math.Ldex" to become a BadExpr. - want := "package main\n\nimport \"math\"\n\nfunc main() {\n\tmath.Sqrt(,0)\n\tmath.Ldexmath.Abs(${1:})\n}\n" + want := "package main\n\nimport \"math\"\n\nfunc main() {\n\tmath.Sqrt(,0)\n\tmath.Ldexp(${1:})\n}\n" if diff := cmp.Diff(want, got); diff != "" { t.Errorf("unimported completion (-want +got):\n%s", diff) } @@ -1164,3 +1165,210 @@ func main() { } }) } + +func TestCompletionAfterSyntaxError(t *testing.T) { + const files = ` +-- go.mod -- +module mod.com + +go 1.14 + +-- test1.go -- +package main + +func test1() { + minimum := 0 + maximum := 0 + minimum, max +} + +-- test2.go -- +package main + +import "math" + +func test2() { + math.Sqrt(0), abs +} + +-- test3.go -- +package main + +import "math" + +func test3() { + math.Sqrt(0), math.ab +} + +-- test4.go -- +package main + +type person struct { + name string + age int +} + +func test4() { + p := person{} + p.name, age +} + +-- test5.go -- +package main + +type person struct { + name string + age int +} + +func test5() { + p := person{} + p.name, p.ag +} + +-- test6.go -- +package main + +import "math" + +func test6() { + math.Sqrt(,0) + abs +} + +-- test7.go -- +package main + +import "math" + +func test7() { + math.Sqrt(,0) + math.ab +} + +-- test8.go -- +package main + +func test8() { + minimum := 0 + fmt.Println("minimum:" min) +} + +-- test9.go -- +package main + +type person struct { + name string + age int +} + +func test9() { + p := person{} + fmt.Println("name:" p.na) +} + +-- test10.go -- +package main + +type Foo struct { + bar Bar + name string +} + +type Bar struct { + baz string +} + +func test10() { + f := Foo{} + f.name, f.bar.b +} +` + tests := []struct { + name string + file string + re string + want string + }{ + { + name: "test 1 variable completion after comma", + file: "test1.go", + re: ", max()", + want: "package main\n\nfunc test1() {\n\tminimum := 0\n\tmaximum := 0\n\tminimum, maximum\n}\n\n", + }, + { + name: "test 2 package member completion after comma", + file: "test2.go", + re: "abs()", + want: "package main\n\nimport \"math\"\n\nfunc test2() {\n\tmath.Sqrt(0), math.Abs(${1:})\n}\n\n", + }, + { + name: "test 3 package member completion after comma with period", + file: "test3.go", + re: "math.ab()", + want: "package main\n\nimport \"math\"\n\nfunc test3() {\n\tmath.Sqrt(0), math.Abs(${1:})\n}\n\n", + }, + { + name: "test 4 struct field completion after comma", + file: "test4.go", + re: ", age()", + want: "package main\n\ntype person struct {\n\tname string\n\tage int\n}\n\nfunc test4() {\n\tp := person{}\n\tp.name, p.age\n}\n\n", + }, + { + name: "test 5 struct field completion after comma with period", + file: "test5.go", + re: "p.ag()", + want: "package main\n\ntype person struct {\n\tname string\n\tage int\n}\n\nfunc test5() {\n\tp := person{}\n\tp.name, p.age\n}\n\n", + }, + { + name: "test 6 package member completion after BadExpr", + file: "test6.go", + re: "abs()", + want: "package main\n\nimport \"math\"\n\nfunc test6() {\n\tmath.Sqrt(,0)\n\tmath.Abs(${1:})\n}\n\n", + }, + { + name: "test 7 package member completion after BadExpr with period", + file: "test7.go", + re: "math.ab()", + want: "package main\n\nimport \"math\"\n\nfunc test7() {\n\tmath.Sqrt(,0)\n\tmath.Abs(${1:})\n}\n\n", + }, + { + name: "test 8 variable completion after missing comma", + file: "test8.go", + re: ":\" min()", + want: "package main\n\nfunc test8() {\n\tminimum := 0\n\tfmt.Println(\"minimum:\" minimum)\n}\n\n", + }, + { + name: "test 9 struct field completion after missing comma with period", + file: "test9.go", + re: "p.na()", + want: "package main\n\ntype person struct {\n\tname string\n\tage int\n}\n\nfunc test9() {\n\tp := person{}\n\tfmt.Println(\"name:\" p.name)\n}\n\n", + }, + { + name: "test 10 complex struct field completion after comma with period", + file: "test10.go", + re: "f.bar.b()", + want: "package main\n\ntype Foo struct {\n\tbar Bar\n\tname string\n}\n\ntype Bar struct {\n\tbaz string\n}\n\nfunc test10() {\n\tf := Foo{}\n\tf.name, f.bar.baz\n}\n", + }, + } + + Run(t, files, func(t *testing.T, env *Env) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + env.OpenFile(test.file) + env.Await(env.DoneWithOpen()) + loc := env.RegexpSearch(test.file, test.re) + completions := env.Completion(loc) + if len(completions.Items) == 0 { + t.Fatalf("no completion items") + } + env.AcceptCompletion(loc, completions.Items[0]) + env.Await(env.DoneWithChange()) + got := env.BufferText(test.file) + if diff := cmp.Diff(test.want, got); diff != "" { + t.Errorf("incorrect completion (-want +got):\n%s", diff) + } + }) + } + }) +}