diff --git a/example/social/social.go b/example/social/social.go index 67774207..781b2344 100644 --- a/example/social/social.go +++ b/example/social/social.go @@ -73,15 +73,19 @@ func (r *searchResult) ToUser() (*user, bool) { return res, ok } +type contact struct { + Email string + Phone string +} + type user struct { IDField string NameField string RoleField string - Email string - Phone string Address *[]string Friends *[]*user CreatedAt graphql.Time + contact } func (u user) ID() graphql.ID { @@ -126,37 +130,45 @@ var users = []*user{ IDField: "0x01", NameField: "Albus Dumbledore", RoleField: "ADMIN", - Email: "Albus@hogwarts.com", - Phone: "000-000-0000", Address: &[]string{"Office @ Hogwarts", "where Horcruxes are"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "Albus@hogwarts.com", + Phone: "000-000-0000", + }, }, { IDField: "0x02", NameField: "Harry Potter", RoleField: "USER", - Email: "harry@hogwarts.com", - Phone: "000-000-0001", Address: &[]string{"123 dorm room @ Hogwarts", "456 random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "harry@hogwarts.com", + Phone: "000-000-0001", + }, }, { IDField: "0x03", NameField: "Hermione Granger", RoleField: "USER", - Email: "hermione@hogwarts.com", - Phone: "000-000-0011", Address: &[]string{"233 dorm room @ Hogwarts", "786 @ random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "hermione@hogwarts.com", + Phone: "000-000-0011", + }, }, { IDField: "0x04", NameField: "Ronald Weasley", RoleField: "USER", - Email: "ronald@hogwarts.com", - Phone: "000-000-0111", Address: &[]string{"411 dorm room @ Hogwarts", "981 @ random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "ronald@hogwarts.com", + Phone: "000-000-0111", + }, }, } diff --git a/go.mod b/go.mod index 088e9931..2c814b0b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/graph-gophers/graphql-go require github.com/opentracing/opentracing-go v1.1.0 + +go 1.13 diff --git a/graphql_test.go b/graphql_test.go index 9f2f3e63..1c7aeb63 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -351,6 +351,88 @@ func TestBasic(t *testing.T) { }) } +type testEmbeddedStructResolver struct{} + +func (_ *testEmbeddedStructResolver) Course() courseResolver { + return courseResolver{ + CourseMeta: CourseMeta{ + Name: "Biology", + Timestamps: Timestamps{CreatedAt: "yesterday", UpdatedAt: "today"}, + }, + Instructor: Instructor{Name: "Socrates"}, + } +} + +type courseResolver struct { + CourseMeta + Instructor Instructor +} + +type CourseMeta struct { + Name string + Timestamps +} + +type Instructor struct { + Name string +} + +type Timestamps struct { + CreatedAt string + UpdatedAt string +} + +func TestEmbeddedStruct(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + + type Query { + course: Course! + } + + type Course { + name: String! + createdAt: String! + updatedAt: String! + instructor: Instructor! + } + + type Instructor { + name: String! + } + `, &testEmbeddedStructResolver{}, graphql.UseFieldResolvers()), + Query: ` + { + course{ + name + createdAt + updatedAt + instructor { + name + } + } + } + `, + ExpectedResult: ` + { + "course": { + "name": "Biology", + "createdAt": "yesterday", + "updatedAt": "today", + "instructor": { + "name":"Socrates" + } + } + } + `, + }, + }) +} + type testNilInterfaceResolver struct{} func (r *testNilInterfaceResolver) A() interface{ Z() int32 } { @@ -1179,13 +1261,13 @@ func TestDeprecatedDirective(t *testing.T) { }) } -type testBadEnumResolver struct {} +type testBadEnumResolver struct{} func (r *testBadEnumResolver) Hero() *testBadEnumCharacterResolver { return &testBadEnumCharacterResolver{} } -type testBadEnumCharacterResolver struct {} +type testBadEnumCharacterResolver struct{} func (r *testBadEnumCharacterResolver) Name() string { return "Spock" @@ -1227,7 +1309,7 @@ func TestEnums(t *testing.T) { `, ExpectedErrors: []*gqlerrors.QueryError{ { - Message: "Argument \"episode\" has invalid value WRATH_OF_KHAN.\nExpected type \"Episode\", found WRATH_OF_KHAN.", + Message: "Argument \"episode\" has invalid value WRATH_OF_KHAN.\nExpected type \"Episode\", found WRATH_OF_KHAN.", Locations: []gqlerrors.Location{{Column: 20, Line: 3}}, Rule: "ArgumentsOfCorrectType", }, @@ -1265,7 +1347,7 @@ func TestEnums(t *testing.T) { Variables: map[string]interface{}{"episode": "FINAL_FRONTIER"}, ExpectedErrors: []*gqlerrors.QueryError{ { - Message: "Variable \"episode\" has invalid value FINAL_FRONTIER.\nExpected type \"Episode\", found FINAL_FRONTIER.", + Message: "Variable \"episode\" has invalid value FINAL_FRONTIER.\nExpected type \"Episode\", found FINAL_FRONTIER.", Locations: []gqlerrors.Location{{Column: 26, Line: 2}}, Rule: "VariablesOfCorrectType", }, @@ -1327,7 +1409,7 @@ func TestEnums(t *testing.T) { ExpectedErrors: []*gqlerrors.QueryError{ { Message: "Invalid value STAR_TREK.\nExpected type Episode, found STAR_TREK.", - Path: []interface{}{"hero", "appearsIn", 0}, + Path: []interface{}{"hero", "appearsIn", 0}, }, }, }, @@ -3018,23 +3100,65 @@ func TestErrorPropagation(t *testing.T) { }) } +type ambiguousResolver struct { + Name string // ambiguous + University +} + +type University struct { + Name string // ambiguous +} + +func TestPanicAmbiguity(t *testing.T) { + panicMessage := `*graphql_test.ambiguousResolver does not resolve "Query": ambiguous field "name"` + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected schema parse to panic") + } + + if r.(error).Error() != panicMessage { + t.Logf("got: %s", r) + t.Logf("want: %s", panicMessage) + t.Fail() + } + }() + + schema := ` + schema { + query: Query + } + + type Query { + name: String! + university: University! + } + + type University { + name: String! + } + ` + graphql.MustParseSchema(schema, &ambiguousResolver{}, graphql.UseFieldResolvers()) +} + func TestSchema_Exec_without_resolver(t *testing.T) { t.Parallel() type args struct { - Query string + Query string Schema string } type want struct { Panic interface{} } testTable := []struct { - Name string - Args args - Want want + Name string + Args args + Want want }{ { - Name: "schema_without_resolver_errors", + Name: "schema_without_resolver_errors", Args: args{ Query: ` query { diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 46d6510a..dc0aa72f 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -217,7 +217,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if res.Kind() == reflect.Ptr { res = res.Elem() } - result = res.Field(f.field.FieldIndex) + result = res.FieldByIndex(f.field.FieldIndex) } return nil }() diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index e82d35e5..1b248e69 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -34,7 +34,7 @@ type Field struct { schema.Field TypeName string MethodIndex int - FieldIndex int + FieldIndex []int HasContext bool HasError bool ArgsPacker *packer.StructPacker @@ -43,7 +43,7 @@ type Field struct { } func (f *Field) UseMethodResolver() bool { - return f.FieldIndex == -1 + return len(f.FieldIndex) == 0 } type TypeAssertion struct { @@ -228,13 +228,17 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p Fields := make(map[string]*Field) rt := unwrapPtr(resolverType) + fieldsCount := fieldCount(rt, map[string]int{}) for _, f := range fields { - fieldIndex := -1 + var fieldIndex []int methodIndex := findMethod(resolverType, f.Name) if b.schema.UseFieldResolvers && methodIndex == -1 { - fieldIndex = findField(rt, f.Name) + if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 { + return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name) + } + fieldIndex = findField(rt, f.Name, []int{}) } - if methodIndex == -1 && fieldIndex == -1 { + if methodIndex == -1 && len(fieldIndex) == 0 { hint := "" if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 { hint = " (hint: the method exists on the pointer type)" @@ -247,7 +251,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p if methodIndex != -1 { m = resolverType.Method(methodIndex) } else { - sf = rt.Field(fieldIndex) + sf = rt.FieldByIndex(fieldIndex) } fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver) if err != nil { @@ -290,7 +294,7 @@ var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, - methodIndex, fieldIndex int, methodHasReceiver bool) (*Field, error) { + methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) { var argsPacker *packer.StructPacker var hasError bool @@ -380,13 +384,46 @@ func findMethod(t reflect.Type, name string) int { return -1 } -func findField(t reflect.Type, name string) int { +func findField(t reflect.Type, name string, index []int) []int { for i := 0; i < t.NumField(); i++ { - if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Field(i).Name)) { - return i + field := t.Field(i) + + if field.Type.Kind() == reflect.Struct && field.Anonymous { + newIndex := findField(field.Type, name, []int{i}) + if len(newIndex) > 1 { + return append(index, newIndex...) + } + } + + if strings.EqualFold(stripUnderscore(name), stripUnderscore(field.Name)) { + return append(index, i) } } - return -1 + + return index +} + +// fieldCount helps resolve ambiguity when more than one embedded struct contains fields with the same name. +func fieldCount(t reflect.Type, count map[string]int) map[string]int { + if t.Kind() != reflect.Struct { + return nil + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldName := strings.ToLower(stripUnderscore(field.Name)) + + if field.Type.Kind() == reflect.Struct && field.Anonymous { + count = fieldCount(field.Type, count) + } else { + if _, ok := count[fieldName]; !ok { + count[fieldName] = 0 + } + count[fieldName]++ + } + } + + return count } func unwrapNonNull(t common.Type) (common.Type, bool) {