diff --git a/interface.go b/interface.go index deaccbd..b64dd0d 100644 --- a/interface.go +++ b/interface.go @@ -2,12 +2,11 @@ package interfaces import ( "errors" - "fmt" "go/types" "sort" "unicode" - "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/packages" ) // Interface represents a typed interface. @@ -65,43 +64,34 @@ func (i Interface) Deps() []string { } func buildInterface(opts *Options) (Interface, error) { - cfg := &loader.Config{ - AllowErrors: true, - Build: opts.context(), - ImportPkgs: map[string]bool{opts.Query.Package: true}, - TypeCheckFuncBodies: func(string) bool { return false }, - } - cfg.ImportWithTests(opts.Query.Package) - prog, err := cfg.Load() + var err error + + // If a requested type is defined in an external test package try to + // build the interface using it before returning an error. Therefore, + // set the `Tests` flag to true so that any test packages are also + // loaded and checked below. + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedTypesInfo, + Tests: true, + } + + pkgs, err := packages.Load(cfg, opts.Query.Package) if err != nil { return nil, err } - pkg, ok := prog.Imported[opts.Query.Package] - if !ok { - return nil, fmt.Errorf("parsing successful, but package %q not found", - opts.Query.Package) - } - i, err := buildInterfaceForPkg(pkg, opts) - if err == nil { - return i, nil - } - // If a requested type is defined in an external test package try to - // build the interface using it before returning an error. - queryCopy := *opts.Query - queryCopy.Package += "_test" - optsCopy := *opts - optsCopy.Query = &queryCopy - for _, pkg := range prog.Created { - if pkg.Pkg.Path() == optsCopy.Query.Package { - return buildInterfaceForPkg(pkg, &optsCopy) + + for _, pkg := range pkgs { + i, err := buildInterfaceForPkg(pkg, opts) + if err == nil { + return i, nil } } return nil, err } -func buildInterfaceForPkg(pkg *loader.PackageInfo, opts *Options) (Interface, error) { +func buildInterfaceForPkg(pkg *packages.Package, opts *Options) (Interface, error) { var typ *types.Named - for _, obj := range pkg.Defs { + for _, obj := range pkg.TypesInfo.Defs { if obj == nil { continue }