From f002efaded87e176593ab646208355df653a2d4b Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Wed, 11 Sep 2024 19:03:35 +0800 Subject: [PATCH] go/ssa: remove outdated doc in CreatePackage and add respective test cases It also removes the deprecated loader package and uses packages and txtar to load the package for testing. Now all packages are created under a go module. --- go/ssa/builder_generic_test.go | 91 ++--- go/ssa/builder_test.go | 371 +++++------------- go/ssa/create.go | 5 +- go/ssa/instantiate_test.go | 108 ++--- go/ssa/sanity.go | 14 + go/ssa/source_test.go | 126 +++--- go/ssa/ssa.go | 4 +- .../{objlookup.go => objlookup.txtar} | 5 +- .../{structconv.go => structconv.txtar} | 5 +- .../{valueforexpr.go => valueforexpr.txtar} | 6 +- go/ssa/testhelper_test.go | 146 +++++++ 11 files changed, 435 insertions(+), 446 deletions(-) rename go/ssa/testdata/{objlookup.go => objlookup.txtar} (98%) rename go/ssa/testdata/{structconv.go => structconv.txtar} (90%) rename go/ssa/testdata/{valueforexpr.go => valueforexpr.txtar} (98%) diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go index 55dc79fe464..8f30ffbab7d 100644 --- a/go/ssa/builder_generic_test.go +++ b/go/ssa/builder_generic_test.go @@ -14,7 +14,6 @@ import ( "testing" "golang.org/x/tools/go/expect" - "golang.org/x/tools/go/loader" "golang.org/x/tools/go/ssa" ) @@ -93,22 +92,22 @@ func TestGenericBodies(t *testing.T) { func From() { type A [4]byte - print(a[A]) //@ types("func(x p03.A)") + print(a[A]) //@ types("func(x example.com.A)") type B *[4]byte - print(b[B]) //@ types("func(x p03.B)") + print(b[B]) //@ types("func(x example.com.B)") type C []byte - print(c[C]) //@ types("func(x p03.C)") + print(c[C]) //@ types("func(x example.com.C)") type D string - print(d[D]) //@ types("func(x p03.D)") + print(d[D]) //@ types("func(x example.com.D)") type E map[int]string - print(e[E]) //@ types("func(x p03.E)") + print(e[E]) //@ types("func(x example.com.E)") type F chan string - print(f[F]) //@ types("func(x p03.F)") + print(f[F]) //@ types("func(x example.com.F)") } `, ` @@ -122,7 +121,7 @@ func TestGenericBodies(t *testing.T) { func From() { type F chan string - print(f[string, F]) //@ types("func(x p05.F)") + print(f[string, F]) //@ types("func(x example.com.F)") } `, ` @@ -152,8 +151,8 @@ func TestGenericBodies(t *testing.T) { type F chan int c := make(F) quit := make(F) - print(start[F], c, quit) //@ types("func(c p06.F, quit p06.F)", "p06.F", "p06.F") - print(fibonacci[F], c, quit) //@ types("func(c p06.F, quit p06.F)", "p06.F", "p06.F") + print(start[F], c, quit) //@ types("func(c example.com.F, quit example.com.F)", "example.com.F", "example.com.F") + print(fibonacci[F], c, quit) //@ types("func(c example.com.F, quit example.com.F)", "example.com.F", "example.com.F") } `, ` @@ -165,7 +164,7 @@ func TestGenericBodies(t *testing.T) { } func From() { type S struct{ x int; y string } - print(f[S]) //@ types("func(i int) p07.S") + print(f[S]) //@ types("func(i int) example.com.S") } `, ` @@ -186,7 +185,7 @@ func TestGenericBodies(t *testing.T) { type H []int32 print(f[F](F{}, 0, 0)) //@ types("[]int8") print(g[G](nil, 0, 0)) //@ types("[]int16") - print(h[H](nil, 0, 0)) //@ types("p08.H") + print(h[H](nil, 0, 0)) //@ types("example.com.H") } `, ` @@ -322,18 +321,18 @@ func TestGenericBodies(t *testing.T) { type MyInterface interface{ foo() } // ChangeType tests - func ct0(x int) { v := MyInt(x); print(x, v) /*@ types(int, "p15.MyInt")*/ } + func ct0(x int) { v := MyInt(x); print(x, v) /*@ types(int, "example.com.MyInt")*/ } func ct1[T MyInt | Other, S int ](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ } func ct2[T int, S MyInt | int ](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ } func ct3[T MyInt | Other, S MyInt | int ](x S) { v := T(x) ; print(x, v) /*@ types(S, T)*/ } // Convert tests - func co0[T int | int8](x MyInt) { v := T(x); print(x, v) /*@ types("p15.MyInt", T)*/} - func co1[T int | int8](x T) { v := MyInt(x); print(x, v) /*@ types(T, "p15.MyInt")*/ } + func co0[T int | int8](x MyInt) { v := T(x); print(x, v) /*@ types("example.com.MyInt", T)*/} + func co1[T int | int8](x T) { v := MyInt(x); print(x, v) /*@ types(T, "example.com.MyInt")*/ } func co2[S, T int | int8](x T) { v := S(x); print(x, v) /*@ types(T, S)*/ } // MakeInterface tests - func mi0[T MyInterface](x T) { v := MyInterface(x); print(x, v) /*@ types(T, "p15.MyInterface")*/ } + func mi0[T MyInterface](x T) { v := MyInterface(x); print(x, v) /*@ types(T, "example.com.MyInterface")*/ } // NewConst tests func nc0[T any]() { v := (*T)(nil); print(v) /*@ types("*T")*/} @@ -427,9 +426,9 @@ func TestGenericBodies(t *testing.T) { Marker() }](v T) { v.Marker() - a := *(any(v).(*A)); print(a) /*@ types("p23.A")*/ - b := *(any(v).(*B)); print(b) /*@ types("p23.B")*/ - c := *(any(v).(*C)); print(c) /*@ types("p23.C")*/ + a := *(any(v).(*A)); print(a) /*@ types("example.com.A")*/ + b := *(any(v).(*B)); print(b) /*@ types("example.com.B")*/ + c := *(any(v).(*C)); print(c) /*@ types("example.com.C")*/ } `, ` @@ -519,32 +518,14 @@ func TestGenericBodies(t *testing.T) { contents := contents pkgname := packageName(t, contents) t.Run(pkgname, func(t *testing.T) { - // Parse - conf := loader.Config{ParserMode: parser.ParseComments} - f, err := conf.ParseFile("file.go", contents) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles(pkgname, f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } + info := ssa.LoadPackageFromSingleFile(t, contents, ssa.SanityCheckFunctions) + p := info.SPkg + prog := p.Prog - // Create and build SSA - prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions) - for _, info := range lprog.AllPackages { - if info.TransitivelyErrorFree { - prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) - } - } - p := prog.Package(lprog.Package(pkgname).Pkg) p.Build() // Collect all notes in f, i.e. comments starting with "//@ types". - notes, err := expect.ExtractGo(prog.Fset, f) + notes, err := expect.ExtractGo(prog.Fset, info.File) if err != nil { t.Errorf("expect.ExtractGo: %v", err) } @@ -754,33 +735,13 @@ func TestInstructionString(t *testing.T) { } ` - // Parse - conf := loader.Config{ParserMode: parser.ParseComments} - const fname = "p.go" - f, err := conf.ParseFile(fname, contents) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - - // Create and build SSA - prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions) - for _, info := range lprog.AllPackages { - if info.TransitivelyErrorFree { - prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) - } - } - p := prog.Package(lprog.Package("p").Pkg) + info := ssa.LoadPackageFromSingleFile(t, contents, ssa.SanityCheckFunctions) + p := info.SPkg + prog := p.Prog p.Build() // Collect all notes in f, i.e. comments starting with "//@ instr". - notes, err := expect.ExtractGo(prog.Fset, f) + notes, err := expect.ExtractGo(prog.Fset, info.File) if err != nil { t.Errorf("expect.ExtractGo: %v", err) } diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index 6ef8a86d728..5782ba9cd4e 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -24,8 +24,6 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/tools/go/analysis/analysistest" - "golang.org/x/tools/go/buildutil" - "golang.org/x/tools/go/loader" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" @@ -176,6 +174,7 @@ func main() { func TestNoIndirectCreatePackage(t *testing.T) { testenv.NeedsGoBuild(t) // for go/packages + // ssa.LoadPackageFromSingleFile(t) dir := testfiles.ExtractTxtarFileToTmp(t, filepath.Join(analysistest.TestData(), "indirect.txtar")) pkgs, err := loadPackages(dir, "testdata/a") if err != nil { @@ -346,8 +345,8 @@ func TestInit(t *testing.T) { input, want string }{ {0, `package A; import _ "errors"; var i int = 42`, - `# Name: A.init -# Package: A + `# Name: example.com.init +# Package: example.com # Synthetic: package initializer func init(): 0: entry P:0 S:2 @@ -363,8 +362,8 @@ func init(): `}, {ssa.BareInits, `package B; import _ "errors"; var i int = 42`, - `# Name: B.init -# Package: B + `# Name: example.com.init +# Package: example.com # Synthetic: package initializer func init(): 0: entry P:0 S:0 @@ -374,23 +373,13 @@ func init(): `}, } for _, test := range tests { - // Create a single-file main package. - var conf loader.Config - f, err := conf.ParseFile("", test.input) - if err != nil { - t.Errorf("test %q: %s", test.input[:15], err) - continue - } - conf.CreateFromFiles(f.Name.Name, f) - lprog, err := conf.Load() - if err != nil { - t.Errorf("test 'package %s': Load: %s", f.Name.Name, err) - continue - } - prog := ssautil.CreateProgram(lprog, test.mode) - mainPkg := prog.Package(lprog.Created[0].Pkg) + info := ssa.LoadPackageFromSingleFile(t, test.input, test.mode) + mainPkg := info.SPkg + prog := mainPkg.Prog + f := info.File prog.Build() + initFunc := mainPkg.Func("init") if initFunc == nil { t.Errorf("test 'package %s': no init function", f.Name.Name) @@ -398,7 +387,7 @@ func init(): } var initbuf bytes.Buffer - _, err = initFunc.WriteTo(&initbuf) + _, err := initFunc.WriteTo(&initbuf) if err != nil { t.Errorf("test 'package %s': WriteTo: %s", f.Name.Name, err) continue @@ -445,44 +434,33 @@ var ( t interface{} = new(struct{*T}) ) ` - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles(f.Name.Name, f) - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } + info := ssa.LoadPackageFromSingleFile(t, input, ssa.SanityCheckFunctions) - // Create and build SSA - prog := ssautil.CreateProgram(lprog, ssa.BuilderMode(0)) - prog.Build() + p := info.SPkg + prog := p.Prog + p.Build() // Enumerate reachable synthetic functions want := map[string]string{ - "(*P.T).g$bound": "bound method wrapper for func (*P.T).g() int", - "(P.T).f$bound": "bound method wrapper for func (P.T).f() int", - - "(*P.T).g$thunk": "thunk for func (*P.T).g() int", - "(P.T).f$thunk": "thunk for func (P.T).f() int", - "(struct{*P.T}).g$thunk": "thunk for func (*P.T).g() int", - "(struct{P.T}).f$thunk": "thunk for func (P.T).f() int", - - "(*P.T).f": "wrapper for func (P.T).f() int", - "(*struct{*P.T}).f": "wrapper for func (P.T).f() int", - "(*struct{*P.T}).g": "wrapper for func (*P.T).g() int", - "(*struct{P.T}).f": "wrapper for func (P.T).f() int", - "(*struct{P.T}).g": "wrapper for func (*P.T).g() int", - "(struct{*P.T}).f": "wrapper for func (P.T).f() int", - "(struct{*P.T}).g": "wrapper for func (*P.T).g() int", - "(struct{P.T}).f": "wrapper for func (P.T).f() int", - - "P.init": "package initializer", + "(*example.com.T).g$bound": "bound method wrapper for func (*example.com.T).g() int", + "(example.com.T).f$bound": "bound method wrapper for func (example.com.T).f() int", + + "(*example.com.T).g$thunk": "thunk for func (*example.com.T).g() int", + "(example.com.T).f$thunk": "thunk for func (example.com.T).f() int", + "(struct{*example.com.T}).g$thunk": "thunk for func (*example.com.T).g() int", + "(struct{example.com.T}).f$thunk": "thunk for func (example.com.T).f() int", + + "(*example.com.T).f": "wrapper for func (example.com.T).f() int", + "(*struct{*example.com.T}).f": "wrapper for func (example.com.T).f() int", + "(*struct{*example.com.T}).g": "wrapper for func (*example.com.T).g() int", + "(*struct{example.com.T}).f": "wrapper for func (example.com.T).f() int", + "(*struct{example.com.T}).g": "wrapper for func (*example.com.T).g() int", + "(struct{*example.com.T}).f": "wrapper for func (example.com.T).f() int", + "(struct{*example.com.T}).g": "wrapper for func (*example.com.T).g() int", + "(struct{example.com.T}).f": "wrapper for func (example.com.T).f() int", + + "example.com.init": "package initializer", } var seen []string // may contain dups for fn := range ssautil.AllFunctions(prog) { @@ -556,23 +534,7 @@ func h(error) // t8 = phi [1: t7, 3: t4] #e // ... - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - - // Create and build SSA - prog := ssautil.CreateProgram(lprog, ssa.BuilderMode(0)) - p := prog.Package(lprog.Package("p").Pkg) + p := ssa.LoadPackageFromSingleFile(t, input, ssa.BuilderMode(0)).SPkg p.Build() g := p.Func("g") @@ -622,23 +584,7 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) // func init func() // var init$guard bool - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - - // Create and build SSA - prog := ssautil.CreateProgram(lprog, ssa.BuilderMode(0)) - p := prog.Package(lprog.Package("p").Pkg) + p := ssa.LoadPackageFromSingleFile(t, input, ssa.BuilderMode(0)).SPkg p.Build() if load := p.Func("Load"); load.Signature.TypeParams().Len() != 1 { @@ -675,24 +621,9 @@ var indirect = R[int].M // var thunk func(S[int]) int // var wrapper func(R[int]) int - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - for _, mode := range []ssa.BuilderMode{ssa.BuilderMode(0), ssa.InstantiateGenerics} { // Create and build SSA - prog := ssautil.CreateProgram(lprog, mode) - p := prog.Package(lprog.Package("p").Pkg) + p := ssa.LoadPackageFromSingleFile(t, input, mode).SPkg p.Build() for _, entry := range []struct { @@ -704,20 +635,20 @@ var indirect = R[int].M { "bound", "*func() int", - "(p.S[int]).M$bound", - "(p.S[int]).M[int]", + "(example.com.S[int]).M$bound", + "(example.com.S[int]).M[int]", }, { "thunk", - "*func(p.S[int]) int", - "(p.S[int]).M$thunk", - "(p.S[int]).M[int]", + "*func(example.com.S[int]) int", + "(example.com.S[int]).M$thunk", + "(example.com.S[int]).M[int]", }, { "indirect", - "*func(p.R[int]) int", - "(p.R[int]).M$thunk", - "(p.S[int]).M[int]", + "*func(example.com.R[int]) int", + "(example.com.R[int]).M$thunk", + "(example.com.S[int]).M[int]", }, } { entry := entry @@ -809,31 +740,23 @@ func TestTypeparamTest(t *testing.T) { t.Logf("Input: %s\n", input) - ctx := build.Default // copy - ctx.GOROOT = "testdata" // fake goroot. Makes tests ~1s. tests take ~80s. + pkgs, err := packages.Load(&packages.Config{ + Mode: packages.NeedSyntax | + packages.NeedTypesInfo | + packages.NeedDeps | + packages.NeedName | + packages.NeedFiles | + packages.NeedImports | + packages.NeedCompiledGoFiles | + packages.NeedTypes, + }, input) - reportErr := func(err error) { - t.Error(err) - } - conf := loader.Config{Build: &ctx, TypeChecker: types.Config{Error: reportErr}} - if _, err := conf.FromArgs([]string{input}, true); err != nil { - t.Fatalf("FromArgs(%s) failed: %s", input, err) - } - - iprog, err := conf.Load() - if iprog != nil { - for _, pkg := range iprog.Created { - for i, e := range pkg.Errors { - t.Errorf("Loading pkg %s error[%d]=%s", pkg, i, e) - } - } - } if err != nil { - t.Fatalf("conf.Load(%s) failed: %s", input, err) + t.Fatalf("fail to load pkgs from file %s", input) } mode := ssa.SanityCheckFunctions | ssa.InstantiateGenerics - prog := ssautil.CreateProgram(iprog, mode) + prog := ssa.CreateProgram(t, pkgs, mode) prog.Build() }) } @@ -856,23 +779,8 @@ func sliceMax(s []int) []int { return s[a():b():c()] } ` - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - // Create and build SSA - prog := ssautil.CreateProgram(lprog, ssa.BuilderMode(0)) - p := prog.Package(lprog.Package("p").Pkg) + p := ssa.LoadPackageFromSingleFile(t, input, ssa.BuilderMode(0)).SPkg p.Build() for _, item := range []struct { @@ -905,30 +813,39 @@ func sliceMax(s []int) []int { return s[a():b():c()] } // TestGenericFunctionSelector ensures generic functions from other packages can be selected. func TestGenericFunctionSelector(t *testing.T) { - pkgs := map[string]map[string]string{ - "main": {"m.go": `package main; import "a"; func main() { a.F[int](); a.G[int,string](); a.H(0) }`}, - "a": {"a.go": `package a; func F[T any](){}; func G[S, T any](){}; func H[T any](a T){} `}, - } + ar := ` +-- go.mod -- +module example.com +go 1.18 + +-- m.go -- +package main; import "example.com/a"; func main() { a.F[int](); a.G[int,string](); a.H(0) } + +-- a/a.go -- +package a; func F[T any](){}; func G[S, T any](){}; func H[T any](a T){} +` + + pkgs := ssa.PackagesFromArchive(t, ar) for _, mode := range []ssa.BuilderMode{ ssa.SanityCheckFunctions, ssa.SanityCheckFunctions | ssa.InstantiateGenerics, } { - conf := loader.Config{ - Build: buildutil.FakeContext(pkgs), - } - conf.Import("main") - lprog, err := conf.Load() - if err != nil { - t.Errorf("Load failed: %s", err) + // Create and build SSA + // todo: consider to refine it + prog := ssa.CreateProgram(t, pkgs, mode) + var tp *types.Package + for _, pkg := range pkgs { + if pkg.Name == "main" { + tp = pkg.Types + break + } } - if lprog == nil { - t.Fatalf("Load returned nil *Program") + if tp == nil { + t.Fatal("fail to get package main from loaded packages") } - // Create and build SSA - prog := ssautil.CreateProgram(lprog, mode) - p := prog.Package(lprog.Package("main").Pkg) + p := prog.Package(tp) p.Build() var callees []string // callees of the CallInstruction.String() in main(). @@ -945,7 +862,7 @@ func TestGenericFunctionSelector(t *testing.T) { } sort.Strings(callees) // ignore the order in the code. - want := "[a.F[int] a.G[int string] a.H[int]]" + want := "[example.com/a.F[int] example.com/a.G[int string] example.com/a.H[int]]" if got := fmt.Sprint(callees); got != want { t.Errorf("Expected main() to contain calls %v. got %v", want, got) } @@ -975,27 +892,16 @@ func TestIssue58491(t *testing.T) { } var Inst = foo[int] ` - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, "p.go", src, 0) - if err != nil { - t.Fatal(err) - } - files := []*ast.File{f} - - pkg := types.NewPackage("p", "") - conf := &types.Config{} - p, _, err := ssautil.BuildPackage(conf, fset, pkg, files, ssa.SanityCheckFunctions|ssa.InstantiateGenerics) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + p := ssa.LoadPackageFromSingleFile(t, src, ssa.SanityCheckFunctions|ssa.InstantiateGenerics).SPkg + p.Build() // Find the local type result instantiated with int. var found bool for _, rt := range p.Prog.RuntimeTypes() { if n, ok := rt.(*types.Named); ok { if u, ok := n.Underlying().(*types.Struct); ok { found = true - if got, want := n.String(), "p.result"; got != want { + if got, want := n.String(), "example.com.result"; got != want { t.Errorf("Expected the name %s got: %s", want, got) } if got, want := u.String(), "struct{res int; error}"; got != want { @@ -1027,19 +933,8 @@ func TestIssue58491Rec(t *testing.T) { } var Inst = foo[int] ` - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, "p.go", src, 0) - if err != nil { - t.Fatal(err) - } - files := []*ast.File{f} - - pkg := types.NewPackage("p", "") - conf := &types.Config{} - p, _, err := ssautil.BuildPackage(conf, fset, pkg, files, ssa.SanityCheckFunctions|ssa.InstantiateGenerics) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + p := ssa.LoadPackageFromSingleFile(t, src, ssa.SanityCheckFunctions|ssa.InstantiateGenerics).SPkg + p.Build() // Find the local type result instantiated with int. var found bool @@ -1047,10 +942,10 @@ func TestIssue58491Rec(t *testing.T) { if n, ok := aliases.Unalias(rt).(*types.Named); ok { if u, ok := n.Underlying().(*types.Struct); ok { found = true - if got, want := n.String(), "p.result"; got != want { + if got, want := n.String(), "example.com.result"; got != want { t.Errorf("Expected the name %s got: %s", want, got) } - if got, want := u.String(), "struct{res int; next *p.result; error}"; got != want { + if got, want := u.String(), "struct{res int; next *example.com.result; error}"; got != want { t.Errorf("Expected the underlying type of %s to be %s. got %s", n, want, got) } } @@ -1087,23 +982,9 @@ func TestSyntax(t *testing.T) { var _ = F[P] // unreferenced => not instantiated ` - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) - } - - // Create and build SSA - prog := ssautil.CreateProgram(lprog, ssa.InstantiateGenerics) - prog.Build() + p := ssa.LoadPackageFromSingleFile(t, input, ssa.InstantiateGenerics).SPkg + prog := p.Prog + p.Build() // Collect syntax information for all of the functions. got := make(map[string]string) @@ -1119,15 +1000,15 @@ func TestSyntax(t *testing.T) { } want := map[string]string{ - "g": "*ast.FuncDecl : func() *p.P @ 4", + "g": "*ast.FuncDecl : func() *example.com.P @ 4", "F": "*ast.FuncDecl : func[T ~int]() *T @ 6", - "F$1": "*ast.FuncLit : func() p.S1 @ 10", - "F$1$1": "*ast.FuncLit : func() p.S2 @ 11", - "F$2": "*ast.FuncLit : func() p.S3 @ 16", + "F$1": "*ast.FuncLit : func() example.com.S1 @ 10", + "F$1$1": "*ast.FuncLit : func() example.com.S2 @ 11", + "F$2": "*ast.FuncLit : func() example.com.S3 @ 16", "F[int]": "*ast.FuncDecl : func() *int @ 6", - "F[int]$1": "*ast.FuncLit : func() p.S1 @ 10", - "F[int]$1$1": "*ast.FuncLit : func() p.S2 @ 11", - "F[int]$2": "*ast.FuncLit : func() p.S3 @ 16", + "F[int]$1": "*ast.FuncLit : func() example.com.S1 @ 10", + "F[int]$1$1": "*ast.FuncLit : func() example.com.S2 @ 11", + "F[int]$2": "*ast.FuncLit : func() example.com.S3 @ 16", // ...but no F[P] etc as they are unreferenced. // (NB: GlobalDebug mode would cause them to be referenced.) } @@ -1176,21 +1057,8 @@ func TestLabels(t *testing.T) { func main() { _:println(1); _:println(2)}`, } for _, test := range tests { - conf := loader.Config{Fset: token.NewFileSet()} - f, err := parser.ParseFile(conf.Fset, "", test, 0) - if err != nil { - t.Errorf("parse error: %s", err) - return - } - conf.CreateFromFiles("main", f) - iprog, err := conf.Load() - if err != nil { - t.Error(err) - continue - } - prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0)) - pkg := prog.Package(iprog.Created[0].Pkg) - pkg.Build() + p := ssa.LoadPackageFromSingleFile(t, test, 0).SPkg + p.Build() } } @@ -1222,22 +1090,10 @@ func TestFixedBugs(t *testing.T) { func TestIssue67079(t *testing.T) { // This test reproduced a race in the SSA builder nearly 100% of the time. - // Load the package. const src = `package p; type T int; func (T) f() {}; var _ = (*T).f` - conf := loader.Config{Fset: token.NewFileSet()} - f, err := parser.ParseFile(conf.Fset, "p.go", src, 0) - if err != nil { - t.Fatal(err) - } - conf.CreateFromFiles("p", f) - iprog, err := conf.Load() - if err != nil { - t.Fatal(err) - } - pkg := iprog.Created[0].Pkg - - // Create and build SSA program. - prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0)) + p := ssa.LoadPackageFromSingleFile(t, src, 0).SPkg + pkg := p.Pkg + prog := p.Prog prog.Build() var g errgroup.Group @@ -1278,7 +1134,7 @@ func TestGenericAliases(t *testing.T) { testenv.NeedsExec(t) testenv.NeedsTool(t, "go") - cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases") + cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases -v") cmd.Env = append(os.Environ(), "GENERICALIASTEST_CHILD=1", "GODEBUG=gotypesalias=1", @@ -1334,22 +1190,13 @@ func f[S any]() { } ` - conf := loader.Config{Fset: token.NewFileSet()} - f, err := parser.ParseFile(conf.Fset, "p.go", source, 0) - if err != nil { - t.Fatal(err) - } - conf.CreateFromFiles("p", f) - iprog, err := conf.Load() - if err != nil { - t.Fatal(err) - } - // Create and build SSA program. - prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics) + p := ssa.LoadPackageFromSingleFile(t, source, 0).SPkg + prog := p.Prog prog.Build() probes := callsTo(ssautil.AllFunctions(prog), "print") + t.Log(probes) if got, want := len(probes), 3*4*2; got != want { t.Errorf("Found %v probes, expected %v", got, want) } diff --git a/go/ssa/create.go b/go/ssa/create.go index 423bce87182..eee3af2684c 100644 --- a/go/ssa/create.go +++ b/go/ssa/create.go @@ -193,14 +193,11 @@ func membersFromDecl(pkg *Package, decl ast.Decl, goversion string) { // // The real work of building SSA form for each function is not done // until a subsequent call to Package.Build. -// -// CreatePackage should not be called after building any package in -// the program. func (prog *Program) CreatePackage(pkg *types.Package, files []*ast.File, info *types.Info, importable bool) *Package { - // TODO(adonovan): assert that no package has yet been built. if pkg == nil { panic("nil pkg") // otherwise pkg.Scope below returns types.Universe! } + p := &Package{ Prog: prog, Members: make(map[string]Member), diff --git a/go/ssa/instantiate_test.go b/go/ssa/instantiate_test.go index fcf682c88a7..a29a79b4445 100644 --- a/go/ssa/instantiate_test.go +++ b/go/ssa/instantiate_test.go @@ -15,38 +15,68 @@ import ( "strings" "testing" - "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/packages" ) -// loadProgram creates loader.Program out of p. -func loadProgram(p string) (*loader.Program, error) { - // Parse - var conf loader.Config - f, err := conf.ParseFile("", p) - if err != nil { - return nil, fmt.Errorf("parse: %v", err) - } - conf.CreateFromFiles("p", f) +func TestCreateNewPkgAfterBuild(t *testing.T) { - // Load - lprog, err := conf.Load() - if err != nil { - return nil, fmt.Errorf("Load: %v", err) - } - return lprog, nil + ar := ` +-- go.mod -- +module example.com +go 1.18 + +-- main.go -- +package p + +import "slices" + +func main(){ + ints := []int{1, 2, 3, 4, 5} + slices.Contains(ints, 1) } -// buildPackage builds and returns ssa representation of package pkg of lprog. -func buildPackage(lprog *loader.Program, pkg string, mode BuilderMode) *Package { - prog := NewProgram(lprog.Fset, mode) +-- sub/p2.go -- +package p2 + +import "slices" - for _, info := range lprog.AllPackages { - prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) +func Entry(){ + numbers := []float32{1, 2, 3, 4, 5} + slices.Contains(numbers, 1) +} +` + pkgs := PackagesFromArchive(t, ar) + + var anotherPkg *packages.Package + for i, p := range pkgs { + if p.Name == "p2" { + anotherPkg = p + pkgs = append(pkgs[:i], pkgs[i+1:]...) + } + } + if anotherPkg == nil { + t.Fatal("cannot find package p2 in the loaded packages") } - p := prog.Package(lprog.Package(pkg).Pkg) - p.Build() - return p + mode := InstantiateGenerics + prog := CreateProgram(t, pkgs, mode) + prog.Build() + + npkg := prog.CreatePackage(anotherPkg.Types, anotherPkg.Syntax, anotherPkg.TypesInfo, true) + npkg.Build() + + var pkgSlices *Package + for _, pkg := range prog.AllPackages() { + if pkg.Pkg.Name() == "slices" { + pkgSlices = pkg + break + } + } + + instanceNum := len(allInstances(pkgSlices.Func("Contains"))) + if instanceNum != 2 { + t.Errorf("slices.Contains should have 2 instances but got %d", instanceNum) + } } // TestNeedsInstance ensures that new method instances can be created via needsInstance, @@ -74,14 +104,9 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) // func init func() // var init$guard bool - lprog, err := loadProgram(input) - if err != nil { - t.Fatal(err) - } - for _, mode := range []BuilderMode{BuilderMode(0), InstantiateGenerics} { // Create and build SSA - p := buildPackage(lprog, "p", mode) + p := LoadPackageFromSingleFile(t, input, mode).SPkg prog := p.Prog ptr := p.Type("Pointer").Type().(*types.Named) @@ -193,14 +218,10 @@ func entry(i int, a A) int { return Id[int](i) } ` - lprog, err := loadProgram(input) - if err != nil { - t.Fatal(err) - } - p := buildPackage(lprog, "p", SanityCheckFunctions) + p := LoadPackageFromSingleFile(t, input, SanityCheckFunctions).SPkg prog := p.Prog - + prog.Build() for _, ti := range []struct { orig string instance string @@ -209,9 +230,9 @@ func entry(i int, a A) int { chTypeInstrs int // number of ChangeType instructions in f's body }{ {"Id", "Id[int]", "[T]", "[int]", 2}, - {"Lambda", "Lambda[p.A]", "[T]", "[p.A]", 1}, + {"Lambda", "Lambda[example.com.A]", "[T]", "[example.com.A]", 1}, {"Make", "Make[int]", "[T]", "[int]", 0}, - {"NoOp", "NoOp[p.K[T]]", "[T]", "[p.K[T]]", 0}, + {"NoOp", "NoOp[example.com.K[T]]", "[T]", "[example.com.K[T]]", 0}, } { test := ti t.Run(test.instance, func(t *testing.T) { @@ -309,19 +330,16 @@ func Foo[T any, S any](t T, s S) { Foo[T, S](t, s) } ` - lprog, err := loadProgram(input) - if err != nil { - t.Fatal(err) - } - p := buildPackage(lprog, "p", SanityCheckFunctions) + p := LoadPackageFromSingleFile(t, input, SanityCheckFunctions).SPkg + p.Build() for _, test := range []struct { orig string instances string }{ - {"H", "[p.H[T] p.H[T]]"}, - {"Foo", "[p.Foo[S T] p.Foo[T S]]"}, + {"H", "[example.com.H[T] example.com.H[T]]"}, + {"Foo", "[example.com.Foo[S T] example.com.Foo[T S]]"}, } { t.Run(test.orig, func(t *testing.T) { f := p.Members[test.orig].(*Function) diff --git a/go/ssa/sanity.go b/go/ssa/sanity.go index 3d82e936518..2e8a952f381 100644 --- a/go/ssa/sanity.go +++ b/go/ssa/sanity.go @@ -609,6 +609,20 @@ func sanityCheckPackage(pkg *Package) { if pkg.Pkg == nil { panic(fmt.Sprintf("Package %s has no Object", pkg)) } + + if pkg.info != nil { + panic(fmt.Sprintf("package %s field 'info' is not cleared", pkg)) + } + if pkg.files != nil { + panic(fmt.Sprintf("package %s field 'files' is not cleared", pkg)) + } + if pkg.created != nil { + panic(fmt.Sprintf("package %s field 'created' is not cleared", pkg)) + } + if pkg.initVersion != nil { + panic(fmt.Sprintf("package %s field 'initVersion' is not cleared", pkg)) + } + _ = pkg.String() // must not crash for name, mem := range pkg.Members { diff --git a/go/ssa/source_test.go b/go/ssa/source_test.go index 112581bb55b..73311885cb1 100644 --- a/go/ssa/source_test.go +++ b/go/ssa/source_test.go @@ -20,9 +20,7 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/expect" - "golang.org/x/tools/go/loader" "golang.org/x/tools/go/ssa" - "golang.org/x/tools/go/ssa/ssautil" ) func TestObjValueLookup(t *testing.T) { @@ -30,17 +28,33 @@ func TestObjValueLookup(t *testing.T) { t.Skipf("no testdata directory on %s", runtime.GOOS) } - conf := loader.Config{ParserMode: parser.ParseComments} - src, err := os.ReadFile("testdata/objlookup.go") + src, err := os.ReadFile("testdata/objlookup.txtar") if err != nil { t.Fatal(err) } - readFile := func(filename string) ([]byte, error) { return src, nil } - f, err := conf.ParseFile("testdata/objlookup.go", src) - if err != nil { - t.Fatal(err) + + pkgs := ssa.PackagesFromArchive(t, string(src)) + prog := ssa.CreateProgram(t, pkgs, ssa.BuilderMode(0)) + + info := ssa.GetPkgInfo(prog, pkgs, "main") + + if info == nil { + t.Fatalf("fail to get package main from loaded packages") + } + + ppkg := info.PPkg + f := info.File + mainPkg := info.SPkg + + readFile := func(_ string) ([]byte, error) { + // split the file content to get the exact file content, + // instead of using go/printer which re-formats the file and the positions are no longer exact + strs := strings.SplitAfter(string(src), "-- objlookup.go --\n") + if len(strs) != 2 { + t.Fatalf("expect to get 2 parts after splitting but got %d", len(strs)) + } + return []byte(strs[1]), nil } - conf.CreateFromFiles("main", f) // Maps each var Ident (represented "name:linenum") to the // kind of ssa.Value we expect (represented "Constant", "&Alloc"). @@ -49,54 +63,45 @@ func TestObjValueLookup(t *testing.T) { // Each note of the form @ssa(x, "BinOp") in testdata/objlookup.go // specifies an expectation that an object named x declared on the // same line is associated with an ssa.Value of type *ssa.BinOp. - notes, err := expect.ExtractGo(conf.Fset, f) + notes, err := expect.ExtractGo(ppkg.Fset, f) if err != nil { t.Fatal(err) } for _, n := range notes { if n.Name != "ssa" { - t.Errorf("%v: unexpected note type %q, want \"ssa\"", conf.Fset.Position(n.Pos), n.Name) + t.Errorf("%v: unexpected note type %q, want \"ssa\"", ppkg.Fset.Position(n.Pos), n.Name) continue } if len(n.Args) != 2 { - t.Errorf("%v: ssa has %d args, want 2", conf.Fset.Position(n.Pos), len(n.Args)) + t.Errorf("%v: ssa has %d args, want 2", ppkg.Fset.Position(n.Pos), len(n.Args)) continue } ident, ok := n.Args[0].(expect.Identifier) if !ok { - t.Errorf("%v: got %v for arg 1, want identifier", conf.Fset.Position(n.Pos), n.Args[0]) + t.Errorf("%v: got %v for arg 1, want identifier", ppkg.Fset.Position(n.Pos), n.Args[0]) continue } exp, ok := n.Args[1].(string) if !ok { - t.Errorf("%v: got %v for arg 2, want string", conf.Fset.Position(n.Pos), n.Args[1]) + t.Errorf("%v: got %v for arg 2, want string", ppkg.Fset.Position(n.Pos), n.Args[1]) continue } - p, _, err := expect.MatchBefore(conf.Fset, readFile, n.Pos, string(ident)) + p, _, err := expect.MatchBefore(ppkg.Fset, readFile, n.Pos, string(ident)) if err != nil { t.Error(err) continue } - pos := conf.Fset.Position(p) + pos := ppkg.Fset.Position(p) key := fmt.Sprintf("%s:%d", ident, pos.Line) expectations[key] = exp } - iprog, err := conf.Load() - if err != nil { - t.Error(err) - return - } - - prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0) /*|ssa.PrintFunctions*/) - mainInfo := iprog.Created[0] - mainPkg := prog.Package(mainInfo.Pkg) mainPkg.SetDebugMode(true) mainPkg.Build() var varIds []*ast.Ident var varObjs []*types.Var - for id, obj := range mainInfo.Defs { + for id, obj := range ppkg.TypesInfo.Defs { // Check invariants for func and const objects. switch obj := obj.(type) { case *types.Func: @@ -113,7 +118,7 @@ func TestObjValueLookup(t *testing.T) { varObjs = append(varObjs, obj) } } - for id, obj := range mainInfo.Uses { + for id, obj := range ppkg.TypesInfo.Uses { if obj, ok := obj.(*types.Var); ok { varIds = append(varIds, id) varObjs = append(varObjs, obj) @@ -125,7 +130,7 @@ func TestObjValueLookup(t *testing.T) { for i, id := range varIds { obj := varObjs[i] ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos()) - pos := prog.Fset.Position(id.Pos()) + pos := ppkg.Fset.Position(id.Pos()) exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)] if exp == "" { t.Errorf("%s: no expectation for var ident %s ", pos, id.Name) @@ -222,11 +227,11 @@ func checkVarValue(t *testing.T, prog *ssa.Program, pkg *ssa.Package, ref []ast. // Ensure that, in debug mode, we can determine the ssa.Value // corresponding to every ast.Expr. func TestValueForExpr(t *testing.T) { - testValueForExpr(t, "testdata/valueforexpr.go") + testValueForExpr(t, "testdata/valueforexpr.txtar") } func TestValueForExprStructConv(t *testing.T) { - testValueForExpr(t, "testdata/structconv.go") + testValueForExpr(t, "testdata/structconv.txtar") } func testValueForExpr(t *testing.T, testfile string) { @@ -234,24 +239,24 @@ func testValueForExpr(t *testing.T, testfile string) { t.Skipf("no testdata dir on %s", runtime.GOOS) } - conf := loader.Config{ParserMode: parser.ParseComments} - f, err := conf.ParseFile(testfile, nil) + src, err := os.ReadFile(testfile) if err != nil { - t.Error(err) - return + t.Fatal(err) } - conf.CreateFromFiles("main", f) - iprog, err := conf.Load() - if err != nil { - t.Error(err) - return + pkgs := ssa.PackagesFromArchive(t, string(src)) + prog := ssa.CreateProgram(t, pkgs, ssa.BuilderMode(0)) + + info := ssa.GetPkgInfo(prog, pkgs, "main") + + if info == nil { + t.Fatalf("fail to get package main from loaded packages") } - mainInfo := iprog.Created[0] + ppkg := info.PPkg + f := info.File + mainPkg := info.SPkg - prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0)) - mainPkg := prog.Package(mainInfo.Pkg) mainPkg.SetDebugMode(true) mainPkg.Build() @@ -318,8 +323,8 @@ func testValueForExpr(t *testing.T, testfile string) { if gotAddr { T = T.Underlying().(*types.Pointer).Elem() // deref } - if !types.Identical(T, mainInfo.TypeOf(e)) { - t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T) + if !types.Identical(T, ppkg.TypesInfo.TypeOf(e)) { + t.Errorf("%s: got type %s, want %s", position, ppkg.TypesInfo.TypeOf(e), T) } } } @@ -357,46 +362,46 @@ func TestEnclosingFunction(t *testing.T) { // Ordinary function: {`package main func f() { println(1003) }`, - "100", "main.f"}, + "100", "example.com.f"}, // Methods: {`package main type T int func (t T) f() { println(200) }`, - "200", "(main.T).f"}, + "200", "(example.com.T).f"}, // Function literal: {`package main func f() { println(func() { print(300) }) }`, - "300", "main.f$1"}, + "300", "example.com.f$1"}, // Doubly nested {`package main func f() { println(func() { print(func() { print(350) })})}`, - "350", "main.f$1$1"}, + "350", "example.com.f$1$1"}, // Implicit init for package-level var initializer. - {"package main; var a = 400", "400", "main.init"}, + {"package main; var a = 400", "400", "example.com.init"}, // No code for constants: {"package main; const a = 500", "500", "(none)"}, // Explicit init() - {"package main; func init() { println(600) }", "600", "main.init#1"}, + {"package main; func init() { println(600) }", "600", "example.com.init#1"}, // Multiple explicit init functions: {`package main func init() { println("foo") } func init() { println(800) }`, - "800", "main.init#2"}, + "800", "example.com.init#2"}, // init() containing FuncLit. {`package main func init() { println(func(){print(900)}) }`, - "900", "main.init#1$1"}, + "900", "example.com.init#1$1"}, // generics {`package main type S[T any] struct{} func (*S[T]) Foo() { println(1000) } type P[T any] struct{ *S[T] }`, - "1000", "(*main.S[T]).Foo", + "1000", "(*example.com.S[T]).Foo", }, } for _, test := range tests { - conf := loader.Config{Fset: token.NewFileSet()} - f, start, end := findInterval(t, conf.Fset, test.input, test.substr) + fset := token.NewFileSet() + f, start, end := findInterval(t, fset, test.input, test.substr) if f == nil { continue } @@ -406,15 +411,8 @@ func TestEnclosingFunction(t *testing.T) { continue } - conf.CreateFromFiles("main", f) - - iprog, err := conf.Load() - if err != nil { - t.Error(err) - continue - } - prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0)) - pkg := prog.Package(iprog.Created[0].Pkg) + info := ssa.LoadPackageFromSingleFile(t, test.input, ssa.BuilderMode(0)) + pkg := info.SPkg pkg.Build() name := "(none)" diff --git a/go/ssa/ssa.go b/go/ssa/ssa.go index df673e2fc99..e0fbbe6638f 100644 --- a/go/ssa/ssa.go +++ b/go/ssa/ssa.go @@ -66,7 +66,7 @@ type Package struct { // The following fields are set transiently, then cleared // after building. - buildOnce sync.Once // ensures package building occurs once + buildOnce sync.Once // ensures package building occurs once so it won't be cleared after building ninit int32 // number of init functions info *types.Info // package type information files []*ast.File // package ASTs @@ -342,7 +342,7 @@ type Function struct { // source information Synthetic string // provenance of synthetic function; "" for true source functions syntax ast.Node // *ast.Func{Decl,Lit}, if from syntax (incl. generic instances) or (*ast.RangeStmt if a yield function) - info *types.Info // type annotations (iff syntax != nil) + info *types.Info // type annotations (if syntax != nil) goversion string // Go version of syntax (NB: init is special) parent *Function // enclosing function if anon; nil if global diff --git a/go/ssa/testdata/objlookup.go b/go/ssa/testdata/objlookup.txtar similarity index 98% rename from go/ssa/testdata/objlookup.go rename to go/ssa/testdata/objlookup.txtar index b040d747333..9aee4e38001 100644 --- a/go/ssa/testdata/objlookup.go +++ b/go/ssa/testdata/objlookup.txtar @@ -1,5 +1,8 @@ -// +build ignore +-- go.mod -- +module example.com +go 1.18 +-- objlookup.go -- package main // This file is the input to TestObjValueLookup in source_test.go, diff --git a/go/ssa/testdata/structconv.go b/go/ssa/testdata/structconv.txtar similarity index 90% rename from go/ssa/testdata/structconv.go rename to go/ssa/testdata/structconv.txtar index c0b4b840ee5..11522769f55 100644 --- a/go/ssa/testdata/structconv.go +++ b/go/ssa/testdata/structconv.txtar @@ -1,5 +1,8 @@ -// +build ignore +-- go.mod -- +module example.com +go 1.18 +-- structconv.go -- // This file is the input to TestValueForExprStructConv in identical_test.go, // which uses the same framework as TestValueForExpr does in source_test.go. // diff --git a/go/ssa/testdata/valueforexpr.go b/go/ssa/testdata/valueforexpr.txtar similarity index 98% rename from go/ssa/testdata/valueforexpr.go rename to go/ssa/testdata/valueforexpr.txtar index 703c316a707..bf8543e0041 100644 --- a/go/ssa/testdata/valueforexpr.go +++ b/go/ssa/testdata/valueforexpr.txtar @@ -1,6 +1,8 @@ -//go:build ignore -// +build ignore +-- go.mod -- +module example.com +go 1.18 +-- valueforexpr.go -- package main // This file is the input to TestValueForExpr in source_test.go, which diff --git a/go/ssa/testhelper_test.go b/go/ssa/testhelper_test.go index 8d08bbb757c..5d251489504 100644 --- a/go/ssa/testhelper_test.go +++ b/go/ssa/testhelper_test.go @@ -4,7 +4,153 @@ package ssa +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "golang.org/x/tools/go/packages" + "golang.org/x/tools/internal/testfiles" + "golang.org/x/tools/txtar" +) + // SetNormalizeAnyForTesting is exported here for external tests. func SetNormalizeAnyForTesting(normalize bool) { normalizeAnyForTesting = normalize } + +// ArchiveFromSingleFileContent creates a go module named example.com +// in txtar format and put the given content in main.go file under the module. +// The package is decided by the package clause in the content. +// The content should contain no error as a typical go file. +// +// It's useful when we want to define a package in a string variable instead of putting it inside a file. +func ArchiveFromSingleFileContent(content string) string { + return fmt.Sprintf(` +-- go.mod -- +module example.com +go 1.18 + +-- main.go -- +%s`, content) +} + +// PackagesFromArchive creates the packages from archive with txtar format. +func PackagesFromArchive(t *testing.T, archive string) []*packages.Package { + ar := txtar.Parse([]byte(archive)) + + fs, err := txtar.FS(ar) + if err != nil { + t.Fatal(err) + } + + dir := testfiles.CopyToTmp(t, fs) + if err != nil { + t.Fatal(err) + } + + var baseConfig = &packages.Config{ + Mode: packages.NeedSyntax | + packages.NeedTypesInfo | + packages.NeedDeps | + packages.NeedName | + packages.NeedFiles | + packages.NeedImports | + packages.NeedCompiledGoFiles | + packages.NeedTypes, + Dir: dir, + } + pkgs, err := packages.Load(baseConfig, "./...") + if err != nil { + t.Fatal(err) + } + if num := packages.PrintErrors(pkgs); num > 0 { + t.Fatalf("packages contained %d errors", num) + } + return pkgs +} + +// CreateProgram creates a program with given initial packages for testing, +// usually the packages are constructed via PackagesFromArchive. +func CreateProgram(t *testing.T, initial []*packages.Package, mode BuilderMode) *Program { + var fset *token.FileSet + if len(initial) > 0 { + fset = initial[0].Fset + } + + prog := NewProgram(fset, mode) + + isInitial := make(map[*packages.Package]bool, len(initial)) + for _, p := range initial { + isInitial[p] = true + } + + packages.Visit(initial, nil, func(p *packages.Package) { + if p.Types != nil && !p.IllTyped { + var files []*ast.File + var info *types.Info + if isInitial[p] { + files = p.Syntax + info = p.TypesInfo + } + prog.CreatePackage(p.Types, files, info, true) + return + } + + t.Fatalf("package %s or its any dependency contains errors", p.Name) + }) + + return prog +} + +// PkgInfo is a ssa package with its packages.Package and ast file. +// We assume the package in test only have one file. +type PkgInfo struct { + SPkg *Package // ssa representation of a package + PPkg *packages.Package // packages representation of a package + File *ast.File // the ast file of the first package file +} + +// GetPkgInfo retrieves the package info from the program with the given name. +// It's useful when you loaded a package from file instead of defining it directly as a string. +func GetPkgInfo(prog *Program, pkgs []*packages.Package, pkgname string) *PkgInfo { + for _, pkg := range pkgs { + if pkg.Name == pkgname { + return &PkgInfo{ + SPkg: prog.Package(pkg.Types), + PPkg: pkg, + File: pkg.Syntax[0], // we assume the test package has one file + } + } + } + return nil +} + +// LoadPackageFromSingleFile is a utility function to creates a package based on the content of a go file, +// and returns the PkgInfo about the input go file. The package name is retrieved from content after parsing. +// It's useful when you want to create a ssa package and its packages.Package and ast.File representation. +func LoadPackageFromSingleFile(t *testing.T, content string, mode BuilderMode) *PkgInfo { + ar := ArchiveFromSingleFileContent(content) + pkgs := PackagesFromArchive(t, ar) + prog := CreateProgram(t, pkgs, mode) + + pkgName := packageName(t, content) + pkgInfo := GetPkgInfo(prog, pkgs, pkgName) + if pkgInfo == nil { + t.Fatalf("fail to get package %s from loaded packages", pkgName) + } + return pkgInfo +} + +// packageName is a test helper to extract the package name from a string +// containing the content of a go file. +func packageName(t testing.TB, content string) string { + f, err := parser.ParseFile(token.NewFileSet(), "", content, parser.PackageClauseOnly) + if err != nil { + t.Fatalf("parsing the file %q failed with error: %s", content, err) + } + return f.Name.Name +}