Skip to content

Commit 5f8e662

Browse files
committed
Added support of custom directives
1 parent 8a96404 commit 5f8e662

File tree

9 files changed

+166
-31
lines changed

9 files changed

+166
-31
lines changed

example/caching/server/server.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/graph-gophers/graphql-go"
1010
"github.com/graph-gophers/graphql-go/example/caching"
1111
"github.com/graph-gophers/graphql-go/example/caching/cache"
12+
"github.com/graph-gophers/graphql-go/types"
1213
)
1314

1415
var schema *graphql.Schema
@@ -40,12 +41,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4041
var hint *cache.Hint
4142
if cacheable(r) {
4243
ctx, hints, done := cache.Hintable(r.Context())
43-
response = h.Schema.Exec(ctx, p.Query, p.OperationName, p.Variables)
44+
response = h.Schema.Exec(ctx, p.Query, p.OperationName, p.Variables, map[string]types.DirectiveVisitor{})
4445
done()
4546
v := <-hints
4647
hint = &v
4748
} else {
48-
response = h.Schema.Exec(r.Context(), p.Query, p.OperationName, p.Variables)
49+
response = h.Schema.Exec(r.Context(), p.Query, p.OperationName, p.Variables, map[string]types.DirectiveVisitor{})
4950
}
5051
responseJSON, err := json.Marshal(response)
5152
if err != nil {

gqltesting/testing.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@ import (
1212

1313
graphql "github.com/graph-gophers/graphql-go"
1414
"github.com/graph-gophers/graphql-go/errors"
15+
"github.com/graph-gophers/graphql-go/types"
1516
)
1617

1718
// Test is a GraphQL test case to be used with RunTest(s).
1819
type Test struct {
19-
Context context.Context
20-
Schema *graphql.Schema
21-
Query string
22-
OperationName string
23-
Variables map[string]interface{}
24-
ExpectedResult string
25-
ExpectedErrors []*errors.QueryError
26-
RawResponse bool
20+
Context context.Context
21+
Schema *graphql.Schema
22+
Query string
23+
OperationName string
24+
Variables map[string]interface{}
25+
ExpectedResult string
26+
ExpectedErrors []*errors.QueryError
27+
RawResponse bool
28+
DirectiveVisitors map[string]types.DirectiveVisitor
2729
}
2830

2931
// RunTests runs the given GraphQL test cases as subtests.
@@ -45,7 +47,7 @@ func RunTest(t *testing.T, test *Test) {
4547
if test.Context == nil {
4648
test.Context = context.Background()
4749
}
48-
result := test.Schema.Exec(test.Context, test.Query, test.OperationName, test.Variables)
50+
result := test.Schema.Exec(test.Context, test.Query, test.OperationName, test.Variables, test.DirectiveVisitors)
4951

5052
checkErrors(t, test.ExpectedErrors, result.Errors)
5153

graphql.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,14 @@ func (s *Schema) ValidateWithVariables(queryString string, variables map[string]
186186
// Exec executes the given query with the schema's resolver. It panics if the schema was created
187187
// without a resolver. If the context get cancelled, no further resolvers will be called and a
188188
// the context error will be returned as soon as possible (not immediately).
189-
func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) *Response {
189+
func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, visitors map[string]types.DirectiveVisitor) *Response {
190190
if s.res.Resolver == (reflect.Value{}) {
191191
panic("schema created without resolver, can not exec")
192192
}
193-
return s.exec(ctx, queryString, operationName, variables, s.res)
193+
return s.exec(ctx, queryString, operationName, variables, visitors, s.res)
194194
}
195195

196-
func (s *Schema) exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, res *resolvable.Schema) *Response {
196+
func (s *Schema) exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, visitors map[string]types.DirectiveVisitor, res *resolvable.Schema) *Response {
197197
doc, qErr := query.Parse(queryString)
198198
if qErr != nil {
199199
return &Response{Errors: []*errors.QueryError{qErr}}
@@ -244,9 +244,10 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
244244
Schema: s.schema,
245245
DisableIntrospection: s.disableIntrospection,
246246
},
247-
Limiter: make(chan struct{}, s.maxParallelism),
248-
Tracer: s.tracer,
249-
Logger: s.logger,
247+
Limiter: make(chan struct{}, s.maxParallelism),
248+
Tracer: s.tracer,
249+
Logger: s.logger,
250+
Visitors: visitors,
250251
}
251252
varTypes := make(map[string]*introspection.Type)
252253
for _, v := range op.Vars {

graphql_test.go

+84-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
gqlerrors "github.com/graph-gophers/graphql-go/errors"
1212
"github.com/graph-gophers/graphql-go/example/starwars"
1313
"github.com/graph-gophers/graphql-go/gqltesting"
14+
"github.com/graph-gophers/graphql-go/types"
1415
)
1516

1617
type helloWorldResolver1 struct{}
@@ -45,6 +46,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam
4546
return "Hello " + args.FullName + "!", nil
4647
}
4748

49+
type customDirectiveVisitor struct {
50+
beforeWasCalled bool
51+
}
52+
53+
func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
54+
v.beforeWasCalled = true
55+
return nil
56+
}
57+
58+
func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
59+
if v.beforeWasCalled == false {
60+
return nil, errors.New("Before directive visitor method wasn't called.")
61+
}
62+
63+
if value, ok := directive.Arguments.Get("customAttribute"); ok {
64+
return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil
65+
} else {
66+
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
67+
}
68+
}
69+
4870
type theNumberResolver struct {
4971
number int32
5072
}
@@ -188,7 +210,6 @@ func TestHelloWorld(t *testing.T) {
188210
}
189211
`,
190212
},
191-
192213
{
193214
Schema: graphql.MustParseSchema(`
194215
schema {
@@ -213,6 +234,67 @@ func TestHelloWorld(t *testing.T) {
213234
})
214235
}
215236

237+
func TestCustomDirective(t *testing.T) {
238+
t.Parallel()
239+
240+
gqltesting.RunTests(t, []*gqltesting.Test{
241+
{
242+
Schema: graphql.MustParseSchema(`
243+
directive @customDirective on FIELD_DEFINITION
244+
245+
schema {
246+
query: Query
247+
}
248+
249+
type Query {
250+
hello_html: String! @customDirective
251+
}
252+
`, &helloSnakeResolver1{}),
253+
Query: `
254+
{
255+
hello_html
256+
}
257+
`,
258+
ExpectedResult: `
259+
{
260+
"hello_html": "Directive 'customDirective' modified result: Hello snake!"
261+
}
262+
`,
263+
DirectiveVisitors: map[string]types.DirectiveVisitor{
264+
"customDirective": &customDirectiveVisitor{},
265+
},
266+
},
267+
{
268+
Schema: graphql.MustParseSchema(`
269+
directive @customDirective(
270+
customAttribute: String!
271+
) on FIELD_DEFINITION
272+
273+
schema {
274+
query: Query
275+
}
276+
277+
type Query {
278+
say_hello(full_name: String!): String! @customDirective(customAttribute: hi)
279+
}
280+
`, &helloSnakeResolver1{}),
281+
Query: `
282+
{
283+
say_hello(full_name: "Johnny")
284+
}
285+
`,
286+
ExpectedResult: `
287+
{
288+
"say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!"
289+
}
290+
`,
291+
DirectiveVisitors: map[string]types.DirectiveVisitor{
292+
"customDirective": &customDirectiveVisitor{},
293+
},
294+
},
295+
})
296+
}
297+
216298
func TestHelloSnake(t *testing.T) {
217299
t.Parallel()
218300

@@ -3728,7 +3810,7 @@ func TestSchema_Exec_without_resolver(t *testing.T) {
37283810
t.Fail()
37293811
}
37303812
}()
3731-
_ = s.Exec(context.Background(), tt.Args.Query, "", map[string]interface{}{})
3813+
_ = s.Exec(context.Background(), tt.Args.Query, "", map[string]interface{}{}, map[string]types.DirectiveVisitor{})
37323814
})
37333815
}
37343816
}

internal/common/lexer_test.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,21 @@ func TestConsume(t *testing.T) {
9494
}
9595
}
9696

97-
var multilineStringTests = []consumeTestCase {
97+
var multilineStringTests = []consumeTestCase{
9898
{
99-
description: "Oneline strings are okay",
100-
definition: `"Hello World"`,
101-
expected: "",
102-
failureExpected: false,
103-
useStringDescriptions: true,
99+
description: "Oneline strings are okay",
100+
definition: `"Hello World"`,
101+
expected: "",
102+
failureExpected: false,
103+
useStringDescriptions: true,
104104
},
105105
{
106106
description: "Multiline strings are not allowed",
107107
definition: `"Hello
108108
World"`,
109-
expected: `graphql: syntax error: literal not terminated (line 1, column 1)`,
110-
failureExpected: true,
111-
useStringDescriptions: true,
109+
expected: `graphql: syntax error: literal not terminated (line 1, column 1)`,
110+
failureExpected: true,
111+
useStringDescriptions: true,
112112
},
113113
}
114114

@@ -130,5 +130,5 @@ func TestMultilineString(t *testing.T) {
130130
t.Fatalf("Test '%s' failed with error: '%s'", test.description, err.Error())
131131
}
132132
})
133-
}
133+
}
134134
}

internal/exec/exec.go

+40
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type Request struct {
2424
Tracer trace.Tracer
2525
Logger log.Logger
2626
SubscribeResolverTimeout time.Duration
27+
Visitors map[string]types.DirectiveVisitor
2728
}
2829

2930
func (r *Request) handlePanic(ctx context.Context) {
@@ -206,8 +207,47 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
206207
if f.field.ArgsPacker != nil {
207208
in = append(in, f.field.PackedArgs)
208209
}
210+
211+
// Before hook directive visitor
212+
if len(f.field.Directives) > 0 {
213+
for _, directive := range f.field.Directives {
214+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
215+
var values = make([]interface{}, 0)
216+
for _, inValue := range in {
217+
values = append(values, inValue.Interface())
218+
}
219+
220+
if visitorErr := visitor.Before(ctx, directive, values); err != nil {
221+
err := errors.Errorf("%s", visitorErr)
222+
err.Path = path.toSlice()
223+
err.ResolverError = visitorErr
224+
return err
225+
}
226+
}
227+
}
228+
}
229+
230+
// Call method
209231
callOut := res.Method(f.field.MethodIndex).Call(in)
210232
result = callOut[0]
233+
234+
// After hook directive visitor (when no error is returned from resolver)
235+
if !f.field.HasError && len(f.field.Directives) > 0 {
236+
for _, directive := range f.field.Directives {
237+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
238+
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
239+
if err != nil {
240+
err := errors.Errorf("%s", visitorErr)
241+
err.Path = path.toSlice()
242+
err.ResolverError = visitorErr
243+
return err
244+
} else {
245+
result = reflect.ValueOf(returned)
246+
}
247+
}
248+
}
249+
}
250+
211251
if f.field.HasError && !callOut[1].IsNil() {
212252
resolverErr := callOut[1].Interface().(error)
213253
err := errors.Errorf("%s", resolverErr)

introspection.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/graph-gophers/graphql-go/internal/exec/resolvable"
88
"github.com/graph-gophers/graphql-go/introspection"
9+
"github.com/graph-gophers/graphql-go/types"
910
)
1011

1112
// Inspect allows inspection of the given schema.
@@ -15,7 +16,7 @@ func (s *Schema) Inspect() *introspection.Schema {
1516

1617
// ToJSON encodes the schema in a JSON format used by tools like Relay.
1718
func (s *Schema) ToJSON() ([]byte, error) {
18-
result := s.exec(context.Background(), introspectionQuery, "", nil, &resolvable.Schema{
19+
result := s.exec(context.Background(), introspectionQuery, "", nil, map[string]types.DirectiveVisitor{}, &resolvable.Schema{
1920
Meta: s.res.Meta,
2021
Query: &resolvable.Object{},
2122
Schema: *s.schema,

relay/relay.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
graphql "github.com/graph-gophers/graphql-go"
12+
"github.com/graph-gophers/graphql-go/types"
1213
)
1314

1415
func MarshalID(kind string, spec interface{}) graphql.ID {
@@ -58,7 +59,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
5859
return
5960
}
6061

61-
response := h.Schema.Exec(r.Context(), params.Query, params.OperationName, params.Variables)
62+
response := h.Schema.Exec(r.Context(), params.Query, params.OperationName, params.Variables, map[string]types.DirectiveVisitor{})
6263
responseJSON, err := json.Marshal(response)
6364
if err != nil {
6465
http.Error(w, err.Error(), http.StatusInternalServerError)

types/directive.go

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package types
22

3+
import "context"
4+
35
// Directive is a representation of the GraphQL Directive.
46
//
57
// http://spec.graphql.org/draft/#sec-Language.Directives
@@ -20,6 +22,11 @@ type DirectiveDefinition struct {
2022

2123
type DirectiveList []*Directive
2224

25+
type DirectiveVisitor interface {
26+
Before(ctx context.Context, directive *Directive, input interface{}) error
27+
After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error)
28+
}
29+
2330
// Returns the Directive in the DirectiveList by name or nil if not found.
2431
func (l DirectiveList) Get(name string) *Directive {
2532
for _, d := range l {

0 commit comments

Comments
 (0)