From bc7cce75dd19e6dadbbd831a6e400f58a7f6e0bc Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Sun, 29 Sep 2024 08:28:32 +0200 Subject: [PATCH] Use generics for subscriptions response --- docs/subscriptions.md | 6 +- generate/operation.go.tmpl | 31 ++++++++-- ...tion.graphql-SimpleSubscription.graphql.go | 34 ++++++++--- graphql/client.go | 50 +++++---------- graphql/subscription.go | 27 ++++---- graphql/websocket.go | 32 +++++----- internal/integration/generated.go | 33 +++++++--- internal/integration/integration_test.go | 2 +- internal/integration/roundtrip.go | 41 +++++++++++++ internal/integration/websocket.go | 61 ------------------- 10 files changed, 166 insertions(+), 151 deletions(-) delete mode 100644 internal/integration/websocket.go diff --git a/docs/subscriptions.md b/docs/subscriptions.md index 95f36ebc..146e9408 100644 --- a/docs/subscriptions.md +++ b/docs/subscriptions.md @@ -6,8 +6,6 @@ This document describes how to use genqlient to make GraphQL subscriptions. It a You will need to use a different client calling `graphql.NewClientUsingWebSocket`, passing as a parameter your own websocket client. -A WebSocket client matching your subcription is generated based on your schema. E.g. for a 'count' subscription, a client named `countClientUsingWebSocket` is generated. This name will be used in the examples below. - Here is how to configure your webSocket client to match the interfaces: ### Example using `github.com/gorilla/websocket` @@ -78,7 +76,7 @@ Once your websocket client matches the interfaces, you can get your `graphql.Web a loop for incoming messages and errors: ```go - graphqlClient := countClientUsingWebSocket( + graphqlClient := graphql.NewClientUsingWebSocket( "ws://localhost:8080/query", &MyDialer{Dialer: dialer}, headers, @@ -118,7 +116,7 @@ a loop for incoming messages and errors: } ``` -To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `countClientUsingWebSocket()`: +To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `graphql.NewClientUsingWebSocket()`: ```go headers.Add("Sec-WebSocket-Protocol", "graphql-ws") ``` diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 728604c8..5f70c8b4 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -7,7 +7,7 @@ func {{.Name}}( ctx_ {{ref .Config.ContextType}}, {{end}} {{- if not .Config.ClientGetter -}} - client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}[{{.ResponseName}}]{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}}, + client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}}, {{end}} {{- if .Input -}} {{- range .Input.Fields -}} @@ -15,7 +15,7 @@ func {{.Name}}( {{.GraphQLName}} {{.GoType.Reference}}, {{end -}} {{end -}} -) ({{if eq .Type "subscription"}}dataChan_ chan graphql.WsResponse[{{.ResponseName}}], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { +) ({{if eq .Type "subscription"}}dataChan_ chan graphql.BaseResponse[*{{.ResponseName}}], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) { req_ := &graphql.Request{ OpName: "{{.Name}}", Query: {{.Name}}_Operation, @@ -36,8 +36,8 @@ func {{.Name}}( } {{end}} {{if eq .Type "subscription"}} - dataChan_ = make(chan graphql.WsResponse[{{.ResponseName}}]) - subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[{{.ResponseName}}]) + dataChan_ = make(chan graphql.BaseResponse[*{{.ResponseName}}]) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData) {{else}} data_ = &{{.ResponseName}}{} resp_ := &graphql.Response{Data: data_} @@ -53,7 +53,26 @@ func {{.Name}}( } {{if eq .Type "subscription"}} -func {{.Name}}ClientUsingWebSocket(endpoint string, wsDialer graphql.Dialer, headers http.Header) graphql.WebSocketClient[{{.ResponseName}}] { - return graphql.NewClientUsingWebSocket[{{.ResponseName}}](endpoint, wsDialer, headers) +func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp graphql.BaseResponse[*{{.ResponseName}}] + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan graphql.BaseResponse[*{{.ResponseName}}]) + if !ok { + return errors.New("failed to cast interface into 'chan graphql.BaseResponse[*{{.ResponseName}}]'") + } + dataChan_ <- wsResp + return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 3053e0cc..fae2767d 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -3,7 +3,8 @@ package test import ( - "net/http" + "encoding/json" + "errors" "github.com/Khan/genqlient/graphql" ) @@ -25,20 +26,39 @@ subscription SimpleSubscription { // To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func SimpleSubscription( - client_ graphql.WebSocketClient[SimpleSubscriptionResponse], -) (dataChan_ chan graphql.WsResponse[SimpleSubscriptionResponse], subscriptionID_ string, err_ error) { + client_ graphql.WebSocketClient, +) (dataChan_ chan graphql.BaseResponse[*SimpleSubscriptionResponse], subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "SimpleSubscription", Query: SimpleSubscription_Operation, } - dataChan_ = make(chan graphql.WsResponse[SimpleSubscriptionResponse]) - subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[SimpleSubscriptionResponse]) + dataChan_ = make(chan graphql.BaseResponse[*SimpleSubscriptionResponse]) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, SimpleSubscriptionForwardData) return dataChan_, subscriptionID_, err_ } -func SimpleSubscriptionClientUsingWebSocket(endpoint string, wsDialer graphql.Dialer, headers http.Header) graphql.WebSocketClient[SimpleSubscriptionResponse] { - return graphql.NewClientUsingWebSocket[SimpleSubscriptionResponse](endpoint, wsDialer, headers) +func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp graphql.BaseResponse[*SimpleSubscriptionResponse] + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan graphql.BaseResponse[*SimpleSubscriptionResponse]) + if !ok { + return errors.New("failed to cast interface into 'chan graphql.BaseResponse[*SimpleSubscriptionResponse]'") + } + dataChan_ <- wsResp + return nil } diff --git a/graphql/client.go b/graphql/client.go index cf2687a0..12389616 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -40,7 +40,7 @@ type Client interface { ) error } -type WebSocketClient[T any] interface { +type WebSocketClient interface { // Start must open a webSocket connection and subscribe to an endpoint // of the client's GraphQL API. // @@ -59,16 +59,17 @@ type WebSocketClient[T any] interface { // req contains the data to be sent to the GraphQL server. Will be marshalled // into JSON bytes. // - // dataChan is a channel used to send the data that arrives via the + // interfaceChan is a channel used to send the data that arrives via the // webSocket connection (it is the channel that is passed to `forwardDataFunc`). // - // forwardDataFunc is the function that will handle the subscription's response. + // forwardDataFunc is the function that will cast the received interface into + // the valid type for the subscription's response. // // Returns a subscriptionID if successful, an error otherwise. Subscribe( req *Request, - dataChan chan WsResponse[T], - forwardDataFunc ForwardDataFunction[T], + interfaceChan interface{}, + forwardDataFunc ForwardDataFunction, ) (string, error) // Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API. @@ -77,32 +78,7 @@ type WebSocketClient[T any] interface { // ForwardDataFunction is a part of the WebSocketClient interface, see // [WebSocketClient.Subscribe] for details. -type ForwardDataFunction[T any] func(dataChan chan WsResponse[T], jsonRawMsg json.RawMessage) error - -func ForwardData[T any](dataChan_ chan WsResponse[T], jsonRawMsg json.RawMessage) error { - var gqlResp Response - var wsResp WsResponse[T] - err := json.Unmarshal(jsonRawMsg, &gqlResp) - if err != nil { - return err - } - if len(gqlResp.Errors) == 0 { - err = json.Unmarshal(jsonRawMsg, &wsResp) - if err != nil { - return err - } - } else { - wsResp.Errors = gqlResp.Errors - } - dataChan_ <- wsResp - return nil -} - -type WsResponse[T any] struct { - Data *T `json:"data"` - Extensions map[string]interface{} `json:"extensions,omitempty"` - Errors error `json:"errors"` -} +type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error type client struct { httpClient Doer @@ -155,19 +131,19 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client { // // The client does not support queries nor mutations, and will return an error // if passed a request that attempts one. -func NewClientUsingWebSocket[T any](endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient[T] { +func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient { if headers == nil { headers = http.Header{} } if headers.Get("Sec-WebSocket-Protocol") == "" { headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") } - return &webSocketClient[T]{ + return &webSocketClient{ Dialer: wsDialer, Header: headers, errChan: make(chan error), endpoint: endpoint, - subscriptions: subscriptionMap[T]{map_: make(map[string]subscription[T])}, + subscriptions: subscriptionMap{map_: make(map[string]subscription)}, } } @@ -230,8 +206,10 @@ type Request struct { // It may additionally contain a key named "extensions", that // might hold GraphQL protocol extensions. Extensions and Errors // are optional, depending on the values returned by the server. -type Response struct { - Data interface{} `json:"data"` +type Response BaseResponse[any] + +type BaseResponse[T any] struct { + Data T `json:"data"` Extensions map[string]interface{} `json:"extensions,omitempty"` Errors gqlerror.List `json:"errors,omitempty"` } diff --git a/graphql/subscription.go b/graphql/subscription.go index 5df107a3..f86f75ee 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -2,41 +2,42 @@ package graphql import ( "fmt" + "reflect" "sync" ) // map of subscription ID to subscription -type subscriptionMap[T any] struct { - map_ map[string]subscription[T] +type subscriptionMap struct { + map_ map[string]subscription sync.RWMutex } -type subscription[T any] struct { - dataChan chan WsResponse[T] - forwardDataFunc ForwardDataFunction[T] +type subscription struct { + interfaceChan interface{} + forwardDataFunc ForwardDataFunction id string hasBeenUnsubscribed bool } -func (s *subscriptionMap[T]) Create(subscriptionID string, dataChan chan WsResponse[T], forwardDataFunc ForwardDataFunction[T]) { +func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { s.Lock() defer s.Unlock() - s.map_[subscriptionID] = subscription[T]{ + s.map_[subscriptionID] = subscription{ id: subscriptionID, - dataChan: dataChan, + interfaceChan: interfaceChan, forwardDataFunc: forwardDataFunc, hasBeenUnsubscribed: false, } } -func (s *subscriptionMap[T]) Read(subscriptionID string) (sub subscription[T], success bool) { +func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) { s.RLock() defer s.RUnlock() sub, success = s.map_[subscriptionID] return sub, success } -func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error { +func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { s.Lock() defer s.Unlock() unsub, success := s.map_[subscriptionID] @@ -45,11 +46,11 @@ func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error { } unsub.hasBeenUnsubscribed = true s.map_[subscriptionID] = unsub - close(s.map_[subscriptionID].dataChan) + reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close() return nil } -func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) { +func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) { s.RLock() defer s.RUnlock() for subID := range s.map_ { @@ -58,7 +59,7 @@ func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) { return subscriptionIDs } -func (s *subscriptionMap[_]) Delete(subscriptionID string) { +func (s *subscriptionMap) Delete(subscriptionID string) { s.Lock() defer s.Unlock() delete(s.map_, subscriptionID) diff --git a/graphql/websocket.go b/graphql/websocket.go index cd222374..bb9a98b6 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -42,13 +42,13 @@ const ( closeMessage = 8 ) -type webSocketClient[T any] struct { +type webSocketClient struct { Dialer Dialer Header http.Header endpoint string conn WSConn errChan chan error - subscriptions subscriptionMap[T] + subscriptions subscriptionMap isClosing bool sync.Mutex } @@ -65,14 +65,14 @@ type webSocketReceiveMessage struct { Payload json.RawMessage `json:"payload"` } -func (w *webSocketClient[_]) sendInit() error { +func (w *webSocketClient) sendInit() error { connInitMsg := webSocketSendMessage{ Type: webSocketTypeConnInit, } return w.sendStructAsJSON(connInitMsg) } -func (w *webSocketClient[_]) sendStructAsJSON(object any) error { +func (w *webSocketClient) sendStructAsJSON(object any) error { jsonBytes, err := json.Marshal(object) if err != nil { return err @@ -80,7 +80,7 @@ func (w *webSocketClient[_]) sendStructAsJSON(object any) error { return w.conn.WriteMessage(textMessage, jsonBytes) } -func (w *webSocketClient[_]) waitForConnAck() error { +func (w *webSocketClient) waitForConnAck() error { var connAckReceived bool var err error start := time.Now() @@ -96,7 +96,7 @@ func (w *webSocketClient[_]) waitForConnAck() error { return nil } -func (w *webSocketClient[_]) handleErr(err error) { +func (w *webSocketClient) handleErr(err error) { w.Lock() defer w.Unlock() if !w.isClosing { @@ -104,7 +104,7 @@ func (w *webSocketClient[_]) handleErr(err error) { } } -func (w *webSocketClient[_]) listenWebSocket() { +func (w *webSocketClient) listenWebSocket() { for { if w.isClosing { return @@ -122,7 +122,7 @@ func (w *webSocketClient[_]) listenWebSocket() { } } -func (w *webSocketClient[_]) forwardWebSocketData(message []byte) error { +func (w *webSocketClient) forwardWebSocketData(message []byte) error { var wsMsg webSocketReceiveMessage err := json.Unmarshal(message, &wsMsg) if err != nil { @@ -135,10 +135,10 @@ func (w *webSocketClient[_]) forwardWebSocketData(message []byte) error { if sub.hasBeenUnsubscribed { return nil } - return sub.forwardDataFunc(sub.dataChan, wsMsg.Payload) + return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) } -func (w *webSocketClient[_]) receiveWebSocketConnAck() (bool, error) { +func (w *webSocketClient) receiveWebSocketConnAck() (bool, error) { _, message, err := w.conn.ReadMessage() if err != nil { return false, err @@ -155,7 +155,7 @@ func checkConnectionAckReceived(message []byte) (bool, error) { return wsMessage.Type == webSocketTypeConnAck, nil } -func (w *webSocketClient[_]) Start(ctx context.Context) (errChan chan error, err error) { +func (w *webSocketClient) Start(ctx context.Context) (errChan chan error, err error) { w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) if err != nil { return nil, err @@ -174,7 +174,7 @@ func (w *webSocketClient[_]) Start(ctx context.Context) (errChan chan error, err return w.errChan, err } -func (w *webSocketClient[_]) Close() error { +func (w *webSocketClient) Close() error { if w.conn == nil { return nil } @@ -193,7 +193,7 @@ func (w *webSocketClient[_]) Close() error { return w.conn.Close() } -func (w *webSocketClient[T]) Subscribe(req *Request, dataChan chan WsResponse[T], forwardDataFunc ForwardDataFunction[T]) (string, error) { +func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) (string, error) { if req.Query != "" { if strings.HasPrefix(strings.TrimSpace(req.Query), "query") { return "", fmt.Errorf("client does not support queries") @@ -204,7 +204,7 @@ func (w *webSocketClient[T]) Subscribe(req *Request, dataChan chan WsResponse[T] } subscriptionID := uuid.NewString() - w.subscriptions.Create(subscriptionID, dataChan, forwardDataFunc) + w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc) subscriptionMsg := webSocketSendMessage{ Type: webSocketTypeSubscribe, Payload: req, @@ -218,7 +218,7 @@ func (w *webSocketClient[T]) Subscribe(req *Request, dataChan chan WsResponse[T] return subscriptionID, nil } -func (w *webSocketClient[_]) Unsubscribe(subscriptionID string) error { +func (w *webSocketClient) Unsubscribe(subscriptionID string) error { completeMsg := webSocketSendMessage{ Type: webSocketTypeComplete, ID: subscriptionID, @@ -234,7 +234,7 @@ func (w *webSocketClient[_]) Unsubscribe(subscriptionID string) error { return nil } -func (w *webSocketClient[_]) UnsubscribeAll() error { +func (w *webSocketClient) UnsubscribeAll() error { subscriptionIDs := w.subscriptions.GetAllIDs() for _, subscriptionID := range subscriptionIDs { err := w.Unsubscribe(subscriptionID) diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 65347b98..07289893 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -5,8 +5,8 @@ package integration import ( "context" "encoding/json" + "errors" "fmt" - "net/http" "time" "github.com/Khan/genqlient/graphql" @@ -3106,21 +3106,40 @@ subscription count { // To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] func count( ctx_ context.Context, - client_ graphql.WebSocketClient[countResponse], -) (dataChan_ chan graphql.WsResponse[countResponse], subscriptionID_ string, err_ error) { + client_ graphql.WebSocketClient, +) (dataChan_ chan graphql.BaseResponse[*countResponse], subscriptionID_ string, err_ error) { req_ := &graphql.Request{ OpName: "count", Query: count_Operation, } - dataChan_ = make(chan graphql.WsResponse[countResponse]) - subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[countResponse]) + dataChan_ = make(chan graphql.BaseResponse[*countResponse]) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countForwardData) return dataChan_, subscriptionID_, err_ } -func countClientUsingWebSocket(endpoint string, wsDialer graphql.Dialer, headers http.Header) graphql.WebSocketClient[countResponse] { - return graphql.NewClientUsingWebSocket[countResponse](endpoint, wsDialer, headers) +func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp graphql.BaseResponse[*countResponse] + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan graphql.BaseResponse[*countResponse]) + if !ok { + return errors.New("failed to cast interface into 'chan graphql.BaseResponse[*countResponse]'") + } + dataChan_ <- wsResp + return nil } // The mutation executed by createUser. diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 628b357c..301ade2b 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -64,7 +64,7 @@ func TestSubscription(t *testing.T) { ctx := context.Background() server := server.RunServer() defer server.Close() - wsClient := newCountWebSocketClient(t, server.URL) + wsClient := newRoundtripWebScoketClient(t, server.URL) errChan, err := wsClient.Start(ctx) require.NoError(t, err) diff --git a/internal/integration/roundtrip.go b/internal/integration/roundtrip.go index 1be9996a..c9343d54 100644 --- a/internal/integration/roundtrip.go +++ b/internal/integration/roundtrip.go @@ -10,9 +10,11 @@ import ( "fmt" "io" "net/http" + "strings" "testing" "github.com/Khan/genqlient/graphql" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -46,6 +48,7 @@ func (t *lastResponseTransport) RoundTrip(req *http.Request) (*http.Response, er // for each request it processes. type roundtripClient struct { wrapped graphql.Client + wsWrapped graphql.WebSocketClient transport *lastResponseTransport t *testing.T } @@ -103,6 +106,22 @@ func (c *roundtripClient) MakeRequest(ctx context.Context, req *graphql.Request, return nil } +func (c *roundtripClient) Start(ctx context.Context) (errChan chan error, err error) { + return c.wsWrapped.Start(ctx) +} + +func (c *roundtripClient) Close() error { + return c.wsWrapped.Close() +} + +func (c *roundtripClient) Subscribe(req *graphql.Request, interfaceChan interface{}, forwardDataFunc graphql.ForwardDataFunction) (string, error) { + return c.wsWrapped.Subscribe(req, interfaceChan, forwardDataFunc) +} + +func (c *roundtripClient) Unsubscribe(subscriptionID string) error { + return c.wsWrapped.Unsubscribe(subscriptionID) +} + func newRoundtripClients(t *testing.T, endpoint string) []graphql.Client { return []graphql.Client{newRoundtripClient(t, endpoint), newRoundtripGetClient(t, endpoint)} } @@ -126,3 +145,25 @@ func newRoundtripGetClient(t *testing.T, endpoint string) graphql.Client { t: t, } } + +type MyDialer struct { + *websocket.Dialer +} + +func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (graphql.WSConn, error) { + conn, resp, err := md.Dialer.DialContext(ctx, urlStr, requestHeader) + resp.Body.Close() + return graphql.WSConn(conn), err +} + +func newRoundtripWebScoketClient(t *testing.T, endpoint string) graphql.WebSocketClient { + dialer := websocket.DefaultDialer + if !strings.HasPrefix(endpoint, "ws") { + _, address, _ := strings.Cut(endpoint, "://") + endpoint = "ws://" + address + } + return &roundtripClient{ + wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil), + t: t, + } +} diff --git a/internal/integration/websocket.go b/internal/integration/websocket.go deleted file mode 100644 index 1e8a50d6..00000000 --- a/internal/integration/websocket.go +++ /dev/null @@ -1,61 +0,0 @@ -package integration - -import ( - "context" - "net/http" - "strings" - "testing" - - "github.com/Khan/genqlient/graphql" - "github.com/gorilla/websocket" -) - -type webSocketClient[T any] struct { - wrapped graphql.WebSocketClient[T] - t *testing.T -} - -func (c *webSocketClient[_]) Start(ctx context.Context) (errChan chan error, err error) { - return c.wrapped.Start(ctx) -} - -func (c *webSocketClient[_]) Close() error { - return c.wrapped.Close() -} - -func (c *webSocketClient[T]) Subscribe(req *graphql.Request, dataChan chan graphql.WsResponse[T], forwardDataFunc graphql.ForwardDataFunction[T]) (string, error) { - return c.wrapped.Subscribe(req, dataChan, forwardDataFunc) -} - -func (c *webSocketClient[_]) Unsubscribe(subscriptionID string) error { - return c.wrapped.Unsubscribe(subscriptionID) -} - -type MyDialer struct { - *websocket.Dialer -} - -func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (graphql.WSConn, error) { - conn, resp, err := md.Dialer.DialContext(ctx, urlStr, requestHeader) - resp.Body.Close() - return graphql.WSConn(conn), err -} - -func wsAdress(endpoint string) string { - if !strings.HasPrefix(endpoint, "ws") { - _, address, _ := strings.Cut(endpoint, "://") - endpoint = "ws://" + address - } - return endpoint -} - -func newCountWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient[countResponse] { - return &webSocketClient[countResponse]{ - wrapped: countClientUsingWebSocket( - wsAdress(endpoint), - &MyDialer{Dialer: websocket.DefaultDialer}, - nil, - ), - t: t, - } -}