diff --git a/maker/maker.go b/maker/maker.go index bc31ffb..c383703 100644 --- a/maker/maker.go +++ b/maker/maker.go @@ -117,6 +117,15 @@ func GetReceiverType(fd *ast.FuncDecl) (ast.Expr, error) { return fd.Recv.List[0].Type, nil } +// reMatchTypename matches any of the following to extract the : +// +// * +// [] +// []* +// map[] +// map[]* +var reMatchTypename = regexp.MustCompile(`^(\[\]|\*|\[\]\*|map\[\w+\]|map\[\w+\]\*)(\w+)$`) + // FormatFieldList takes in the source code // as a []byte and a FuncDecl parameters or // return values as a FieldList. @@ -135,12 +144,28 @@ func FormatFieldList(src []byte, fl *ast.FieldList, pkgName string, declaredType names[i] = n.Name } t := string(src[l.Type.Pos()-1 : l.Type.End()-1]) + t2 := t + // Try to match . If matched variable `match` will look like this for t=="[]Category": + // match[0][0] = "[]Category" + // match[0][1] = "[]" + // match[0][2] = "Category" + match := reMatchTypename.FindAllStringSubmatch(t, -1) + if match != nil { + // Set `t` so it will compare correctly with `dt.Name` below + t2 = match[0][2] + } for _, dt := range declaredTypes { - if t == dt.Name && pkgName != dt.Package { + if t2 == dt.Name && pkgName != dt.Package { // The type of this field is the same as one declared in the source package, // and the source package is not the same as the destination package. - t = dt.Fullname() + if match != nil { + // Add back `*`, `[]`, `[]*`, `map[]` or `map[]*` if there was a + // match. + t = match[0][1] + dt.Fullname() + } else { + t = dt.Fullname() + } } } @@ -341,12 +366,6 @@ func Make(options MakeOptions) ([]byte, error) { return []byte{}, err } types := ParseDeclaredTypes(b) - // validate structs from file against input struct Type - if !validateStructType(types, options.StructType) { - return []byte{}, - fmt.Errorf("%q structtype not found in input files", - options.StructType) - } for _, t := range types { if _, ok := tset[t.Fullname()]; !ok { allDeclaredTypes = append(allDeclaredTypes, t) @@ -354,6 +373,12 @@ func Make(options MakeOptions) ([]byte, error) { } } } + // validate structs from file against input struct Type + if !validateStructType(allDeclaredTypes, options.StructType) { + return []byte{}, + fmt.Errorf("%q structtype not found in input files", + options.StructType) + } excludedMethods := make(map[string]struct{}, len(options.ExcludeMethods)) for _, mName := range options.ExcludeMethods {