Skip to content

Commit

Permalink
Fix object generation from 3rd packages, self-reference & alias selec…
Browse files Browse the repository at this point in the history
…tion (#126)
  • Loading branch information
hgiasac authored Jun 22, 2024
1 parent bb434b7 commit 726b3fb
Show file tree
Hide file tree
Showing 19 changed files with 422 additions and 162 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ clean:
rm -f cmd/hasura-ndc-go/testdata/*/source/schema.generated.json
rm -f cmd/hasura-ndc-go/testdata/*/source/**/schema.generated.json
rm -f cmd/hasura-ndc-go/testdata/*/source/**/types.generated.go
rm -rf cmd/hasura-ndc-go/testdata/*/source/testdata
rm -rf cmd/hasura-ndc-go/testdata/**/testdata

.PHONY: build-codegen
build-codegen:
Expand Down
17 changes: 6 additions & 11 deletions cmd/hasura-ndc-go/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,18 @@ func (cg *connectorGenerator) genObjectMethods() error {

for _, objectName := range objectKeys {
object := cg.rawSchema.Objects[objectName]
if object.IsAnonymous {
if object.IsAnonymous || !strings.HasPrefix(object.PackagePath, cg.moduleName) {
continue
}
sb := cg.getOrCreateTypeBuilder(object.PackagePath)
sb.builder.WriteString(fmt.Sprintf(`
// ToMap encodes the struct to a value map
func (j %s) ToMap() (map[string]any, error) {
func (j %s) ToMap() map[string]any {
r := make(map[string]any)
`, objectName))
cg.genObjectToMap(sb, object, "j", "r")
sb.builder.WriteString(`
return r, nil
return r
}`)
}

Expand Down Expand Up @@ -432,7 +432,7 @@ func (cg *connectorGenerator) genToMapProperty(sb *connectorTypeBuilder, field *
if isArrayFragments(fragments) {
varName := formatLocalFieldName(selector)
valueName := fmt.Sprintf("%s_v", varName)
sb.builder.WriteString(fmt.Sprintf(" %s := make([]map[string]any, len(%s))\n", varName, selector))
sb.builder.WriteString(fmt.Sprintf(" %s := make([]any, len(%s))\n", varName, selector))
sb.builder.WriteString(fmt.Sprintf(" for i, %s := range %s {\n", valueName, selector))
cg.genToMapProperty(sb, field, valueName, fmt.Sprintf("%s[i]", varName), ty, fragments[1:])
sb.builder.WriteString(" }\n")
Expand All @@ -442,13 +442,8 @@ func (cg *connectorGenerator) genToMapProperty(sb *connectorTypeBuilder, field *

isAnonymous := strings.HasPrefix(strings.Join(fragments, ""), "struct{")
if !isAnonymous {
sb.SetImport("fmt", "")
sb.builder.WriteString(fmt.Sprintf(` itemResult, err := utils.EncodeObject(%s)
if err != nil {
return nil, fmt.Errorf("failed to encode %s: %%s", err)
}
%s = itemResult
`, selector, field.Name, assigner))
sb.builder.WriteString(fmt.Sprintf(` %s = %s
`, assigner, selector))
return selector
}
innerObject, ok := cg.rawSchema.Objects[ty.Name]
Expand Down
92 changes: 53 additions & 39 deletions cmd/hasura-ndc-go/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ func (rcs RawConnectorSchema) IsCustomType(name string) bool {
type SchemaParser struct {
context context.Context
moduleName string
rawSchema *RawConnectorSchema
packages []*packages.Package
packageIndex int
}
Expand Down Expand Up @@ -436,9 +437,10 @@ func parseRawConnectorSchemaFromGoCode(ctx context.Context, moduleName string, f
moduleName: moduleName,
packages: packageList,
packageIndex: i,
rawSchema: rawSchema,
}

err = sp.parseRawConnectorSchema(rawSchema, packageList[i].Types)
err = sp.parseRawConnectorSchema(packageList[i].Types)
parseSchemaTask.End()
if err != nil {
return nil, err
Expand Down Expand Up @@ -472,11 +474,11 @@ func evalPackageTypesLocation(name string, moduleName string, filePath string, c
}

// parse raw connector schema from Go code
func (sp *SchemaParser) parseRawConnectorSchema(rawSchema *RawConnectorSchema, pkg *types.Package) error {
func (sp *SchemaParser) parseRawConnectorSchema(pkg *types.Package) error {

for _, name := range pkg.Scope().Names() {
_, task := trace.NewTask(sp.context, fmt.Sprintf("parse_%s_schema_%s", sp.GetCurrentPackage().Name, name))
err := sp.parsePackageScope(rawSchema, pkg, name)
err := sp.parsePackageScope(pkg, name)
task.End()
if err != nil {
return err
Expand All @@ -486,7 +488,7 @@ func (sp *SchemaParser) parseRawConnectorSchema(rawSchema *RawConnectorSchema, p
return nil
}

func (sp *SchemaParser) parsePackageScope(rawSchema *RawConnectorSchema, pkg *types.Package, name string) error {
func (sp *SchemaParser) parsePackageScope(pkg *types.Package, name string) error {
switch obj := pkg.Scope().Lookup(name).(type) {
case *types.Func:
// only parse public functions
Expand Down Expand Up @@ -521,35 +523,35 @@ func (sp *SchemaParser) parsePackageScope(rawSchema *RawConnectorSchema, pkg *ty
// ignore 2 first parameters (context and state)
if params.Len() == 3 {
arg := params.At(2)
arguments, argumentType, err := sp.parseArgumentTypes(rawSchema, arg.Type(), []string{})
arguments, argumentType, err := sp.parseArgumentTypes(arg.Type(), []string{})
if err != nil {
return err
}
opInfo.ArgumentsType = argumentType
opInfo.Arguments = arguments
}

resultType, err := sp.parseType(rawSchema, nil, resultTuple.At(0).Type(), []string{}, false)
resultType, err := sp.parseType(nil, resultTuple.At(0).Type(), []string{}, false)
if err != nil {
return err
}
opInfo.ResultType = resultType

switch opInfo.Kind {
case OperationProcedure:
rawSchema.Procedures = append(rawSchema.Procedures, ProcedureInfo(*opInfo))
sp.rawSchema.Procedures = append(sp.rawSchema.Procedures, ProcedureInfo(*opInfo))
case OperationFunction:
rawSchema.Functions = append(rawSchema.Functions, FunctionInfo(*opInfo))
sp.rawSchema.Functions = append(sp.rawSchema.Functions, FunctionInfo(*opInfo))
}
}
return nil
}

func (sp *SchemaParser) parseArgumentTypes(rawSchema *RawConnectorSchema, ty types.Type, fieldPaths []string) (map[string]ArgumentInfo, *TypeInfo, error) {
func (sp *SchemaParser) parseArgumentTypes(ty types.Type, fieldPaths []string) (map[string]ArgumentInfo, *TypeInfo, error) {

switch inferredType := ty.(type) {
case *types.Pointer:
return sp.parseArgumentTypes(rawSchema, inferredType.Elem(), fieldPaths)
return sp.parseArgumentTypes(inferredType.Elem(), fieldPaths)
case *types.Struct:
result := make(map[string]ArgumentInfo)
for i := 0; i < inferredType.NumFields(); i++ {
Expand All @@ -563,7 +565,7 @@ func (sp *SchemaParser) parseArgumentTypes(rawSchema *RawConnectorSchema, ty typ
PackagePath: fieldPackage.Path(),
}
}
fieldType, err := sp.parseType(rawSchema, typeInfo, fieldVar.Type(), append(fieldPaths, fieldVar.Name()), false)
fieldType, err := sp.parseType(typeInfo, fieldVar.Type(), append(fieldPaths, fieldVar.Name()), false)
if err != nil {
return nil, nil, err
}
Expand All @@ -578,7 +580,7 @@ func (sp *SchemaParser) parseArgumentTypes(rawSchema *RawConnectorSchema, ty typ
}
return result, nil, nil
case *types.Named:
arguments, _, err := sp.parseArgumentTypes(rawSchema, inferredType.Obj().Type().Underlying(), append(fieldPaths, inferredType.Obj().Name()))
arguments, _, err := sp.parseArgumentTypes(inferredType.Obj().Type().Underlying(), append(fieldPaths, inferredType.Obj().Name()))
if err != nil {
return nil, nil, err
}
Expand All @@ -599,14 +601,14 @@ func (sp *SchemaParser) parseArgumentTypes(rawSchema *RawConnectorSchema, ty typ
}
}

func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeInfo, ty types.Type, fieldPaths []string, skipNullable bool) (*TypeInfo, error) {
func (sp *SchemaParser) parseType(rootType *TypeInfo, ty types.Type, fieldPaths []string, skipNullable bool) (*TypeInfo, error) {

switch inferredType := ty.(type) {
case *types.Pointer:
if skipNullable {
return sp.parseType(rawSchema, rootType, inferredType.Elem(), fieldPaths, false)
return sp.parseType(rootType, inferredType.Elem(), fieldPaths, false)
}
innerType, err := sp.parseType(rawSchema, rootType, inferredType.Elem(), fieldPaths, false)
innerType, err := sp.parseType(rootType, inferredType.Elem(), fieldPaths, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -653,10 +655,14 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
IsAnonymous: isAnonymous,
Fields: map[string]*ObjectField{},
}
// temporarily add the object type to raw schema to avoid infinite loop
sp.rawSchema.ObjectSchemas[rootType.Name] = objType
sp.rawSchema.Objects[rootType.Name] = objFields

for i := 0; i < inferredType.NumFields(); i++ {
fieldVar := inferredType.Field(i)
fieldTag := inferredType.Tag(i)
fieldType, err := sp.parseType(rawSchema, nil, fieldVar.Type(), append(fieldPaths, fieldVar.Name()), false)
fieldType, err := sp.parseType(nil, fieldVar.Type(), append(fieldPaths, fieldVar.Name()), false)
if err != nil {
return nil, err
}
Expand All @@ -670,8 +676,8 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
Type: fieldType,
}
}
rawSchema.ObjectSchemas[rootType.Name] = objType
rawSchema.Objects[rootType.Name] = objFields
sp.rawSchema.ObjectSchemas[rootType.Name] = objType
sp.rawSchema.Objects[rootType.Name] = objFields

return rootType, nil
case *types.Named:
Expand All @@ -682,19 +688,27 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
}

innerPkg := innerType.Pkg()
var packagePath string
if innerPkg != nil {
packagePath = innerPkg.Path()
if _, ok := sp.rawSchema.Objects[innerType.Name()]; ok {
ty := &TypeInfo{
Name: innerType.Name(),
SchemaName: innerType.Name(),
PackageName: innerPkg.Name(),
PackagePath: innerPkg.Path(),
TypeAST: innerType.Type(),
Schema: schema.NewNamedType(innerType.Name()),
TypeFragments: []string{innerType.Name()},
}
return ty, nil
}

typeInfo, err := sp.parseTypeInfoFromComments(innerType.Name(), packagePath, innerType.Parent())
typeInfo, err := sp.parseTypeInfoFromComments(innerType.Name(), innerPkg.Path(), innerType.Parent())
if err != nil {
return nil, err
}
if innerPkg != nil {
var scalarName ScalarName
typeInfo.PackageName = innerPkg.Name()
typeInfo.PackagePath = packagePath
typeInfo.PackagePath = innerPkg.Path()
scalarSchema := schema.NewScalarType()

switch innerPkg.Path() {
Expand Down Expand Up @@ -734,52 +748,52 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
typeInfo.IsScalar = true
typeInfo.Schema = schema.NewNamedType(string(scalarName))
typeInfo.TypeAST = ty
rawSchema.ScalarSchemas[string(scalarName)] = *scalarSchema
sp.rawSchema.ScalarSchemas[string(scalarName)] = *scalarSchema
return typeInfo, nil
}
}

if typeInfo.IsScalar {
rawSchema.CustomScalars[typeInfo.Name] = typeInfo
sp.rawSchema.CustomScalars[typeInfo.Name] = typeInfo
scalarSchema := schema.NewScalarType()
if typeInfo.ScalarRepresentation != nil {
scalarSchema.Representation = typeInfo.ScalarRepresentation
} else {
// requires representation since NDC spec v0.1.2
scalarSchema.Representation = schema.NewTypeRepresentationJSON().Encode()
}
rawSchema.ScalarSchemas[typeInfo.SchemaName] = *scalarSchema
sp.rawSchema.ScalarSchemas[typeInfo.SchemaName] = *scalarSchema
return typeInfo, nil
}

return sp.parseType(rawSchema, typeInfo, innerType.Type().Underlying(), append(fieldPaths, innerType.Name()), false)
return sp.parseType(typeInfo, innerType.Type().Underlying(), append(fieldPaths, innerType.Name()), false)
case *types.Basic:
var scalarName ScalarName
switch inferredType.Kind() {
case types.Bool:
scalarName = ScalarBoolean
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Int8, types.Uint8:
scalarName = ScalarInt8
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Int16, types.Uint16:
scalarName = ScalarInt16
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Int, types.Int32, types.Uint, types.Uint32:
scalarName = ScalarInt32
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Int64, types.Uint64:
scalarName = ScalarInt64
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Float32:
scalarName = ScalarFloat32
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.Float64:
scalarName = ScalarFloat64
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
case types.String:
scalarName = ScalarString
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
default:
return nil, fmt.Errorf("%s: unsupported scalar type <%s>", strings.Join(fieldPaths, "."), inferredType.String())
}
Expand All @@ -797,15 +811,15 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI

return rootType, nil
case *types.Array:
innerType, err := sp.parseType(rawSchema, nil, inferredType.Elem(), fieldPaths, false)
innerType, err := sp.parseType(nil, inferredType.Elem(), fieldPaths, false)
if err != nil {
return nil, err
}
innerType.TypeFragments = append([]string{"[]"}, innerType.TypeFragments...)
innerType.Schema = schema.NewArrayType(innerType.Schema)
return innerType, nil
case *types.Slice:
innerType, err := sp.parseType(rawSchema, nil, inferredType.Elem(), fieldPaths, false)
innerType, err := sp.parseType(nil, inferredType.Elem(), fieldPaths, false)
if err != nil {
return nil, err
}
Expand All @@ -825,8 +839,8 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
rootType.PackagePath = ""
}

if _, ok := rawSchema.ScalarSchemas[string(scalarName)]; !ok {
rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
if _, ok := sp.rawSchema.ScalarSchemas[string(scalarName)]; !ok {
sp.rawSchema.ScalarSchemas[string(scalarName)] = defaultScalarTypes[scalarName]
}
rootType.TypeFragments = append(rootType.TypeFragments, inferredType.String())
rootType.Schema = schema.NewNamedType(string(scalarName))
Expand Down
Loading

0 comments on commit 726b3fb

Please sign in to comment.