From 6246e7ef589b2fa21f3b28ae1ccae08fda05f063 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Fri, 27 Sep 2024 20:19:10 +0200 Subject: [PATCH] Simplify --- generate/operation.go.tmpl | 8 --- ...tion.graphql-SimpleSubscription.graphql.go | 8 --- graphql/websocket.go | 6 ++ internal/integration/generated.go | 8 --- internal/integration/integration_test.go | 72 ++++++++++--------- 5 files changed, 45 insertions(+), 57 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index da0f0592..1b70890e 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -62,14 +62,6 @@ type {{.Name}}WsResponse struct { func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp {{.Name}}WsResponse - if len(jsonRawMsg) == 0 { - dataChan_, ok := interfaceChan.(chan {{.Name}}WsResponse) - if !ok { - return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") - } - close(dataChan_) - return nil - } err := json.Unmarshal(jsonRawMsg, &gqlResp) if err != nil { return err diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 0d4a7f87..4f1ebca6 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -48,14 +48,6 @@ type SimpleSubscriptionWsResponse struct { func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp SimpleSubscriptionWsResponse - if len(jsonRawMsg) == 0 { - dataChan_, ok := interfaceChan.(chan SimpleSubscriptionWsResponse) - if !ok { - return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") - } - close(dataChan_) - return nil - } err := json.Unmarshal(jsonRawMsg, &gqlResp) if err != nil { return err diff --git a/graphql/websocket.go b/graphql/websocket.go index bb9a98b6..10b77daa 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "reflect" "strings" "sync" "time" @@ -135,6 +136,11 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if sub.hasBeenUnsubscribed { return nil } + if wsMsg.Type == webSocketTypeComplete { + reflect.ValueOf(sub.interfaceChan).Close() + return nil + } + return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 15062db1..8a8f70cb 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3128,14 +3128,6 @@ type countWsResponse struct { func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { var gqlResp graphql.Response var wsResp countWsResponse - if len(jsonRawMsg) == 0 { - dataChan_, ok := interfaceChan.(chan countWsResponse) - if !ok { - return errors.New("failed to cast interface into 'chan countWsResponse'") - } - close(dataChan_) - return nil - } err := json.Unmarshal(jsonRawMsg, &gqlResp) if err != nil { return err diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index b783684e..c30f05f9 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -57,18 +57,19 @@ func TestMutation(t *testing.T) { require.Errorf(t, err, "client does not support mutations") } -type subscriptionCountLoopResult struct { - loop bool - clientUnsubscribed bool - serverClosed bool +type subscriptionCountResult struct { + clientUnsubscribed bool + serverChannelClosed bool } func subscriptionCountLoop( ctx context.Context, t *testing.T, - wsClient graphql.WebSocketClient, - clientUnsubscribeDuration time.Duration, -) subscriptionCountLoopResult { + serverURL string, + unsubThreshold time.Duration, +) *subscriptionCountResult { + + wsClient := newRoundtripWebScoketClient(t, serverURL) errChan, err := wsClient.Start(ctx) require.NoError(t, err) @@ -77,75 +78,80 @@ func subscriptionCountLoop( defer wsClient.Close() var ( - counter = 0 - start = time.Now() - loop = true - clientUnsubscribed = false - serverClosed = false + counter = 0 + start = time.Now() + result = &subscriptionCountResult{} ) - for loop { + for loop := true; loop; { select { case resp, more := <-dataChan: if !more { + result.serverChannelClosed = true loop = false - serverClosed = true break } + require.NotNil(t, resp.Data) assert.Equal(t, counter, resp.Data.Count) require.Nil(t, resp.Errors) - if time.Since(start) > clientUnsubscribeDuration { + + if time.Since(start) > unsubThreshold { err := wsClient.Unsubscribe(subscriptionID) require.NoError(t, err) + result.clientUnsubscribed = true loop = false - clientUnsubscribed = true } + counter++ + case err := <-errChan: require.NoError(t, err) + return nil + case <-time.After(5 * time.Second): require.NoError(t, fmt.Errorf("subscription timed out")) + return nil } } - return subscriptionCountLoopResult{ - loop: loop, - clientUnsubscribed: clientUnsubscribed, - serverClosed: serverClosed, - } + return result } -func TestSubscriptionServerClose(t *testing.T) { +func TestSubscriptionServerClosedChannel(t *testing.T) { _ = `# @genqlient subscription count { count }` ctx := context.Background() server := server.RunServer() defer server.Close() - wsClient := newRoundtripWebScoketClient(t, server.URL) - result := subscriptionCountLoop(ctx, t, wsClient, 5*time.Second) + actual := subscriptionCountLoop(ctx, t, server.URL, 5*time.Second) + require.NotNil(t, actual) - assert.False(t, result.loop) - assert.False(t, result.clientUnsubscribed) - assert.True(t, result.serverClosed) + expected := &subscriptionCountResult{ + clientUnsubscribed: false, + serverChannelClosed: true, + } + assert.Equal(t, expected, actual) } -func TestSubscriptionClientClose(t *testing.T) { +func TestSubscriptionClientUnsubscribed(t *testing.T) { _ = `# @genqlient subscription count { count }` ctx := context.Background() server := server.RunServer() defer server.Close() - wsClient := newRoundtripWebScoketClient(t, server.URL) - result := subscriptionCountLoop(ctx, t, wsClient, 300*time.Millisecond) + actual := subscriptionCountLoop(ctx, t, server.URL, 300*time.Millisecond) + require.NotNil(t, actual) - assert.False(t, result.loop) - assert.True(t, result.clientUnsubscribed) - assert.False(t, result.serverClosed) + expected := &subscriptionCountResult{ + clientUnsubscribed: true, + serverChannelClosed: false, + } + assert.Equal(t, expected, actual) } func TestServerError(t *testing.T) {