Skip to content

Commit

Permalink
Use generics to avoid type casting
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Sep 28, 2024
1 parent 4466fc1 commit 71a7e54
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 163 deletions.
2 changes: 1 addition & 1 deletion docs/subscriptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Once your websocket client matches the interfaces, you can get your `graphql.Web
a loop for incoming messages and errors:

```go
graphqlClient := graphql.NewClientUsingWebSocket(
graphqlClient := graphql.NewClientUsingWebSocket[countResponse](
"ws://localhost:8080/query",
&MyDialer{Dialer: dialer},
headers,
Expand Down
41 changes: 5 additions & 36 deletions generate/operation.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
const {{.Name}}_Operation = `{{$.Body}}`

{{.Doc}}
func {{.Name}}(
func {{.Name}}{{if eq .Type "subscription"}}[T any]{{end}}(
{{if ne .Config.ContextType "-" -}}
ctx_ {{ref .Config.ContextType}},
{{end}}
{{- if not .Config.ClientGetter -}}
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
client_ {{if eq .Type "subscription"}}{{ref "github.com/Khan/genqlient/graphql.WebSocketClient"}}[T]{{else}}{{ref "github.com/Khan/genqlient/graphql.Client"}}{{end}},
{{end}}
{{- if .Input -}}
{{- range .Input.Fields -}}
{{/* the GraphQL name here is the user-specified variable-name */ -}}
{{.GraphQLName}} {{.GoType.Reference}},
{{end -}}
{{end -}}
) ({{if eq .Type "subscription"}}dataChan_ chan {{.Name}}WsResponse, subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
) ({{if eq .Type "subscription"}}dataChan_ chan graphql.WsResponse[T], subscriptionID_ string,{{else}}data_ *{{.ResponseName}}, {{if .Config.Extensions -}}ext_ map[string]interface{},{{end}}{{end}} err_ error) {
req_ := &graphql.Request{
OpName: "{{.Name}}",
Query: {{.Name}}_Operation,
Expand All @@ -36,8 +36,8 @@ func {{.Name}}(
}
{{end}}
{{if eq .Type "subscription"}}
dataChan_ = make(chan {{.Name}}WsResponse)
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, {{.Name}}ForwardData)
dataChan_ = make(chan graphql.WsResponse[T])
subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, graphql.ForwardData[T])
{{else}}
data_ = &{{.ResponseName}}{}
resp_ := &graphql.Response{Data: data_}
Expand All @@ -51,34 +51,3 @@ func {{.Name}}(

return {{if eq .Type "subscription"}}dataChan_, subscriptionID_,{{else}}data_, {{if .Config.Extensions -}}resp_.Extensions,{{end -}}{{end}} err_
}

{{if eq .Type "subscription"}}
type {{.Name}}WsResponse struct {
Data *{{.ResponseName}} `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors error `json:"errors"`
}

func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error {
var gqlResp graphql.Response
var wsResp {{.Name}}WsResponse
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_, ok := interfaceChan.(chan {{.Name}}WsResponse)
if !ok {
return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'")
}
dataChan_ <- wsResp
return nil
}
{{end}}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 33 additions & 8 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Client interface {
) error
}

type WebSocketClient interface {
type WebSocketClient[T any] interface {
// Start must open a webSocket connection and subscribe to an endpoint
// of the client's GraphQL API.
//
Expand All @@ -59,7 +59,7 @@ type WebSocketClient interface {
// req contains the data to be sent to the GraphQL server. Will be marshalled
// into JSON bytes.
//
// interfaceChan is a channel used to send the data that arrives via the
// dataChan is a channel used to send the data that arrives via the
// webSocket connection (it is the channel that is passed to `forwardDataFunc`).
//
// forwardDataFunc is the function that will cast the received interface into
Expand All @@ -68,8 +68,8 @@ type WebSocketClient interface {
// Returns a subscriptionID if successful, an error otherwise.
Subscribe(
req *Request,
interfaceChan interface{},
forwardDataFunc ForwardDataFunction,
dataChan chan WsResponse[T],
forwardDataFunc ForwardDataFunction[T],
) (string, error)

// Unsubscribe must unsubscribe from an endpoint of the client's GraphQL API.
Expand All @@ -78,7 +78,32 @@ type WebSocketClient interface {

// ForwardDataFunction is a part of the WebSocketClient interface, see
// [WebSocketClient.Subscribe] for details.
type ForwardDataFunction func(interfaceChan interface{}, jsonRawMsg json.RawMessage) error
type ForwardDataFunction[T any] func(dataChan chan WsResponse[T], jsonRawMsg json.RawMessage) error

func ForwardData[T any](dataChan_ chan WsResponse[T], jsonRawMsg json.RawMessage) error {
var gqlResp Response
var wsResp WsResponse[T]
err := json.Unmarshal(jsonRawMsg, &gqlResp)
if err != nil {
return err
}
if len(gqlResp.Errors) == 0 {
err = json.Unmarshal(jsonRawMsg, &wsResp)
if err != nil {
return err
}
} else {
wsResp.Errors = gqlResp.Errors
}
dataChan_ <- wsResp
return nil
}

type WsResponse[T any] struct {
Data *T `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Errors error `json:"errors"`
}

type client struct {
httpClient Doer
Expand Down Expand Up @@ -131,19 +156,19 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
func NewClientUsingWebSocket[T any](endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient[T] {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
return &webSocketClient{
return &webSocketClient[T]{
Dialer: wsDialer,
Header: headers,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
subscriptions: subscriptionMap[T]{map_: make(map[string]subscription[T])},
}
}

Expand Down
27 changes: 13 additions & 14 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,41 @@ package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
type subscriptionMap struct {
map_ map[string]subscription
type subscriptionMap[T any] struct {
map_ map[string]subscription[T]
sync.RWMutex
}

type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
type subscription[T any] struct {
dataChan chan WsResponse[T]
forwardDataFunc ForwardDataFunction[T]
id string
hasBeenUnsubscribed bool
}

func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
func (s *subscriptionMap[T]) Create(subscriptionID string, dataChan chan WsResponse[T], forwardDataFunc ForwardDataFunction[T]) {
s.Lock()
defer s.Unlock()
s.map_[subscriptionID] = subscription{
s.map_[subscriptionID] = subscription[T]{
id: subscriptionID,
interfaceChan: interfaceChan,
dataChan: dataChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
}
}

func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) {
func (s *subscriptionMap[T]) Read(subscriptionID string) (sub subscription[T], success bool) {
s.RLock()
defer s.RUnlock()
sub, success = s.map_[subscriptionID]
return sub, success
}

func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
func (s *subscriptionMap[_]) Unsubscribe(subscriptionID string) error {
s.Lock()
defer s.Unlock()
unsub, success := s.map_[subscriptionID]
Expand All @@ -46,11 +45,11 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
}
unsub.hasBeenUnsubscribed = true
s.map_[subscriptionID] = unsub
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
close(s.map_[subscriptionID].dataChan)
return nil
}

func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
func (s *subscriptionMap[_]) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
for subID := range s.map_ {
Expand All @@ -59,7 +58,7 @@ func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
return subscriptionIDs
}

func (s *subscriptionMap) Delete(subscriptionID string) {
func (s *subscriptionMap[_]) Delete(subscriptionID string) {
s.Lock()
defer s.Unlock()
delete(s.map_, subscriptionID)
Expand Down
Loading

0 comments on commit 71a7e54

Please sign in to comment.