Skip to content

Commit a030af0

Browse files
committed
Added support of custom directives
1 parent 9d31459 commit a030af0

File tree

8 files changed

+159
-20
lines changed

8 files changed

+159
-20
lines changed

example/caching/server/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/graph-gophers/graphql-go"
1010
"github.com/graph-gophers/graphql-go/example/caching"
1111
"github.com/graph-gophers/graphql-go/example/caching/cache"
12+
"github.com/graph-gophers/graphql-go/types"
1213
)
1314

1415
var schema *graphql.Schema
@@ -40,12 +41,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4041
var hint *cache.Hint
4142
if cacheable(r) {
4243
ctx, hints, done := cache.Hintable(r.Context())
43-
response = h.Schema.Exec(ctx, p.Query, p.OperationName, p.Variables)
44+
response = h.Schema.Exec(ctx, p.Query, p.OperationName, p.Variables, map[string]types.DirectiveVisitor{})
4445
done()
4546
v := <-hints
4647
hint = &v
4748
} else {
48-
response = h.Schema.Exec(r.Context(), p.Query, p.OperationName, p.Variables)
49+
response = h.Schema.Exec(r.Context(), p.Query, p.OperationName, p.Variables, map[string]types.DirectiveVisitor{})
4950
}
5051
responseJSON, err := json.Marshal(response)
5152
if err != nil {

gqltesting/testing.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@ import (
1212

1313
graphql "github.com/graph-gophers/graphql-go"
1414
"github.com/graph-gophers/graphql-go/errors"
15+
"github.com/graph-gophers/graphql-go/types"
1516
)
1617

1718
// Test is a GraphQL test case to be used with RunTest(s).
1819
type Test struct {
19-
Context context.Context
20-
Schema *graphql.Schema
21-
Query string
22-
OperationName string
23-
Variables map[string]interface{}
24-
ExpectedResult string
25-
ExpectedErrors []*errors.QueryError
26-
RawResponse bool
20+
Context context.Context
21+
Schema *graphql.Schema
22+
Query string
23+
OperationName string
24+
Variables map[string]interface{}
25+
ExpectedResult string
26+
ExpectedErrors []*errors.QueryError
27+
RawResponse bool
28+
DirectiveVisitors map[string]types.DirectiveVisitor
2729
}
2830

2931
// RunTests runs the given GraphQL test cases as subtests.
@@ -45,7 +47,7 @@ func RunTest(t *testing.T, test *Test) {
4547
if test.Context == nil {
4648
test.Context = context.Background()
4749
}
48-
result := test.Schema.Exec(test.Context, test.Query, test.OperationName, test.Variables)
50+
result := test.Schema.Exec(test.Context, test.Query, test.OperationName, test.Variables, test.DirectiveVisitors)
4951

5052
checkErrors(t, test.ExpectedErrors, result.Errors)
5153

graphql.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ func (s *Schema) ValidateWithVariables(queryString string, variables map[string]
195195
// Exec executes the given query with the schema's resolver. It panics if the schema was created
196196
// without a resolver. If the context get cancelled, no further resolvers will be called and a
197197
// the context error will be returned as soon as possible (not immediately).
198-
func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) *Response {
198+
func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, visitors map[string]types.DirectiveVisitor) *Response {
199199
if !s.res.Resolver.IsValid() {
200200
panic("schema created without resolver, can not exec")
201201
}
202-
return s.exec(ctx, queryString, operationName, variables, s.res)
202+
return s.exec(ctx, queryString, operationName, variables, visitors, s.res)
203203
}
204204

205-
func (s *Schema) exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, res *resolvable.Schema) *Response {
205+
func (s *Schema) exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, visitors map[string]types.DirectiveVisitor, res *resolvable.Schema) *Response {
206206
doc, qErr := query.Parse(queryString)
207207
if qErr != nil {
208208
return &Response{Errors: []*errors.QueryError{qErr}}
@@ -257,6 +257,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
257257
Tracer: s.tracer,
258258
Logger: s.logger,
259259
PanicHandler: s.panicHandler,
260+
Visitors: visitors,
260261
}
261262
varTypes := make(map[string]*introspection.Type)
262263
for _, v := range op.Vars {

graphql_test.go

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/graph-gophers/graphql-go/gqltesting"
1515
"github.com/graph-gophers/graphql-go/introspection"
1616
"github.com/graph-gophers/graphql-go/trace"
17+
"github.com/graph-gophers/graphql-go/types"
1718
)
1819

1920
type helloWorldResolver1 struct{}
@@ -48,6 +49,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam
4849
return "Hello " + args.FullName + "!", nil
4950
}
5051

52+
type customDirectiveVisitor struct {
53+
beforeWasCalled bool
54+
}
55+
56+
func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
57+
v.beforeWasCalled = true
58+
return nil
59+
}
60+
61+
func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
62+
if v.beforeWasCalled == false {
63+
return nil, errors.New("Before directive visitor method wasn't called.")
64+
}
65+
66+
if value, ok := directive.Arguments.Get("customAttribute"); ok {
67+
return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil
68+
} else {
69+
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
70+
}
71+
}
72+
5173
type theNumberResolver struct {
5274
number int32
5375
}
@@ -191,7 +213,6 @@ func TestHelloWorld(t *testing.T) {
191213
}
192214
`,
193215
},
194-
195216
{
196217
Schema: graphql.MustParseSchema(`
197218
schema {
@@ -216,6 +237,67 @@ func TestHelloWorld(t *testing.T) {
216237
})
217238
}
218239

240+
func TestCustomDirective(t *testing.T) {
241+
t.Parallel()
242+
243+
gqltesting.RunTests(t, []*gqltesting.Test{
244+
{
245+
Schema: graphql.MustParseSchema(`
246+
directive @customDirective on FIELD_DEFINITION
247+
248+
schema {
249+
query: Query
250+
}
251+
252+
type Query {
253+
hello_html: String! @customDirective
254+
}
255+
`, &helloSnakeResolver1{}),
256+
Query: `
257+
{
258+
hello_html
259+
}
260+
`,
261+
ExpectedResult: `
262+
{
263+
"hello_html": "Directive 'customDirective' modified result: Hello snake!"
264+
}
265+
`,
266+
DirectiveVisitors: map[string]types.DirectiveVisitor{
267+
"customDirective": &customDirectiveVisitor{},
268+
},
269+
},
270+
{
271+
Schema: graphql.MustParseSchema(`
272+
directive @customDirective(
273+
customAttribute: String!
274+
) on FIELD_DEFINITION
275+
276+
schema {
277+
query: Query
278+
}
279+
280+
type Query {
281+
say_hello(full_name: String!): String! @customDirective(customAttribute: hi)
282+
}
283+
`, &helloSnakeResolver1{}),
284+
Query: `
285+
{
286+
say_hello(full_name: "Johnny")
287+
}
288+
`,
289+
ExpectedResult: `
290+
{
291+
"say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!"
292+
}
293+
`,
294+
DirectiveVisitors: map[string]types.DirectiveVisitor{
295+
"customDirective": &customDirectiveVisitor{},
296+
},
297+
},
298+
})
299+
}
300+
219301
func TestHelloSnake(t *testing.T) {
220302
t.Parallel()
221303

@@ -3757,7 +3839,7 @@ func TestSchema_Exec_without_resolver(t *testing.T) {
37573839
t.Fail()
37583840
}
37593841
}()
3760-
_ = s.Exec(context.Background(), tt.Args.Query, "", map[string]interface{}{})
3842+
_ = s.Exec(context.Background(), tt.Args.Query, "", map[string]interface{}{}, map[string]types.DirectiveVisitor{})
37613843
})
37623844
}
37633845
}
@@ -4113,7 +4195,9 @@ func TestTracer(t *testing.T) {
41134195
"id": "1002",
41144196
}
41154197

4116-
_ = schema.Exec(ctx, doc, opName, variables)
4198+
visitors := map[string]types.DirectiveVisitor{}
4199+
4200+
_ = schema.Exec(ctx, doc, opName, variables, visitors)
41174201

41184202
tracer.mu.Lock()
41194203
defer tracer.mu.Unlock()

internal/exec/exec.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Request struct {
2525
Logger log.Logger
2626
PanicHandler errors.PanicHandler
2727
SubscribeResolverTimeout time.Duration
28+
Visitors map[string]types.DirectiveVisitor
2829
}
2930

3031
func (r *Request) handlePanic(ctx context.Context) {
@@ -208,8 +209,47 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
208209
if f.field.ArgsPacker != nil {
209210
in = append(in, f.field.PackedArgs)
210211
}
212+
213+
// Before hook directive visitor
214+
if len(f.field.Directives) > 0 {
215+
for _, directive := range f.field.Directives {
216+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
217+
var values = make([]interface{}, 0)
218+
for _, inValue := range in {
219+
values = append(values, inValue.Interface())
220+
}
221+
222+
if visitorErr := visitor.Before(ctx, directive, values); err != nil {
223+
err := errors.Errorf("%s", visitorErr)
224+
err.Path = path.toSlice()
225+
err.ResolverError = visitorErr
226+
return err
227+
}
228+
}
229+
}
230+
}
231+
232+
// Call method
211233
callOut := res.Method(f.field.MethodIndex).Call(in)
212234
result = callOut[0]
235+
236+
// After hook directive visitor (when no error is returned from resolver)
237+
if !f.field.HasError && len(f.field.Directives) > 0 {
238+
for _, directive := range f.field.Directives {
239+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
240+
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
241+
if err != nil {
242+
err := errors.Errorf("%s", visitorErr)
243+
err.Path = path.toSlice()
244+
err.ResolverError = visitorErr
245+
return err
246+
} else {
247+
result = reflect.ValueOf(returned)
248+
}
249+
}
250+
}
251+
}
252+
213253
if f.field.HasError && !callOut[1].IsNil() {
214254
resolverErr := callOut[1].Interface().(error)
215255
err := errors.Errorf("%s", resolverErr)

introspection.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/graph-gophers/graphql-go/internal/exec/resolvable"
88
"github.com/graph-gophers/graphql-go/introspection"
9+
"github.com/graph-gophers/graphql-go/types"
910
)
1011

1112
// Inspect allows inspection of the given schema.
@@ -15,7 +16,7 @@ func (s *Schema) Inspect() *introspection.Schema {
1516

1617
// ToJSON encodes the schema in a JSON format used by tools like Relay.
1718
func (s *Schema) ToJSON() ([]byte, error) {
18-
result := s.exec(context.Background(), introspectionQuery, "", nil, &resolvable.Schema{
19+
result := s.exec(context.Background(), introspectionQuery, "", nil, map[string]types.DirectiveVisitor{}, &resolvable.Schema{
1920
Meta: s.res.Meta,
2021
Query: &resolvable.Object{},
2122
Schema: *s.schema,

relay/relay.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
graphql "github.com/graph-gophers/graphql-go"
12+
"github.com/graph-gophers/graphql-go/types"
1213
)
1314

1415
func MarshalID(kind string, spec interface{}) graphql.ID {
@@ -58,7 +59,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
5859
return
5960
}
6061

61-
response := h.Schema.Exec(r.Context(), params.Query, params.OperationName, params.Variables)
62+
response := h.Schema.Exec(r.Context(), params.Query, params.OperationName, params.Variables, map[string]types.DirectiveVisitor{})
6263
responseJSON, err := json.Marshal(response)
6364
if err != nil {
6465
http.Error(w, err.Error(), http.StatusInternalServerError)

types/directive.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package types
22

3-
import "github.com/graph-gophers/graphql-go/errors"
3+
import (
4+
"context"
5+
6+
"github.com/graph-gophers/graphql-go/errors"
7+
)
48

59
// Directive is a representation of the GraphQL Directive.
610
//
@@ -23,6 +27,11 @@ type DirectiveDefinition struct {
2327

2428
type DirectiveList []*Directive
2529

30+
type DirectiveVisitor interface {
31+
Before(ctx context.Context, directive *Directive, input interface{}) error
32+
After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error)
33+
}
34+
2635
// Returns the Directive in the DirectiveList by name or nil if not found.
2736
func (l DirectiveList) Get(name string) *Directive {
2837
for _, d := range l {

0 commit comments

Comments
 (0)