Skip to content

Commit

Permalink
Handle websocket data channel closing
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Sep 28, 2024
1 parent 4466fc1 commit 94da0c7
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 30 deletions.
6 changes: 6 additions & 0 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -135,6 +136,11 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
if sub.hasBeenUnsubscribed {
return nil
}
if wsMsg.Type == webSocketTypeComplete {
reflect.ValueOf(sub.interfaceChan).Close()
return nil
}

return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload)
}

Expand Down
102 changes: 74 additions & 28 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,90 @@ func TestMutation(t *testing.T) {
require.Errorf(t, err, "client does not support mutations")
}

type subscriptionResult struct {
clientUnsubscribed bool
serverChannelClosed bool
}

func TestSubscription(t *testing.T) {
_ = `# @genqlient
subscription count { count }`

ctx := context.Background()
server := server.RunServer()
defer server.Close()
wsClient := newRoundtripWebScoketClient(t, server.URL)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)
cases := []struct {
name string
unsubThreshold time.Duration
expected subscriptionResult
}{
{
name: "server_closed_channel",
unsubThreshold: 5 * time.Second,
expected: subscriptionResult{
clientUnsubscribed: false,
serverChannelClosed: true,
},
},
{
name: "client_unsubscribed",
unsubThreshold: 300 * time.Millisecond,
expected: subscriptionResult{
clientUnsubscribed: true,
serverChannelClosed: false,
},
},
}

dataChan, subscriptionID, err := count(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()
counter := 0
start := time.Now()
for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
loop = false
break
}
require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.Count)
require.Nil(t, resp.Errors)
if time.Since(start) > time.Second*5 {
err = wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
loop = false
}
counter++
case err := <-errChan:
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)
errChan, err := wsClient.Start(ctx)
require.NoError(t, err)
case <-time.After(time.Second * 10):
require.NoError(t, fmt.Errorf("subscription timed out"))
}

dataChan, subscriptionID, err := count(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()

var (
counter = 0
start = time.Now()
result = subscriptionResult{}
)

for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
result.serverChannelClosed = true
loop = false
break
}

require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.Count)
require.Nil(t, resp.Errors)

if time.Since(start) > tc.unsubThreshold {
err := wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
result.clientUnsubscribed = true
loop = false
}

counter++

case err := <-errChan:
require.NoError(t, err)

case <-time.After(10 * time.Second):
require.NoError(t, fmt.Errorf("subscription timed out"))
}
}

assert.Equal(t, tc.expected, result)
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebScoketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
Expand Down
6 changes: 5 additions & 1 deletion internal/integration/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,15 @@ func (m mutationResolver) CreateUser(ctx context.Context, input NewUser) (*User,
func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) {
respChan := make(chan int, 1)
go func(respChan chan int) {
defer close(respChan)
counter := 0
for {
if counter == 10 {
return
}
respChan <- counter
counter++
time.Sleep(time.Second)
time.Sleep(100 * time.Millisecond)
}
}(respChan)
return respChan, nil
Expand Down

0 comments on commit 94da0c7

Please sign in to comment.