Skip to content

Commit

Permalink
removing ResultIterator in favor of a result channel
Browse files Browse the repository at this point in the history
  • Loading branch information
branden-blackline committed May 17, 2020
1 parent d6ffa49 commit f21d0c7
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 188 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
module github.com/graphql-go/graphql

go 1.13
169 changes: 13 additions & 156 deletions subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,148 +3,11 @@ package graphql
import (
"context"
"fmt"
"sync"

"github.com/graphql-go/graphql/gqlerrors"
"github.com/graphql-go/graphql/language/ast"
)

// Subscriber subscriber
type Subscriber struct {
message chan interface{}
done chan interface{}
}

// Message returns the subscriber message channel
func (c *Subscriber) Message() chan interface{} {
return c.message
}

// Done returns the subscriber done channel
func (c *Subscriber) Done() chan interface{} {
return c.done
}

// NewSubscriber creates a new subscriber
func NewSubscriber(message, done chan interface{}) *Subscriber {
return &Subscriber{
message: message,
done: done,
}
}

// ResultIteratorParams parameters passed to the result iterator handler
type ResultIteratorParams struct {
ResultCount int64 // number of results this iterator has processed
Result *Result // the current result
Done func() // Removes the current handler
Cancel func() // Cancels the iterator, same as iterator.Cancel()
}

// ResultIteratorFn a result iterator handler
type ResultIteratorFn func(p ResultIteratorParams)

// holds subscription handler data
type subscriptionHanlderConfig struct {
handler ResultIteratorFn
doneFunc func()
}

// ResultIterator handles processing results from a chan *Result
type ResultIterator struct {
currentHandlerID int64
count int64
mx sync.Mutex
ch chan *Result
iterDone chan interface{}
subDone chan interface{}
cancelled bool
handlers map[int64]*subscriptionHanlderConfig
}

func (c *ResultIterator) incrimentCount() int64 {
c.mx.Lock()
defer c.mx.Unlock()
c.count++
return c.count
}

// NewResultIterator creates a new iterator and starts handling message on the result channel
func NewResultIterator(subDone chan interface{}, ch chan *Result) *ResultIterator {
iterator := &ResultIterator{
currentHandlerID: 0,
count: 0,
iterDone: make(chan interface{}),
subDone: subDone,
ch: ch,
cancelled: false,
handlers: map[int64]*subscriptionHanlderConfig{},
}

go func() {
for {
select {
case <-iterator.iterDone:
subDone <- true
return
case res := <-iterator.ch:
if iterator.cancelled {
return
}

count := iterator.incrimentCount()
for _, h := range iterator.handlers {
h.handler(ResultIteratorParams{
ResultCount: int64(count),
Result: res,
Done: h.doneFunc,
Cancel: iterator.Cancel,
})
}
}
}
}()

return iterator
}

// adds a new handler
func (c *ResultIterator) addHandler(handler ResultIteratorFn) {
c.mx.Lock()
defer c.mx.Unlock()

handlerID := c.currentHandlerID + 1
c.currentHandlerID = handlerID
c.handlers[handlerID] = &subscriptionHanlderConfig{
handler: handler,
doneFunc: func() {
c.removeHandler(handlerID)
},
}
}

// removes a handler and cancels if no more handlers exist
func (c *ResultIterator) removeHandler(handlerID int64) {
c.mx.Lock()
defer c.mx.Unlock()

delete(c.handlers, handlerID)
if len(c.handlers) == 0 {
c.Cancel()
}
}

// ForEach adds a handler and handles each message as they come
func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
c.addHandler(handler)
}

// Cancel cancels the iterator
func (c *ResultIterator) Cancel() {
c.cancelled = true
c.iterDone <- true
}

// SubscribeParams parameters for subscribing
type SubscribeParams struct {
Schema Schema
Expand All @@ -158,14 +21,8 @@ type SubscribeParams struct {
}

// Subscribe performs a subscribe operation
func Subscribe(p SubscribeParams) *ResultIterator {
func Subscribe(ctx context.Context, p SubscribeParams) chan *Result {
resultChannel := make(chan *Result)
doneChannel := make(chan interface{})
// Use background context if no context was provided
ctx := p.ContextValue
if ctx == nil {
ctx = context.Background()
}

var mapSourceToResponse = func(payload interface{}) *Result {
return Execute(ExecuteParams{
Expand All @@ -174,18 +31,18 @@ func Subscribe(p SubscribeParams) *ResultIterator {
AST: p.Document,
OperationName: p.OperationName,
Args: p.VariableValues,
Context: ctx,
Context: p.ContextValue,
})
}

go func() {
result := &Result{}
defer func() {
if err := recover(); err != nil {
fmt.Println("SUBSCRIPTION RECOVERER", err)
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
resultChannel <- result
}
resultChannel <- result
close(resultChannel)
}()

exeContext, err := buildExecutionContext(buildExecutionCtxParams{
Expand All @@ -195,7 +52,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
OperationName: p.OperationName,
Args: p.VariableValues,
Result: result,
Context: ctx,
Context: p.ContextValue,
})

if err != nil {
Expand Down Expand Up @@ -263,7 +120,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
Source: p.RootValue,
Args: args,
Info: info,
Context: ctx,
Context: p.ContextValue,
})
if err != nil {
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
Expand All @@ -279,14 +136,14 @@ func Subscribe(p SubscribeParams) *ResultIterator {
}

switch fieldResult.(type) {
case *Subscriber:
sub := fieldResult.(*Subscriber)
case chan interface{}:
sub := fieldResult.(chan interface{})
for {
select {
case <-doneChannel:
sub.done <- true
case <-ctx.Done():
return
case res := <-sub.message:

case res := <-sub:
resultChannel <- mapSourceToResponse(res)
}
}
Expand All @@ -296,6 +153,6 @@ func Subscribe(p SubscribeParams) *ResultIterator {
}
}()

// return a result iterator
return NewResultIterator(doneChannel, resultChannel)
// return a result channel
return resultChannel
}
73 changes: 41 additions & 32 deletions subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ func TestSubscription(t *testing.T) {
return fmt.Sprintf("count=%v", p.Source), nil
},
Subscribe: func(p ResolveParams) (interface{}, error) {
sub := NewSubscriber(m, make(chan interface{}))
return sub, nil
return m, nil
},
},
"watch_should_fail": &Field{
Expand All @@ -74,60 +73,70 @@ func TestSubscription(t *testing.T) {
return
}

failIterator := Subscribe(SubscribeParams{
// test a subscribe that should fail due to no return value
fctx, fCancelFunc := context.WithCancel(context.Background())
fail := Subscribe(fctx, SubscribeParams{
Schema: schema,
Document: document2,
})

// test a subscribe that should fail due to no return value
failIterator.ForEach(func(p ResultIteratorParams) {
if !p.Result.HasErrors() {
t.Errorf("subscribe failed to catch nil result from subscribe")
p.Done()
go func() {
for {
result := <-fail
if !result.HasErrors() {
t.Errorf("subscribe failed to catch nil result from subscribe")
}
fCancelFunc()
return
}
p.Done()
return
})
}()

resultIterator := Subscribe(SubscribeParams{
// test subscription data
resultCount := 0
rctx, rCancelFunc := context.WithCancel(context.Background())
results := Subscribe(rctx, SubscribeParams{
Schema: schema,
Document: document1,
ContextValue: context.Background(),
})

resultIterator.ForEach(func(p ResultIteratorParams) {
if p.Result.HasErrors() {
t.Errorf("subscribe error(s): %v", p.Result.Errors)
p.Done()
return
}

if p.Result.Data != nil {
data := p.Result.Data.(map[string]interface{})["watch_count"]
expected := fmt.Sprintf("count=%d", p.ResultCount)
actual := fmt.Sprintf("%v", data)
if actual != expected {
t.Errorf("subscription result error: expected %q, actual %q", expected, actual)
p.Done()
go func() {
for {
result := <-results
if result.HasErrors() {
t.Errorf("subscribe error(s): %v", result.Errors)
rCancelFunc()
return
}

// test the done func by quitting after 3 iterations
// the publisher will publish up to 5
if p.ResultCount >= int64(maxPublish-2) {
p.Done()
return
if result.Data != nil {
resultCount++
data := result.Data.(map[string]interface{})["watch_count"]
expected := fmt.Sprintf("count=%d", resultCount)
actual := fmt.Sprintf("%v", data)
if actual != expected {
t.Errorf("subscription result error: expected %q, actual %q", expected, actual)
rCancelFunc()
return
}

// test the done func by quitting after 3 iterations
// the publisher will publish up to 5
if resultCount >= maxPublish-2 {
rCancelFunc()
return
}
}
}
})
}()

// start publishing
go func() {
for i := 1; i <= maxPublish; i++ {
time.Sleep(200 * time.Millisecond)
m <- i
}
close(m)
}()

// give time for the test to complete
Expand Down

0 comments on commit f21d0c7

Please sign in to comment.