From 707b9140e5210ed012b3ffb679790ca439e2e7b3 Mon Sep 17 00:00:00 2001 From: Steve Graham Date: Fri, 22 Sep 2023 18:11:45 -0700 Subject: [PATCH] Fixes #44: Fix problem generating an interface in same package as impl Modify generator to detect if the dependency matches the new interface's package name and if so, skip adding the dependency (so skip import). Also, when generating the parameter package names, if the package matches the new interface's package then make the package blank. --- cmd/interfacer/main.go | 21 +++++++++++++-------- interface.go | 8 +++++--- options.go | 7 ++++--- type.go | 16 ++++++++++++---- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/cmd/interfacer/main.go b/cmd/interfacer/main.go index e8e714d..7c9acac 100644 --- a/cmd/interfacer/main.go +++ b/cmd/interfacer/main.go @@ -58,13 +58,22 @@ func run() error { if *output == "" { return errors.New("empty -o flag value; see -help for details") } + packageName := "" + interfaceName := "" + if i := strings.IndexRune(*as, '.'); i != -1 { + packageName = (*as)[:i] + interfaceName = (*as)[i+1:] + } else { + interfaceName = *as + } q, err := interfaces.ParseQuery(*query) if err != nil { return err } opts := &interfaces.Options{ - Query: q, - Unexported: *all, + Query: q, + Unexported: *all, + PackageName: packageName, } i, err := interfaces.NewWithOptions(opts) if err != nil { @@ -75,12 +84,8 @@ func run() error { Deps: i.Deps(), Interface: i, } - if i := strings.IndexRune(*as, '.'); i != -1 { - v.PackageName = (*as)[:i] - v.InterfaceName = (*as)[i+1:] - } else { - v.InterfaceName = *as - } + v.PackageName = packageName + v.InterfaceName = interfaceName var buf bytes.Buffer if err := tmpl.Execute(&buf, v); err != nil { return err diff --git a/interface.go b/interface.go index deaccbd..f0ec055 100644 --- a/interface.go +++ b/interface.go @@ -58,7 +58,9 @@ func (i Interface) Deps() []string { } deps := make([]string, 0, len(pkgs)) for pkg := range pkgs { - deps = append(deps, pkg) + if pkg != "" { + deps = append(deps, pkg) + } } sort.Strings(deps) return deps @@ -142,11 +144,11 @@ func buildInterfaceForPkg(pkg *loader.PackageInfo, opts *Options) (Interface, er } for i := range fn.Ins { fn.Ins[i] = newType(ins.At(i)) - fixup(&fn.Ins[i], opts.Query) + fixup(&fn.Ins[i], opts) } for i := range fn.Outs { fn.Outs[i] = newType(outs.At(i)) - fixup(&fn.Outs[i], opts.Query) + fixup(&fn.Outs[i], opts) } inter = append(inter, fn) } diff --git a/options.go b/options.go index 8e8a166..e6dbc9b 100644 --- a/options.go +++ b/options.go @@ -42,9 +42,10 @@ func (q *Query) valid() error { // Options is used for altering behavior of New() function. type Options struct { - Query *Query // a named type - Context *build.Context // build context; see go/build godoc for details - Unexported bool // whether to include unexported methods + Query *Query // a named type + PackageName string // name of package to generate interface for + Context *build.Context // build context; see go/build godoc for details + Unexported bool // whether to include unexported methods CSVHeader []string CSVRecord []string diff --git a/type.go b/type.go index 8473635..3be4560 100644 --- a/type.go +++ b/type.go @@ -117,7 +117,10 @@ func (typ *Type) setFromComposite(t compositeType, depth int, orig types.Type) { typ.setFromType(t.Elem(), depth+1, orig) } -func fixup(typ *Type, q *Query) { +func fixup(typ *Type, opts *Options) { + query := opts.Query + packageName := opts.PackageName + // Hacky fixup for renaming: // // GeoAdd(string, []*github.com/go-redis/redis.GeoLocation) *redis.IntCmd @@ -130,11 +133,11 @@ func fixup(typ *Type, q *Query) { // when include other package struct if typ.ImportPath != "" && typ.IsComposite { - if typ.ImportPath == q.Package { + if typ.ImportPath == query.Package { typ.Name = strings.Replace(typ.Name, typ.ImportPath, typ.Package, -1) } - if typ.ImportPath != q.Package { + if typ.ImportPath != query.Package { pkgIdx := strings.LastIndex(typ.ImportPath, typ.Package) if 0 < pkgIdx { typ.Name = strings.Replace(typ.Name, typ.ImportPath[:pkgIdx], "", -1) @@ -142,8 +145,13 @@ func fixup(typ *Type, q *Query) { } } - typ.Name = strings.Replace(typ.Name, q.Package, path.Base(q.Package), -1) + typ.Name = strings.Replace(typ.Name, query.Package, path.Base(query.Package), -1) typ.ImportPath = trimVendorPath(typ.ImportPath) + + if typ.Package == packageName { + typ.Package = "" + typ.ImportPath = "" + } } // trimVendorPath removes the vendor dir prefix from a package path.