diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 4f1ebca6..0d4a7f87 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -48,6 +48,14 @@ type SimpleSubscriptionWsResponse struct { func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse + if len(jsonRawMsg) == 0 { + dataChan_, ok := interfaceChan.(chan SimpleSubscriptionWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") + } + close(dataChan_) + return nil + } err := json.Unmarshal(jsonRawMsg, &gqlResp) if err != nil { return err diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 301ade2b..48dba42c 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -57,45 +57,95 @@ func TestMutation(t *testing.T) { require.Errorf(t, err, "client does not support mutations") } -func TestSubscription(t *testing.T) { - _ = `# @genqlient - subscription count { count }` - - ctx := context.Background() - server := server.RunServer() - defer server.Close() - wsClient := newRoundtripWebScoketClient(t, server.URL) +type subscriptionCountLoopResult struct { + loop bool + clientTriggerUnsubscribe bool + serverClosedChannel bool +} +func subscriptionCountLoop( + ctx context.Context, + t *testing.T, + wsClient graphql.WebSocketClient, + clientUnsubscribeDuration time.Duration, +) subscriptionCountLoopResult { errChan, err := wsClient.Start(ctx) require.NoError(t, err) dataChan, subscriptionID, err := count(ctx, wsClient) require.NoError(t, err) defer wsClient.Close() - counter := 0 - start := time.Now() - for loop := true; loop; { + + var ( + counter = 0 + start = time.Now() + loop = true + clientTriggerUnsubscribe = false + serverClosedChannel = false + ) + + for loop { select { case resp, more := <-dataChan: if !more { loop = false + serverClosedChannel = true break } require.NotNil(t, resp.Data) assert.Equal(t, counter, resp.Data.Count) require.Nil(t, resp.Errors) - if time.Since(start) > time.Second*5 { - err = wsClient.Unsubscribe(subscriptionID) + if time.Since(start) > clientUnsubscribeDuration { + err := wsClient.Unsubscribe(subscriptionID) require.NoError(t, err) loop = false + clientTriggerUnsubscribe = true } counter++ case err := <-errChan: require.NoError(t, err) - case <-time.After(time.Second * 10): + case <-time.After(5 * time.Second): require.NoError(t, fmt.Errorf("subscription timed out")) } } + + return subscriptionCountLoopResult{ + loop: loop, + clientTriggerUnsubscribe: clientTriggerUnsubscribe, + serverClosedChannel: serverClosedChannel, + } +} + +func TestSubscriptionServerClose(t *testing.T) { + _ = `# @genqlient + subscription count { count }` + + ctx := context.Background() + server := server.RunServer() + defer server.Close() + wsClient := newRoundtripWebScoketClient(t, server.URL) + + result := subscriptionCountLoop(ctx, t, wsClient, 5*time.Second) + + assert.False(t, result.loop) + assert.False(t, result.clientTriggerUnsubscribe) + assert.True(t, result.serverClosedChannel) +} + +func TestSubscriptionClientClose(t *testing.T) { + _ = `# @genqlient + subscription count { count }` + + ctx := context.Background() + server := server.RunServer() + defer server.Close() + wsClient := newRoundtripWebScoketClient(t, server.URL) + + result := subscriptionCountLoop(ctx, t, wsClient, 300*time.Millisecond) + + assert.False(t, result.loop) + assert.True(t, result.clientTriggerUnsubscribe) + assert.False(t, result.serverClosedChannel) } func TestServerError(t *testing.T) { diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index eb7ba266..2c79909f 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -162,12 +162,12 @@ func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { defer close(respChan) counter := 0 for { - if counter == 3 { + if counter == 10 { return } respChan <- counter counter++ - time.Sleep(10 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } }(respChan) return respChan, nil