Skip to content

Commit bca10e9

Browse files
authored
Allow directives in schema extensions (#592)
1 parent d126bba commit bca10e9

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

graphql_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -5589,3 +5589,48 @@ func TestSeparateResolvers(t *testing.T) {
55895589
})
55905590
}
55915591
}
5592+
5593+
func TestSchemaExtension(t *testing.T) {
5594+
t.Parallel()
5595+
5596+
sdl := `
5597+
directive @awesome on SCHEMA
5598+
5599+
schema {
5600+
query: Query
5601+
}
5602+
5603+
type Query {
5604+
hello: String!
5605+
}
5606+
5607+
extend schema @awesome
5608+
`
5609+
schema := graphql.MustParseSchema(sdl, &helloResolver{})
5610+
5611+
gqltesting.RunTests(t, []*gqltesting.Test{
5612+
{
5613+
Schema: schema,
5614+
Query: `
5615+
{
5616+
hello
5617+
}
5618+
`,
5619+
ExpectedResult: `
5620+
{
5621+
"hello": "Hello world!"
5622+
}
5623+
`,
5624+
},
5625+
})
5626+
5627+
ast := schema.AST()
5628+
dirs := ast.SchemaDefinition.Directives
5629+
if len(dirs) != 1 {
5630+
t.Fatalf("expected 1 schema directive, got %d", len(dirs))
5631+
}
5632+
name := dirs[0].Name.Name
5633+
if name != "awesome" {
5634+
t.Fatalf(`expected an "awesome" schema directive, got %q`, dirs[0].Name.Name)
5635+
}
5636+
}

internal/schema/schema.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -570,16 +570,18 @@ func parseExtension(s *ast.Schema, l *common.Lexer) {
570570
loc := l.Location()
571571
switch x := l.ConsumeIdent(); x {
572572
case "schema":
573-
l.ConsumeToken('{')
574573
s.SchemaDefinition.Present = true
575574
s.SchemaDefinition.Directives = append(s.SchemaDefinition.Directives, common.ParseDirectives(l)...)
576-
for l.Peek() != '}' {
577-
name := l.ConsumeIdent()
578-
l.ConsumeToken(':')
579-
typ := l.ConsumeIdent()
580-
s.EntryPointNames[name] = typ
575+
if l.Peek() == '{' { // in schema extensions the body is optional
576+
l.ConsumeToken('{')
577+
for l.Peek() != '}' {
578+
name := l.ConsumeIdent()
579+
l.ConsumeToken(':')
580+
typ := l.ConsumeIdent()
581+
s.EntryPointNames[name] = typ
582+
}
583+
l.ConsumeToken('}')
581584
}
582-
l.ConsumeToken('}')
583585

584586
case "type":
585587
obj := parseObjectDef(l)

0 commit comments

Comments
 (0)