Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generics to avoid type casting subscription channels #355

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/subscriptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Once your websocket client matches the interfaces, you can get your `graphql.Web
a loop for incoming messages and errors:

```go
graphqlClient := graphql.NewClientUsingWebSocket(
graphqlClient := graphql.NewClientUsingWebSocket[countResponse](
"ws://localhost:8080/query",
&MyDialer{Dialer: dialer},
headers,
Expand Down
41 changes: 5 additions & 36 deletions generate/operation.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
const {{.Name}}_Operation = `{{$.Body}}`

{{.Doc}}
func {{.Name}}(
func {{.Name}}{{if eq .Type "subscription"}}[T any]{{end}}(
{{if ne .Config.ContextType "-" -}}
ctx_ {{ref .Config.ContextType}},
{{end}}
{{- if not .Config.ClientGetter -}}
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}[T]{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
{{end}}
{{- if .Input -}}
{{- range .Input.Fields -}}
{{/* the GraphQL name here is the user-specified variable-name */ -}}
{{.GraphQLName}} {{.GoType.Reference}},
{{end -}}
{{end -}}
) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
) ({{if eq .Type "subscription"}}dataChan_ chan graphql.WsResponse[T], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
req_ := &graphql.Request{
OpName: "{{.Name}}",
Query: {{.Name}}_Operation,
Expand All @@ -36,8 +36,8 @@ func {{.Name}}(
}
{{end}}
{{if eq .Type "subscription"}}
dataChan_ = make(chan {{.Name}}WsResponse)
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData)
dataChan_ = make(chan graphql.WsResponse[T])
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[T])
{{else}}
data_ = &{{.ResponseName}}{}
resp_ := &graphql.Response{Data: data_}
Expand All @@ -51,34 +51,3 @@ func {{.Name}}(

return {{if eq .Type "subscription"}}dataChan_, subscriptionID_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_
}

{{if eq .Type "subscription"}}
type {{.Name}}WsResponse struct {
Data *{{.ResponseName}} `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors error `json:"errors"`
}

func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error {
var gqlResp graphql.Response
var wsResp {{.Name}}WsResponse
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_, ok := interfaceChan.(chan {{.Name}}WsResponse)
if !ok {
return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'")
}
dataChan_ <- wsResp
return nil
}
{{end}}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 33 additions & 8 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Client interface {
) error
}

type WebSocketClient interface {
type WebSocketClient[T any] interface {
// Start must open a webSocket connection and subscribe to an endpoint
// of the client's GraphQL API.
//
Expand All @@ -59,7 +59,7 @@ type WebSocketClient interface {
// req contains the data to be sent to the GraphQL server. Will be marshalled
// into JSON bytes.
//
// interfaceChan is a channel used to send the data that arrives via the
// dataChan is a channel used to send the data that arrives via the
// webSocket connection (it is the channel that is passed to `forwardDataFunc`).
//
// forwardDataFunc is the function that will cast the received interface into
Expand All @@ -68,8 +68,8 @@ type WebSocketClient interface {
// Returns a subscriptionID if successful, an error otherwise.
Subscribe(
req *Request,
interfaceChan interface{},
forwardDataFunc ForwardDataFunction,
dataChan chan WsResponse[T],
forwardDataFunc ForwardDataFunction[T],
) (string, error)

// Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API.
Expand All @@ -78,7 +78,32 @@ type WebSocketClient interface {

// ForwardDataFunction is a part of the WebSocketClient interface, see
// [WebSocketClient.Subscribe] for details.
type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error
type ForwardDataFunction[T any] func(dataChan chan WsResponse[T], jsonRawMsg json.RawMessage) error

func ForwardData[T any](dataChan_ chan WsResponse[T], jsonRawMsg json.RawMessage) error {
var gqlResp Response
var wsResp WsResponse[T]
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_ <- wsResp
return nil
}

type WsResponse[T any] struct {
Data *T `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors error `json:"errors"`
}

type client struct {
httpClient Doer
Expand Down Expand Up @@ -131,19 +156,19 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
func NewClientUsingWebSocket[T any](endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient[T] {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
return &webSocketClient{
return &webSocketClient[T]{
Dialer: wsDialer,
Header: headers,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
subscriptions: subscriptionMap[T]{map_: make(map[string]subscription[T])},
}
}

Expand Down
27 changes: 13 additions & 14 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,41 @@ package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
type subscriptionMap struct {
map_ map[string]subscription
type subscriptionMap[T any] struct {
map_ map[string]subscription[T]
sync.RWMutex
}

type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
type subscription[T any] struct {
dataChan chan WsResponse[T]
forwardDataFunc ForwardDataFunction[T]
id string
hasBeenUnsubscribed bool
}

func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
func (s *subscriptionMap[T]) Create(subscriptionID string, dataChan chan WsResponse[T], forwardDataFunc ForwardDataFunction[T]) {
s.Lock()
defer s.Unlock()
s.map_[subscriptionID] = subscription{
s.map_[subscriptionID] = subscription[T]{
id: subscriptionID,
interfaceChan: interfaceChan,
dataChan: dataChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
}
}

func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) {
func (s *subscriptionMap[T]) Read(subscriptionID string) (sub subscription[T], success bool) {
s.RLock()
defer s.RUnlock()
sub, success = s.map_[subscriptionID]
return sub, success
}

func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error {
s.Lock()
defer s.Unlock()
unsub, success := s.map_[subscriptionID]
Expand All @@ -46,11 +45,11 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
}
unsub.hasBeenUnsubscribed = true
s.map_[subscriptionID] = unsub
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
close(s.map_[subscriptionID].dataChan)
return nil
}

func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
for subID := range s.map_ {
Expand All @@ -59,7 +58,7 @@ func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
return subscriptionIDs
}

func (s *subscriptionMap) Delete(subscriptionID string) {
func (s *subscriptionMap[_]) Delete(subscriptionID string) {
s.Lock()
defer s.Unlock()
delete(s.map_, subscriptionID)
Expand Down
Loading