Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for static type overrides with imports #126

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cli/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"strings"

"github.com/gocomply/xsd2go/pkg/xsd"
"github.com/gocomply/xsd2go/pkg/xsd2go"
"github.com/urfave/cli"
)
Expand Down Expand Up @@ -37,10 +38,28 @@ var convert = cli.Command{
1)
}
}

for _, override := range c.StringSlice("type-override") {
if !strings.Contains(override, "=") {
return cli.NewExitError(
fmt.Sprintf(
"Invalid type-override: '%s', expecting form of TYPE=GOTYPE or TYPE=GOTYPE:GOIMPORT",
override,
),
1,
)
}
}

return nil
},
Action: func(c *cli.Context) error {
xsdFile, goModule, outputDir := c.Args()[0], c.Args()[1], c.Args()[2]

for _, typeOverride := range c.StringSlice("type-override") {
xsd.AddStaticTypeOverride(typeOverride)
}

err := xsd2go.Convert(xsdFile, goModule, outputDir, c.StringSlice("xmlns-override"))
if err != nil {
return cli.NewExitError(err, 1)
Expand All @@ -52,5 +71,9 @@ var convert = cli.Command{
Name: "xmlns-override",
Usage: "Allows to explicitly set gopackage name for given XMLNS. Example: --xmlns-override='http://www.w3.org/2000/09/xmldsig#=xml_signatures'",
},
cli.StringSliceFlag{
Name: "type-override",
Usage: "Allows to explicitly override a static simple type mapping. Example: --type-override='decimal=string' or --type-override='decimal=decimal:github.com/ericlagergren/decimal",
},
},
}
10 changes: 5 additions & 5 deletions pkg/xsd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ func (sch *Schema) findReferencedElement(ref reference) *Element {
}
if innerSchema != sch {
sch.registerImportedModule(innerSchema)

}
return innerSchema.GetElement(ref.Name())
}
Expand Down Expand Up @@ -257,6 +256,9 @@ func (sch *Schema) GoImportsNeeded() []string {
for _, importedMod := range sch.importedModules {
imports = append(imports, fmt.Sprintf("%s/%s", sch.ModulesPath, importedMod.GoPackageName()))
}
for _, importedMod := range GetStaticTypeImports() {
imports = append(imports, importedMod)
}
sort.Strings(imports)
return imports
}
Expand Down Expand Up @@ -297,8 +299,7 @@ type Import struct {

func (i *Import) load(ws *Workspace, baseDir string) (err error) {
if i.SchemaLocation != "" {
i.ImportedSchema, err =
ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), true)
i.ImportedSchema, err = ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), true)
}
return
}
Expand All @@ -312,8 +313,7 @@ type Include struct {

func (i *Include) load(ws *Workspace, baseDir string) (err error) {
if i.SchemaLocation != "" {
i.IncludedSchema, err =
ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), false)
i.IncludedSchema, err = ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), false)
}
return
}
36 changes: 36 additions & 0 deletions pkg/xsd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package xsd

import (
"encoding/xml"
"sort"
"strings"

"github.com/iancoleman/strcase"
)
Expand Down Expand Up @@ -285,9 +287,43 @@ var staticTypes = map[string]staticType{
"byte": "int8",
}

var (
staticTypeImports = map[string]string{}
staticTypeUsed = map[string]struct{}{}
)

func AddStaticTypeOverride(override string) {
parts := strings.SplitN(override, "=", 2)
typeParts := strings.SplitN(parts[1], ":", 2)

typeName := parts[0]

staticTypes[typeName] = staticType(typeParts[0])

if len(typeParts) == 2 {
staticTypeImports[typeName] = typeParts[1]
}
}

func GetStaticTypeImports() []string {
imports := []string{}

for name, mod := range staticTypeImports {
if _, found := staticTypeUsed[name]; found {
imports = append(imports, mod)
}
}

sort.Strings(imports)

return imports
}

func StaticType(name string) staticType {
typ, found := staticTypes[name]
if found {
staticTypeUsed[name] = struct{}{}

return typ
}
panic("Type xsd:" + name + " not implemented")
Expand Down