Skip to content

Commit

Permalink
Fixes rjeczalik#44: Fix problem generating an interface in same packa…
Browse files Browse the repository at this point in the history
…ge as impl

Modify generator to detect if the dependency matches the new interface's
package name and if so, skip adding the dependency (so skip import).

Also, when generating the parameter package names, if the package matches
the new interface's package then make the package blank.
  • Loading branch information
Steve Graham committed Sep 23, 2023
1 parent a16a9ae commit 87fb296
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
21 changes: 13 additions & 8 deletions cmd/interfacer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,22 @@ func run() error {
if *output == "" {
return errors.New("empty -o flag value; see -help for details")
}
packageName := ""
interfaceName := ""
if i := strings.IndexRune(*as, '.'); i != -1 {
packageName = (*as)[:i]
interfaceName = (*as)[i+1:]
} else {
interfaceName = *as
}
q, err := interfaces.ParseQuery(*query)
if err != nil {
return err
}
opts := &interfaces.Options{
Query: q,
Unexported: *all,
Query: q,
Unexported: *all,
PackageName: packageName,
}
i, err := interfaces.NewWithOptions(opts)
if err != nil {
Expand All @@ -75,12 +84,8 @@ func run() error {
Deps: i.Deps(),
Interface: i,
}
if i := strings.IndexRune(*as, '.'); i != -1 {
v.PackageName = (*as)[:i]
v.InterfaceName = (*as)[i+1:]
} else {
v.InterfaceName = *as
}
v.PackageName = packageName
v.InterfaceName = interfaceName
var buf bytes.Buffer
if err := tmpl.Execute(&buf, v); err != nil {
return err
Expand Down
8 changes: 5 additions & 3 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ func (i Interface) Deps() []string {
}
deps := make([]string, 0, len(pkgs))
for pkg := range pkgs {
deps = append(deps, pkg)
if pkg != "" {
deps = append(deps, pkg)
}
}
sort.Strings(deps)
return deps
Expand Down Expand Up @@ -133,11 +135,11 @@ func buildInterfaceForPkg(pkg *packages.Package, opts *Options) (Interface, erro
}
for i := range fn.Ins {
fn.Ins[i] = newType(ins.At(i))
fixup(&fn.Ins[i], opts.Query)
fixup(&fn.Ins[i], opts)
}
for i := range fn.Outs {
fn.Outs[i] = newType(outs.At(i))
fixup(&fn.Outs[i], opts.Query)
fixup(&fn.Outs[i], opts)
}
inter = append(inter, fn)
}
Expand Down
7 changes: 4 additions & 3 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ func (q *Query) valid() error {

// Options is used for altering behavior of New() function.
type Options struct {
Query *Query // a named type
Context *build.Context // build context; see go/build godoc for details
Unexported bool // whether to include unexported methods
Query *Query // a named type
PackageName string // name of package to generate interface for
Context *build.Context // build context; see go/build godoc for details
Unexported bool // whether to include unexported methods

CSVHeader []string
CSVRecord []string
Expand Down
16 changes: 12 additions & 4 deletions type.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ func (typ *Type) setFromComposite(t compositeType, depth int, orig types.Type) {
typ.setFromType(t.Elem(), depth+1, orig)
}

func fixup(typ *Type, q *Query) {
func fixup(typ *Type, opts *Options) {
query := opts.Query
packageName := opts.PackageName

// Hacky fixup for renaming:
//
// GeoAdd(string, []*github.com/go-redis/redis.GeoLocation) *redis.IntCmd
Expand All @@ -137,20 +140,25 @@ func fixup(typ *Type, q *Query) {

// when include other package struct
if typ.ImportPath != "" && typ.IsComposite {
if typ.ImportPath == q.Package {
if typ.ImportPath == query.Package {
typ.Name = strings.Replace(typ.Name, typ.ImportPath, typ.Package, -1)
}

if typ.ImportPath != q.Package {
if typ.ImportPath != query.Package {
pkgIdx := strings.LastIndex(typ.ImportPath, typ.Package)
if 0 < pkgIdx {
typ.Name = strings.Replace(typ.Name, typ.ImportPath[:pkgIdx], "", -1)
}
}
}

typ.Name = strings.Replace(typ.Name, q.Package, path.Base(q.Package), -1)
typ.Name = strings.Replace(typ.Name, query.Package, path.Base(query.Package), -1)
typ.ImportPath = trimVendorPath(typ.ImportPath)

if typ.Package == packageName {
typ.Package = ""
typ.ImportPath = ""
}
}

// trimVendorPath removes the vendor dir prefix from a package path.
Expand Down

0 comments on commit 87fb296

Please sign in to comment.