Skip to content

Commit

Permalink
Use generics for subscriptions response
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Sep 29, 2024
1 parent 9dd475e commit bc7cce7
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 151 deletions.
6 changes: 2 additions & 4 deletions docs/subscriptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ This document describes how to use genqlient to make GraphQL subscriptions. It a

You will need to use a different client calling `graphql.NewClientUsingWebSocket`, passing as a parameter your own websocket client.

A WebSocket client matching your subcription is generated based on your schema. E.g. for a 'count' subscription, a client named `countClientUsingWebSocket` is generated. This name will be used in the examples below.

Here is how to configure your webSocket client to match the interfaces:

### Example using `github.com/gorilla/websocket`
Expand Down Expand Up @@ -78,7 +76,7 @@ Once your websocket client matches the interfaces, you can get your `graphql.Web
a loop for incoming messages and errors:

```go
graphqlClient := countClientUsingWebSocket(
graphqlClient := graphql.NewClientUsingWebSocket(
"ws://localhost:8080/query",
&MyDialer{Dialer: dialer},
headers,
Expand Down Expand Up @@ -118,7 +116,7 @@ a loop for incoming messages and errors:
}
```

To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `countClientUsingWebSocket()`:
To change the websocket protocol from its default value `graphql-transport-ws`, add the following header before calling `graphql.NewClientUsingWebSocket()`:
```go
headers.Add("Sec-WebSocket-Protocol", "graphql-ws")
```
31 changes: 25 additions & 6 deletions generate/operation.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ func {{.Name}}(
ctx_ {{ref .Config.ContextType}},
{{end}}
{{- if not .Config.ClientGetter -}}
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}[{{.ResponseName}}]{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
{{end}}
{{- if .Input -}}
{{- range .Input.Fields -}}
{{/* the GraphQL name here is the user-specified variable-name */ -}}
{{.GraphQLName}} {{.GoType.Reference}},
{{end -}}
{{end -}}
) ({{if eq .Type "subscription"}}dataChan_ chan graphql.WsResponse[{{.ResponseName}}], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
) ({{if eq .Type "subscription"}}dataChan_ chan graphql.BaseResponse[*{{.ResponseName}}], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
req_ := &graphql.Request{
OpName: "{{.Name}}",
Query: {{.Name}}_Operation,
Expand All @@ -36,8 +36,8 @@ func {{.Name}}(
}
{{end}}
{{if eq .Type "subscription"}}
dataChan_ = make(chan graphql.WsResponse[{{.ResponseName}}])
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[{{.ResponseName}}])
dataChan_ = make(chan graphql.BaseResponse[*{{.ResponseName}}])
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData)
{{else}}
data_ = &{{.ResponseName}}{}
resp_ := &graphql.Response{Data: data_}
Expand All @@ -53,7 +53,26 @@ func {{.Name}}(
}

{{if eq .Type "subscription"}}
func {{.Name}}ClientUsingWebSocket(endpoint string, wsDialer graphql.Dialer, headers http.Header) graphql.WebSocketClient[{{.ResponseName}}] {
return graphql.NewClientUsingWebSocket[{{.ResponseName}}](endpoint, wsDialer, headers)
func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error {
var gqlResp graphql.Response
var wsResp graphql.BaseResponse[*{{.ResponseName}}]
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_, ok := interfaceChan.(chan graphql.BaseResponse[*{{.ResponseName}}])
if !ok {
return errors.New("failed to cast interface into 'chan graphql.BaseResponse[*{{.ResponseName}}]'")
}
dataChan_ <- wsResp
return nil
}
{{end}}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 14 additions & 36 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Client interface {
) error
}

type WebSocketClient[T any] interface {
type WebSocketClient interface {
// Start must open a webSocket connection and subscribe to an endpoint
// of the client's GraphQL API.
//
Expand All @@ -59,16 +59,17 @@ type WebSocketClient[T any] interface {
// req contains the data to be sent to the GraphQL server. Will be marshalled
// into JSON bytes.
//
// dataChan is a channel used to send the data that arrives via the
// interfaceChan is a channel used to send the data that arrives via the
// webSocket connection (it is the channel that is passed to `forwardDataFunc`).
//
// forwardDataFunc is the function that will handle the subscription's response.
// forwardDataFunc is the function that will cast the received interface into
// the valid type for the subscription's response.
//
// Returns a subscriptionID if successful, an error otherwise.
Subscribe(
req *Request,
dataChan chan WsResponse[T],
forwardDataFunc ForwardDataFunction[T],
interfaceChan interface{},
forwardDataFunc ForwardDataFunction,
) (string, error)

// Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API.
Expand All @@ -77,32 +78,7 @@ type WebSocketClient[T any] interface {

// ForwardDataFunction is a part of the WebSocketClient interface, see
// [WebSocketClient.Subscribe] for details.
type ForwardDataFunction[T any] func(dataChan chan WsResponse[T], jsonRawMsg json.RawMessage) error

func ForwardData[T any](dataChan_ chan WsResponse[T], jsonRawMsg json.RawMessage) error {
var gqlResp Response
var wsResp WsResponse[T]
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_ <- wsResp
return nil
}

type WsResponse[T any] struct {
Data *T `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors error `json:"errors"`
}
type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error

type client struct {
httpClient Doer
Expand Down Expand Up @@ -155,19 +131,19 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket[T any](endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient[T] {
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
return &webSocketClient[T]{
return &webSocketClient{
Dialer: wsDialer,
Header: headers,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap[T]{map_: make(map[string]subscription[T])},
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
}
}

Expand Down Expand Up @@ -230,8 +206,10 @@ type Request struct {
// It may additionally contain a key named "extensions", that
// might hold GraphQL protocol extensions. Extensions and Errors
// are optional, depending on the values returned by the server.
type Response struct {
Data interface{} `json:"data"`
type Response BaseResponse[any]

type BaseResponse[T any] struct {
Data T `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors gqlerror.List `json:"errors,omitempty"`
}
Expand Down
27 changes: 14 additions & 13 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@ package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
type subscriptionMap[T any] struct {
map_ map[string]subscription[T]
type subscriptionMap struct {
map_ map[string]subscription
sync.RWMutex
}

type subscription[T any] struct {
dataChan chan WsResponse[T]
forwardDataFunc ForwardDataFunction[T]
type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
id string
hasBeenUnsubscribed bool
}

func (s *subscriptionMap[T]) Create(subscriptionID string, dataChan chan WsResponse[T], forwardDataFunc ForwardDataFunction[T]) {
func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
s.Lock()
defer s.Unlock()
s.map_[subscriptionID] = subscription[T]{
s.map_[subscriptionID] = subscription{
id: subscriptionID,
dataChan: dataChan,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
}
}

func (s *subscriptionMap[T]) Read(subscriptionID string) (sub subscription[T], success bool) {
func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) {
s.RLock()
defer s.RUnlock()
sub, success = s.map_[subscriptionID]
return sub, success
}

func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error {
func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
s.Lock()
defer s.Unlock()
unsub, success := s.map_[subscriptionID]
Expand All @@ -45,11 +46,11 @@ func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error {
}
unsub.hasBeenUnsubscribed = true
s.map_[subscriptionID] = unsub
close(s.map_[subscriptionID].dataChan)
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
return nil
}

func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) {
func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
for subID := range s.map_ {
Expand All @@ -58,7 +59,7 @@ func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) {
return subscriptionIDs
}

func (s *subscriptionMap[_]) Delete(subscriptionID string) {
func (s *subscriptionMap) Delete(subscriptionID string) {
s.Lock()
defer s.Unlock()
delete(s.map_, subscriptionID)
Expand Down
Loading

0 comments on commit bc7cce7

Please sign in to comment.