From 550e232451a84a769bef08101c5f94a4f14564c4 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Thu, 18 Feb 2021 11:07:32 +0900 Subject: [PATCH 01/11] support pure websocket subscription --- option.go => client_option.go | 0 option_test.go => client_option_test.go | 0 pure_websocket_subscriber.go | 400 ++++++++++++++++++++++++ pure_websocket_subscriber_option.go | 44 +++ test/EventApp/appsync_eventapp.go | 140 ++++++++- 5 files changed, 567 insertions(+), 17 deletions(-) rename option.go => client_option.go (100%) rename option_test.go => client_option_test.go (100%) create mode 100644 pure_websocket_subscriber.go create mode 100644 pure_websocket_subscriber_option.go diff --git a/option.go b/client_option.go similarity index 100% rename from option.go rename to client_option.go diff --git a/option_test.go b/client_option_test.go similarity index 100% rename from option_test.go rename to client_option_test.go diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go new file mode 100644 index 0000000..49afce4 --- /dev/null +++ b/pure_websocket_subscriber.go @@ -0,0 +1,400 @@ +package appsync + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "time" + + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/sony/appsync-client-go/graphql" +) + +type message struct { + Type string `json:"type"` +} + +type connectionAckMessage struct { + message + Payload struct { + ConnectionTimeoutMs int64 `json:"connectionTimeoutMs"` + } `json:"payload"` +} + +type startMessage struct { + message + ID string `json:"id"` + Payload subscriptionRegistrationPayload `json:"payload"` +} +type subscriptionRegistrationPayload struct { + Data string `json:"data"` + Extensions subscriptionRegistrationPayloadExtensions `json:"extensions"` +} +type subscriptionRegistrationPayloadExtensions struct { + Authorization map[string]string `json:"authorization"` +} +type startAckMessage struct { + message + ID string `json:"id"` +} + +type processingDataMessage struct { + message + ID string `json:"id"` + Payload processingDataPayload `json:"payload"` +} +type processingDataPayload struct { + Data interface{} `json:"data"` +} + +type stopMessage struct { + message + ID string `json:"id"` +} +type completeMessage struct { + message + ID string `json:"id"` +} + +type errorMessage struct { + message + ID string `json:"id"` + Payload errorPayload `json:"payload"` +} +type errorPayload struct { + Errors []struct { + ErrorType string `json:"errorType"` + Message string `json:"message"` + } `json:"errors"` +} + +var ( + connectionInitMsg = message{Type: "connection_init"} +) + +// PureWebSocketSubscriber has pure WebSocket connections and subscription information. +type PureWebSocketSubscriber struct { + realtimeEndpoint string + request graphql.PostRequest + onReceive func(response *graphql.Response) + onConnectionLost func(err error) + header http.Header + iamAuth *struct { + signer *v4.Signer + region string + host string + } + op *realtimeWebSocketOperation +} + +// NewPureWebSocketSubscriber returns a PureWebSocketSubscriber instance. +func NewPureWebSocketSubscriber(realtimeEndpoint string, request graphql.PostRequest, + onReceive func(response *graphql.Response), + onConnectionLost func(err error), + opts ...PureWebSocketSubscriberOption) *PureWebSocketSubscriber { + p := PureWebSocketSubscriber{ + realtimeEndpoint: realtimeEndpoint, + request: request, + header: http.Header{}, + iamAuth: nil, + op: newRealtimeWebSocketOperation(onReceive, onConnectionLost), + } + for _, opt := range opts { + opt(&p) + } + return &p +} + +// Start starts a new subscription. +func (p *PureWebSocketSubscriber) Start() error { + bpayload := []byte("{}") + header := map[string]string{} + if p.iamAuth != nil { + var err error + header, err = signRequest(p.iamAuth.signer, p.iamAuth.host+"/connect", p.iamAuth.region, bpayload) + if err != nil { + log.Println(err) + return err + } + } else { + for k, v := range p.header { + header[k] = v[0] + } + } + + bheader, err := json.Marshal(header) + if err != nil { + log.Println(err) + return err + } + if err := p.op.connect(p.realtimeEndpoint, bheader, bpayload); err != nil { + return err + } + + if err := p.op.connectionInit(); err != nil { + return err + } + + brequest, err := json.Marshal(p.request) + if err != nil { + log.Println(err) + return err + } + authz := map[string]string{} + if p.iamAuth != nil { + var err error + authz, err = signRequest(p.iamAuth.signer, p.iamAuth.host, p.iamAuth.region, brequest) + if err != nil { + return err + } + } else { + for k, v := range p.header { + authz[k] = v[0] + } + } + if err := p.op.start(brequest, authz); err != nil { + return err + } + + return nil +} + +func signRequest(signer *v4.Signer, url, region string, data []byte) (map[string]string, error) { + req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) + if err != nil { + log.Println(err) + return nil, err + } + req.Header.Add("accept", "application/json, text/javascript") + req.Header.Add("content-encoding", "amz-1.0") + req.Header.Add("content-type", "application/json; charset=UTF-8") + + _, err = signer.Sign(req, bytes.NewReader(data), "appsync", region, time.Now()) + if err != nil { + log.Println(err) + return nil, err + } + + return map[string]string{ + "accept": req.Header.Get("accept"), + "content-encoding": req.Header.Get("content-encoding"), + "content-type": req.Header.Get("content-type"), + "host": req.Host, + "x-amz-date": req.Header.Get("x-amz-date"), + "X-Amz-Security-Token": req.Header.Get("X-Amz-Security-Token"), + "Authorization": req.Header.Get("Authorization"), + }, nil +} + +// Stop ends the subscription. +func (p *PureWebSocketSubscriber) Stop() { + p.op.stop() + p.op.disconnect() +} + +type realtimeWebSocketOperation struct { + onReceive func(response *graphql.Response) + onConnectionLost func(err error) + + ws *websocket.Conn + connectionTimeoutMs time.Duration + subscriptionID string + connackCh chan connectionAckMessage + startackCh chan startAckMessage + completeCh chan completeMessage +} + +func newRealtimeWebSocketOperation(onReceive func(response *graphql.Response), + onConnectionLost func(err error)) *realtimeWebSocketOperation { + return &realtimeWebSocketOperation{onReceive, onConnectionLost, nil, 0, "", nil, nil, nil} +} + +func (r *realtimeWebSocketOperation) readLoop() { + r.connackCh = make(chan connectionAckMessage, 1) + defer close(r.connackCh) + + r.startackCh = make(chan startAckMessage, 1) + defer close(r.startackCh) + + r.completeCh = make(chan completeMessage, 1) + defer close(r.completeCh) + + timeout := time.Duration(300000) * time.Millisecond + r.ws.SetReadDeadline(time.Now().Add(timeout)) + for { + _, b, err := r.ws.ReadMessage() + if err != nil { + log.Println(err) + if strings.Contains(err.Error(), "i/o timeout") { + r.onConnectionLost(err) + } + return + } + + msg := new(message) + if err := json.Unmarshal(b, msg); err != nil { + log.Println(err) + return + } + switch msg.Type { + case "connection_ack": + connack := new(connectionAckMessage) + if err := json.Unmarshal(b, connack); err != nil { + log.Println(err) + return + } + r.connackCh <- *connack + case "ka": + if r.connectionTimeoutMs != 0 { + timeout = r.connectionTimeoutMs + } + r.ws.SetReadDeadline(time.Now().Add(timeout)) + case "start_ack": + startack := new(startAckMessage) + if err := json.Unmarshal(b, startack); err != nil { + log.Println(err) + return + } + r.startackCh <- *startack + case "data": + data := new(processingDataMessage) + if err := json.Unmarshal(b, data); err != nil { + log.Println(err) + return + } + r.onReceive(&graphql.Response{ + Data: data.Payload.Data, + }) + case "complete": + complete := new(completeMessage) + if err := json.Unmarshal(b, complete); err != nil { + log.Println(err) + return + } + r.completeCh <- *complete + return + case "error": + e := new(errorMessage) + if err := json.Unmarshal(b, e); err != nil { + log.Println(err) + } + default: + log.Println("invalid message received") + } + } +} + +func (r *realtimeWebSocketOperation) connect(realtimeEndpoint string, header, payload []byte) error { + if r.ws != nil { + return errors.New("already connected") + } + + b64h := base64.StdEncoding.EncodeToString(header) + b64p := base64.StdEncoding.EncodeToString(payload) + endpoint := fmt.Sprintf("%s?header=%s&payload=%s", realtimeEndpoint, b64h, b64p) + + ws, _, err := websocket.DefaultDialer.Dial(endpoint, http.Header{"sec-websocket-protocol": []string{"graphql-ws"}}) + if err != nil { + log.Println(err) + return err + } + + r.ws = ws + go r.readLoop() + return nil +} + +func (r *realtimeWebSocketOperation) connectionInit() error { + if r.connectionTimeoutMs != 0 { + return errors.New("already connection initialized") + } + + init, err := json.Marshal(connectionInitMsg) + if err != nil { + log.Println(err) + return err + } + if err := r.ws.WriteMessage(websocket.TextMessage, init); err != nil { + log.Println(err) + return err + } + connack, ok := <-r.connackCh + if !ok { + return errors.New("connection failed") + } + + r.connectionTimeoutMs = time.Duration(connack.Payload.ConnectionTimeoutMs) * time.Millisecond + return nil +} + +func (r *realtimeWebSocketOperation) start(request []byte, authorization map[string]string) error { + if len(r.subscriptionID) != 0 { + return errors.New("already started") + } + + start := startMessage{ + message: message{"start"}, + ID: uuid.New().String(), + Payload: subscriptionRegistrationPayload{ + Data: string(request), + Extensions: subscriptionRegistrationPayloadExtensions{ + Authorization: authorization, + }, + }, + } + + b, err := json.Marshal(start) + if err != nil { + log.Println(err) + return err + } + if err := r.ws.WriteMessage(websocket.TextMessage, b); err != nil { + log.Println(err) + } + startack, ok := <-r.startackCh + if !ok { + return errors.New("subscription registration failed") + } + r.subscriptionID = startack.ID + + return nil +} + +func (r *realtimeWebSocketOperation) stop() { + if len(r.subscriptionID) == 0 { + return + } + + stop := stopMessage{message{"stop"}, r.subscriptionID} + b, err := json.Marshal(stop) + if err != nil { + log.Println(err) + return + } + if err := r.ws.WriteMessage(websocket.TextMessage, b); err != nil { + log.Println(err) + return + } + if _, ok := <-r.completeCh; !ok { + log.Println("unsubscribe failed") + } + r.subscriptionID = "" +} + +func (r *realtimeWebSocketOperation) disconnect() { + if r.ws == nil { + return + } + + if err := r.ws.Close(); err != nil { + log.Println(err) + } + r.ws = nil +} diff --git a/pure_websocket_subscriber_option.go b/pure_websocket_subscriber_option.go new file mode 100644 index 0000000..3423763 --- /dev/null +++ b/pure_websocket_subscriber_option.go @@ -0,0 +1,44 @@ +package appsync + +import ( + "net/url" + + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" +) + +// PureWebSocketSubscriberOption represents options for an PureWebSocketSubscriber. +type PureWebSocketSubscriberOption func(*PureWebSocketSubscriber) + +func sanitize(host string) string { + if u, err := url.Parse(host); err == nil { + return u.Host + } + return host +} + +//WithAPIKey returns a PureWebSocketSubscriberOption configured with the host for the AWS AppSync GraphQL endpoint and API key +func WithAPIKey(host, apiKey string) PureWebSocketSubscriberOption { + return func(p *PureWebSocketSubscriber) { + p.header.Set("host", sanitize(host)) + p.header.Set("X-Api-Key", apiKey) + } +} + +//WithOIDC returns a PureWebSocketSubscriberOption configured with the host for the AWS AppSync GraphQL endpoint and JWT Access Token. +func WithOIDC(host, jwt string) PureWebSocketSubscriberOption { + return func(p *PureWebSocketSubscriber) { + p.header.Set("host", sanitize(host)) + p.header.Set("Authorization", jwt) + } +} + +// WithIAM returns a PureWebSocketSubscriberOption configured with the signature version 4 signer, the region and the host for the AWS AppSync GraphQL endpoint. +func WithIAM(signer *v4.Signer, region, host string) PureWebSocketSubscriberOption { + return func(p *PureWebSocketSubscriber) { + p.iamAuth = &struct { + signer *v4.Signer + region string + host string + }{signer, region, host} + } +} diff --git a/test/EventApp/appsync_eventapp.go b/test/EventApp/appsync_eventapp.go index ef5b357..ccc80e1 100644 --- a/test/EventApp/appsync_eventapp.go +++ b/test/EventApp/appsync_eventapp.go @@ -2,7 +2,10 @@ package main import ( "flag" + "fmt" "log" + "strings" + "time" "github.com/aws/aws-sdk-go/aws/session" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" @@ -24,38 +27,141 @@ type eventConnection struct { NextToken string `json:"nextToken"` } +type subscriber interface { + Start() error + Stop() +} + func main() { log.SetFlags(log.Llongfile) var ( - region = flag.String("region", "", "AppSync API region") - url = flag.String("url", "", "AppSync API URL") + region = flag.String("region", "", "AppSync API region") + url = flag.String("url", "", "AppSync API URL") + protocol = flag.String("subscription protocol", "graphql-ws", "AppSync Subscription protocol(mqtt, graphql-ws)") ) flag.Parse() - query := ` -query ListEvents { - listEvents { - items { - id - name - } - } -} -` sess := session.Must(session.NewSession()) signer := v4.NewSigner(sess.Config.Credentials) client := appsync.NewClient(appsync.NewGraphQLClient(graphql.NewClient(*url)), appsync.WithIAMAuthorization(*signer, *region, *url)) + + log.Println("mutation createEvent()") + mutation := ` +mutation { + createEvent(name: "name", when: "when", where: "where", description: "description") { + id + name + when + where + description + } +}` res, err := client.Post(graphql.PostRequest{ - Query: query, + Query: mutation, + }) + if err != nil { + log.Fatalln(err) + } + pp.Println(res) + + ev := new(event) + if err := res.DataAs(ev); err != nil { + log.Fatalln(err) + } + + log.Println("subscription subscribeToEventComments()") + subscription := fmt.Sprintf(` +subscription { + subscribeToEventComments(eventId: "%s"){ + eventId + commentId + content + createdAt + } +}`, ev.ID) + + var s subscriber + subreq := graphql.PostRequest{ + Query: subscription, + } + + ch := make(chan *graphql.Response) + defer close(ch) + + switch *protocol { + case "mqtt": + res, err = client.Post(subreq) + if err != nil { + log.Fatalln(err) + } + pp.Println(res) + + ext, err := appsync.NewExtensions(res) + if err != nil { + log.Fatalln(err) + } + + s = appsync.NewSubscriber(*ext, + func(r *graphql.Response) { ch <- r }, + func(err error) { log.Println(err) }) + case "graphql-ws": + realtime := strings.Replace(strings.Replace(*url, "https", "wss", 1), "appsync-api", "appsync-realtime-api", 1) + s = appsync.NewPureWebSocketSubscriber(realtime, subreq, + func(r *graphql.Response) { ch <- r }, + func(err error) { log.Println(err) }, + appsync.WithIAM(signer, *region, *url), + ) + default: + log.Fatalln("unsupported protocol: " + *protocol) + } + + if err := s.Start(); err != nil { + log.Fatalln(err) + } + defer s.Stop() + + log.Println("mutation commentOnEvent()") + mutation = fmt.Sprintf(` +mutation { + commentOnEvent(eventId: "%s", content: "content", createdAt: "%s"){ + eventId + commentId + content + createdAt + } +}`, ev.ID, time.Now().String()) + res, err = client.Post(graphql.PostRequest{ + Query: mutation, }) if err != nil { log.Fatalln(err) } pp.Println(res) - data := new(eventConnection) - if err := res.DataAs(data); err != nil { - log.Fatalln(err, res) + + msg, ok := <-ch + if !ok { + log.Fatal("ch has been closed.") + } + log.Println("comment received") + pp.Print(msg) + + log.Println("mutation deleteEvent()") + mutation = fmt.Sprintf(` +mutation { + deleteEvent(id: "%s"){ + id + name + when + where + description + } +}`, ev.ID) + res, err = client.Post(graphql.PostRequest{ + Query: mutation, + }) + if err != nil { + log.Fatalln(err) } - pp.Println(*data) + pp.Println(res) } From 81417947d4bfeeea9076d3bb4c72aa3e90918b80 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Fri, 19 Feb 2021 13:47:15 +0900 Subject: [PATCH 02/11] unit testing --- appsync_example_test.go | 46 +++++- internal/appsynctest/server.go | 179 ++++++++++++++++++----- pure_websocket_subscriber.go | 15 +- pure_websocket_subscriber_option.go | 2 +- pure_websocket_subscriber_option_test.go | 146 ++++++++++++++++++ 5 files changed, 340 insertions(+), 48 deletions(-) create mode 100644 pure_websocket_subscriber_option_test.go diff --git a/appsync_example_test.go b/appsync_example_test.go index 77a041a..f221006 100644 --- a/appsync_example_test.go +++ b/appsync_example_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log" + "strings" "github.com/sony/appsync-client-go/internal/appsynctest" @@ -59,7 +60,7 @@ func ExampleClient_Post_mutation() { // Hi, AppSync! } -func ExampleClient_Post_subscription() { +func ExampleClient_mqtt_subscription() { server := appsynctest.NewAppSyncEchoServer() defer server.Close() @@ -108,3 +109,46 @@ func ExampleClient_Post_subscription() { // Output: // Hi, AppSync! } + +func ExampleClient_graphqlws_subscription() { + server := appsynctest.NewAppSyncEchoServer() + defer server.Close() + + client := appsync.NewClient(appsync.NewGraphQLClient(graphql.NewClient(server.URL))) + subscription := `subscription SubscribeToEcho() { subscribeToEcho }` + + ch := make(chan *graphql.Response) + subscriber := appsync.NewPureWebSocketSubscriber( + strings.Replace(server.URL, "http", "ws", 1), + graphql.PostRequest{ + Query: subscription, + }, + func(r *graphql.Response) { ch <- r }, + func(err error) { log.Println(err) }, + ) + + if err := subscriber.Start(); err != nil { + log.Fatalln(err) + } + defer subscriber.Stop() + + mutation := `mutation Echo($message: String!) { echo(message: $message) }` + variables := json.RawMessage(fmt.Sprintf(`{ "message": "%s" }`, "Hi, AppSync!")) + _, err := client.Post(graphql.PostRequest{ + Query: mutation, + Variables: &variables, + }) + if err != nil { + log.Fatal(err) + } + + response := <-ch + data := new(string) + if err := response.DataAs(data); err != nil { + log.Fatalln(err, response) + } + fmt.Println(*data) + + // Output: + // Hi, AppSync! +} diff --git a/internal/appsynctest/server.go b/internal/appsynctest/server.go index 8edcced..4784bce 100644 --- a/internal/appsynctest/server.go +++ b/internal/appsynctest/server.go @@ -21,29 +21,58 @@ var ( mqttEchoTopic = "echo" schema = ` schema { - query: Query - mutation: Mutation - subscription: Subscription + query: Query + mutation: Mutation + subscription: Subscription } type Query { - message: String! + message: String! } type Mutation { - echo(message: String!): String! + echo(message: String!): String! } type Subscription { - subscribeToEcho: String! + subscribeToEcho: String! } ` initialMessage = "Hello, AppSync!" + + gqlwsconnack = ` +{ + "type": "connection_ack", + "payload" : { + "connectionTimeoutMs": 300000 + } +} +` + gqlwsstartackfmt = ` +{ + "type": "start_ack", + "id" : "%s" +} +` + gqlwsdatafmt = ` +{ + "type": "data", + "id" : "%s", + "payload": %s +} +` + gqlwscompletefmt = ` +{ + "type": "complete", + "id" : "%s" +} +` ) type mqttPublisher struct { - w http.ResponseWriter - sessions mqttSessions + w http.ResponseWriter + mqttSessions mqttSessions + grapqhWsSessions grapqhWsSessions } func (m *mqttPublisher) Header() http.Header { @@ -51,26 +80,39 @@ func (m *mqttPublisher) Header() http.Header { } func (m *mqttPublisher) Write(payload []byte) (int, error) { - go func() { - pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) - pub.TopicName = mqttEchoTopic - pub.Payload = payload - for sub := range m.sessions { - writer, err := sub.NextWriter(websocket.BinaryMessage) - if err != nil { - log.Println(err) - continue - } - if err := pub.Write(writer); err != nil { - log.Println(err) - continue + if len(m.mqttSessions) != 0 { + go func() { + pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) + pub.TopicName = mqttEchoTopic + pub.Payload = payload + for s := range m.mqttSessions { + writer, err := s.NextWriter(websocket.BinaryMessage) + if err != nil { + log.Println(err) + continue + } + if err := pub.Write(writer); err != nil { + log.Println(err) + continue + } + if err := writer.Close(); err != nil { + log.Println(err) + continue + } } - if err := writer.Close(); err != nil { - log.Println(err) - continue + }() + } + if len(m.grapqhWsSessions) != 0 { + go func() { + for s := range m.grapqhWsSessions { + data := json.RawMessage(fmt.Sprintf(gqlwsdatafmt, "id", string(payload))) + if err := s.WriteJSON(data); err != nil { + log.Println(err) + continue + } } - } - }() + }() + } return m.w.Write(payload) } @@ -95,7 +137,7 @@ func (e *echoResolver) SubscribeToEcho() string { return e.message } -func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clientId string), onDisconnected func(ws *websocket.Conn)) { +func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) { defer func() { if err := ws.Close(); err != nil { log.Println(err) @@ -103,13 +145,13 @@ func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clie }() for { - mt, reader, err := ws.NextReader() + mt, r, err := ws.NextReader() if err != nil { log.Println(err) return } - cp, err := packets.ReadPacket(reader) + cp, err := packets.ReadPacket(r) if err != nil { log.Println(err) return @@ -119,7 +161,7 @@ func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clie switch cp.(type) { case *packets.ConnectPacket: ack = packets.NewControlPacket(packets.Connack) - onConnected(ws, cp.(*packets.ConnectPacket).ClientIdentifier) + onConnected(ws) case *packets.SubscribePacket: ack = packets.NewControlPacket(packets.Suback) ack.(*packets.SubackPacket).MessageID = cp.(*packets.SubscribePacket).MessageID @@ -150,15 +192,46 @@ func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clie } } +func graphQLWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) { + defer func() { + if err := ws.Close(); err != nil { + log.Println(err) + } + }() + + for { + msg := map[string]interface{}{} + if err := ws.ReadJSON(&msg); err != nil { + return + } + + var ack json.RawMessage + switch msg["type"].(string) { + case "connection_init": + ack = json.RawMessage(gqlwsconnack) + onConnected(ws) + case "start": + ack = json.RawMessage(fmt.Sprintf(gqlwsstartackfmt, msg["id"].(string))) + case "stop": + ack = json.RawMessage(fmt.Sprintf(gqlwscompletefmt, msg["id"].(string))) + onDisconnected(ws) + } + if err := ws.WriteJSON(ack); err != nil { + log.Println(err) + return + } + } +} + func newQueryHandlerFunc(h relay.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) } } -func newMutationHandlerFunc(h relay.Handler, sessions mqttSessions) http.HandlerFunc { +func newMutationHandlerFunc(h relay.Handler, mqtt mqttSessions, graphqlws grapqhWsSessions) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - h.ServeHTTP(&mqttPublisher{w, sessions}, r) + h.ServeHTTP(&mqttPublisher{w, mqtt, graphqlws}, r) } } @@ -200,10 +273,16 @@ func newSubscriptionHandlerFunc() http.HandlerFunc { } func isMqttWs(r *http.Request) bool { - return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket" + return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket" && + r.Header.Get("Sec-Websocket-Protocol") == "mqtt" } -func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn, clientId string), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc { +func isGraphQLWs(r *http.Request) bool { + return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket" && + r.Header.Get("Sec-Websocket-Protocol") == "graphql-ws" +} + +func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc { upgrader := websocket.Upgrader{} return func(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) @@ -214,18 +293,35 @@ func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn, clientId string), } } -type mqttSessions map[*websocket.Conn]string +type mqttSessions map[*websocket.Conn]bool +type grapqhWsSessions map[*websocket.Conn]bool + +func newGraphQLWsHandlerFunc(onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc { + upgrader := websocket.Upgrader{} + return func(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + go graphQLWsSession(ws, onConnected, onDisconnected) + } +} func newAppSyncEchoHandlerFunc(initialMessage string) http.HandlerFunc { s := graphqlgo.MustParseSchema(schema, &echoResolver{initialMessage}) handler := relay.Handler{Schema: s} - sessions := make(mqttSessions) + mqttSessions := make(mqttSessions) + grapqhWsSessions := make(grapqhWsSessions) query := newQueryHandlerFunc(handler) - mutation := newMutationHandlerFunc(handler, sessions) + mutation := newMutationHandlerFunc(handler, mqttSessions, grapqhWsSessions) subscription := newSubscriptionHandlerFunc() mqttws := newMqttWsHandlerFunc( - func(ws *websocket.Conn, clientId string) { sessions[ws] = clientId }, - func(ws *websocket.Conn) { delete(sessions, ws) }, + func(ws *websocket.Conn) { mqttSessions[ws] = true }, + func(ws *websocket.Conn) { delete(mqttSessions, ws) }, + ) + graphqlws := newGraphQLWsHandlerFunc( + func(ws *websocket.Conn) { grapqhWsSessions[ws] = true }, + func(ws *websocket.Conn) { delete(grapqhWsSessions, ws) }, ) return func(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) @@ -242,6 +338,11 @@ func newAppSyncEchoHandlerFunc(initialMessage string) http.HandlerFunc { return } + if isGraphQLWs(r) { + graphqlws.ServeHTTP(w, r) + return + } + req := new(graphql.PostRequest) if err := json.Unmarshal(body, req); err != nil { log.Println(err) diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 49afce4..025e83a 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -217,17 +217,12 @@ func newRealtimeWebSocketOperation(onReceive func(response *graphql.Response), } func (r *realtimeWebSocketOperation) readLoop() { - r.connackCh = make(chan connectionAckMessage, 1) defer close(r.connackCh) - - r.startackCh = make(chan startAckMessage, 1) defer close(r.startackCh) - - r.completeCh = make(chan completeMessage, 1) defer close(r.completeCh) - timeout := time.Duration(300000) * time.Millisecond - r.ws.SetReadDeadline(time.Now().Add(timeout)) + const defaultTimeout = time.Duration(300000) * time.Millisecond + r.ws.SetReadDeadline(time.Now().Add(defaultTimeout)) for { _, b, err := r.ws.ReadMessage() if err != nil { @@ -252,6 +247,7 @@ func (r *realtimeWebSocketOperation) readLoop() { } r.connackCh <- *connack case "ka": + timeout := defaultTimeout if r.connectionTimeoutMs != 0 { timeout = r.connectionTimeoutMs } @@ -307,6 +303,10 @@ func (r *realtimeWebSocketOperation) connect(realtimeEndpoint string, header, pa } r.ws = ws + r.connackCh = make(chan connectionAckMessage, 1) + r.startackCh = make(chan startAckMessage, 1) + r.completeCh = make(chan completeMessage, 1) + go r.readLoop() return nil } @@ -396,5 +396,6 @@ func (r *realtimeWebSocketOperation) disconnect() { if err := r.ws.Close(); err != nil { log.Println(err) } + r.connectionTimeoutMs = 0 r.ws = nil } diff --git a/pure_websocket_subscriber_option.go b/pure_websocket_subscriber_option.go index 3423763..d7c1854 100644 --- a/pure_websocket_subscriber_option.go +++ b/pure_websocket_subscriber_option.go @@ -10,7 +10,7 @@ import ( type PureWebSocketSubscriberOption func(*PureWebSocketSubscriber) func sanitize(host string) string { - if u, err := url.Parse(host); err == nil { + if u, err := url.ParseRequestURI(host); err == nil { return u.Host } return host diff --git a/pure_websocket_subscriber_option_test.go b/pure_websocket_subscriber_option_test.go new file mode 100644 index 0000000..0eea4b6 --- /dev/null +++ b/pure_websocket_subscriber_option_test.go @@ -0,0 +1,146 @@ +package appsync + +import ( + "testing" + + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/sony/appsync-client-go/graphql" +) + +var ( + realtimeEndpoint = "wss://example1234567890000.appsync-realtime-api.us-east-1.amazonaws.com/graphql" + request = graphql.PostRequest{} + onReceive = func(*graphql.Response) {} + onConnectionLost = func(error) {} +) + +func TestWithAPIKey(t *testing.T) { + type args struct { + host string + apiKey string + } + tests := []struct { + name string + args args + want PureWebSocketSubscriberOption + }{ + { + name: "WithAPIKey Success", + args: args{ + host: "example1234567890000.appsync-api.us-east-1.amazonaws.com", + apiKey: "apikey", + }, + }, + { + name: "WithAPIKey with sanitize Success", + args: args{ + host: "https://example1234567890000.appsync-api.us-east-1.amazonaws.com/graphql", + apiKey: "apikey", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewPureWebSocketSubscriber(realtimeEndpoint, request, onReceive, onConnectionLost) + if len(s.header) != 0 { + t.Fatal(s.header) + } + opt := WithAPIKey(tt.args.host, tt.args.apiKey) + opt(s) + if len(s.header) != 2 { + t.Fatal(s.header) + } + if s.header.Get("host") != sanitize(tt.args.host) { + t.Errorf("got: %s, want: %s", s.header.Get("host"), tt.args.host) + } + if s.header.Get("x-api-key") != tt.args.apiKey { + t.Errorf("got: %s, want: %s", s.header.Get("x-api-key"), tt.args.apiKey) + } + }) + } +} + +func TestWithOIDC(t *testing.T) { + type args struct { + host string + jwt string + } + tests := []struct { + name string + args args + }{ + { + name: "WithOIDC Success", + args: args{ + host: "example1234567890000.appsync-api.us-east-1.amazonaws.com", + jwt: "jwt", + }, + }, + { + name: "WithOIDC with sanitize Success", + args: args{ + host: "https://example1234567890000.appsync-api.us-east-1.amazonaws.com/graphql", + jwt: "jwt", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewPureWebSocketSubscriber(realtimeEndpoint, request, onReceive, onConnectionLost) + if len(s.header) != 0 { + t.Fatal(s.header) + } + opt := WithOIDC(tt.args.host, tt.args.jwt) + opt(s) + if len(s.header) != 2 { + t.Fatal(s.header) + } + if s.header.Get("host") != sanitize(tt.args.host) { + t.Errorf("got: %s, want: %s", s.header.Get("host"), tt.args.host) + } + if s.header.Get("authorization") != tt.args.jwt { + t.Errorf("got: %s, want: %s", s.header.Get("authorization"), tt.args.jwt) + } + }) + } +} + +func TestWithIAM(t *testing.T) { + type args struct { + signer *v4.Signer + region string + host string + } + tests := []struct { + name string + args args + }{ + { + name: "WithIAM Success", + args: args{ + signer: &v4.Signer{}, + region: "region", + host: "host", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewPureWebSocketSubscriber(realtimeEndpoint, request, onReceive, onConnectionLost) + if s.iamAuth != nil { + t.Fatal(s.iamAuth) + } + opt := WithIAM(tt.args.signer, tt.args.region, tt.args.host) + opt(s) + if s.iamAuth.signer != tt.args.signer { + t.Errorf("got: %v, want: %v", s.iamAuth.signer, tt.args.signer) + } + if s.iamAuth.region != tt.args.region { + t.Errorf("got: %s, want: %s", s.iamAuth.region, tt.args.region) + } + if s.iamAuth.host != tt.args.host { + t.Errorf("got: %s, want: %s", s.iamAuth.host, s.iamAuth.host) + } + }) + } +} From 7ac44a8affd33258392194e7a743133dae2f7c1d Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 10:31:38 +0900 Subject: [PATCH 03/11] notify error message as graphql.Response --- pure_websocket_subscriber.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 025e83a..26ed14c 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -277,10 +277,18 @@ func (r *realtimeWebSocketOperation) readLoop() { r.completeCh <- *complete return case "error": - e := new(errorMessage) - if err := json.Unmarshal(b, e); err != nil { + em := new(errorMessage) + if err := json.Unmarshal(b, em); err != nil { log.Println(err) } + errors := make([]interface{}, len(em.Payload.Errors)) + for i, e := range em.Payload.Errors { + errors[i] = e + } + r.onReceive(&graphql.Response{ + Errors: &errors, + }) + return default: log.Println("invalid message received") } From 0d23c5ad2e58d79f326dd9929e6946585e7ef0cc Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 10:45:22 +0900 Subject: [PATCH 04/11] update README --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0d03228..5fa02a2 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ appsync-client-go ================= +[![Go Reference](https://pkg.go.dev/badge/github.com/sony/appsync-client-go.svg)](https://pkg.go.dev/github.com/sony/appsync-client-go) [![Job Status](https://inspecode.rocro.com/badges/github.com/sony/appsync-client-go/status?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://inspecode.rocro.com/jobs/github.com/sony/appsync-client-go/latest?completed=true&branch=master) [![Report](https://inspecode.rocro.com/badges/github.com/sony/appsync-client-go/report?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://inspecode.rocro.com/reports/github.com/sony/appsync-client-go/branch/master/summary) -[![Job Status](https://docstand.rocro.com/badges/github.com/sony/appsync-client-go/status?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://docstand.rocro.com/jobs/github.com/sony/appsync-client-go/latest?completed=true&branch=master) -[![godoc](https://docstand.rocro.com/badges/github.com/sony/appsync-client-go/documentation/godoc?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://docstand.rocro.com/docs/github.com/sony/appsync-client-go/branch/master/godoc/github.com/sony/appsync-client-go/) + [AWS AppSync](https://aws.amazon.com/jp/appsync/) Go client library @@ -13,6 +13,7 @@ Features * GraphQL Query(Queries, Mutations and Subscriptions). * MQTT over Websocket for subscriptions. +* Pure Websockets subscriptions. Getting Started --------------- From 1a2355057daf8d43f5c45a4ff0db9a320f1699d8 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 11:13:35 +0900 Subject: [PATCH 05/11] support exponential backoff --- pure_websocket_subscriber.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 26ed14c..1762cd0 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -12,6 +12,7 @@ import ( "time" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/cenkalti/backoff" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/sony/appsync-client-go/graphql" @@ -304,13 +305,18 @@ func (r *realtimeWebSocketOperation) connect(realtimeEndpoint string, header, pa b64p := base64.StdEncoding.EncodeToString(payload) endpoint := fmt.Sprintf("%s?header=%s&payload=%s", realtimeEndpoint, b64h, b64p) - ws, _, err := websocket.DefaultDialer.Dial(endpoint, http.Header{"sec-websocket-protocol": []string{"graphql-ws"}}) - if err != nil { + if err := backoff.Retry(func() error { + ws, _, err := websocket.DefaultDialer.Dial(endpoint, http.Header{"sec-websocket-protocol": []string{"graphql-ws"}}) + if err != nil { + return err + } + r.ws = ws + return nil + }, backoff.NewExponentialBackOff()); err != nil { log.Println(err) return err } - r.ws = ws r.connackCh = make(chan connectionAckMessage, 1) r.startackCh = make(chan startAckMessage, 1) r.completeCh = make(chan completeMessage, 1) From 9a6fded1cc1a4807bbb12dd2196b91a3f019a027 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 11:24:20 +0900 Subject: [PATCH 06/11] update inspecode config --- rocro.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rocro.yml b/rocro.yml index d7737bb..adc1b1c 100644 --- a/rocro.yml +++ b/rocro.yml @@ -14,6 +14,9 @@ inspecode: goimports: default golint: default misspell: default + gosec: default + errcheck: default + staticcheck: default go-test: options: - -cover From 2ba2047a002b6b6778d5e6bea4261847df9ad1b3 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 11:35:26 +0900 Subject: [PATCH 07/11] fix some inspecode issues --- graphql/client.go | 2 -- pure_websocket_subscriber.go | 34 +++++++++++++++++-------------- test/EventApp/appsync_eventapp.go | 15 +++++--------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/graphql/client.go b/graphql/client.go index f5e9e1b..b20832c 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -9,7 +9,6 @@ import ( "net/url" "time" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/cenkalti/backoff" ) @@ -32,7 +31,6 @@ type Client struct { timeout time.Duration maxElapsedTime time.Duration header http.Header - signer *v4.Signer } // NewClient returns a Client instance. diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 1762cd0..6e88667 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -84,8 +84,6 @@ var ( type PureWebSocketSubscriber struct { realtimeEndpoint string request graphql.PostRequest - onReceive func(response *graphql.Response) - onConnectionLost func(err error) header http.Header iamAuth *struct { signer *v4.Signer @@ -204,12 +202,12 @@ type realtimeWebSocketOperation struct { onReceive func(response *graphql.Response) onConnectionLost func(err error) - ws *websocket.Conn - connectionTimeoutMs time.Duration - subscriptionID string - connackCh chan connectionAckMessage - startackCh chan startAckMessage - completeCh chan completeMessage + ws *websocket.Conn + connectionTimeout time.Duration + subscriptionID string + connackCh chan connectionAckMessage + startackCh chan startAckMessage + completeCh chan completeMessage } func newRealtimeWebSocketOperation(onReceive func(response *graphql.Response), @@ -223,7 +221,10 @@ func (r *realtimeWebSocketOperation) readLoop() { defer close(r.completeCh) const defaultTimeout = time.Duration(300000) * time.Millisecond - r.ws.SetReadDeadline(time.Now().Add(defaultTimeout)) + if err := r.ws.SetReadDeadline(time.Now().Add(defaultTimeout)); err != nil { + log.Println(err) + return + } for { _, b, err := r.ws.ReadMessage() if err != nil { @@ -249,10 +250,13 @@ func (r *realtimeWebSocketOperation) readLoop() { r.connackCh <- *connack case "ka": timeout := defaultTimeout - if r.connectionTimeoutMs != 0 { - timeout = r.connectionTimeoutMs + if r.connectionTimeout != 0 { + timeout = r.connectionTimeout + } + if err := r.ws.SetReadDeadline(time.Now().Add(timeout)); err != nil { + log.Println(err) + return } - r.ws.SetReadDeadline(time.Now().Add(timeout)) case "start_ack": startack := new(startAckMessage) if err := json.Unmarshal(b, startack); err != nil { @@ -326,7 +330,7 @@ func (r *realtimeWebSocketOperation) connect(realtimeEndpoint string, header, pa } func (r *realtimeWebSocketOperation) connectionInit() error { - if r.connectionTimeoutMs != 0 { + if r.connectionTimeout != 0 { return errors.New("already connection initialized") } @@ -344,7 +348,7 @@ func (r *realtimeWebSocketOperation) connectionInit() error { return errors.New("connection failed") } - r.connectionTimeoutMs = time.Duration(connack.Payload.ConnectionTimeoutMs) * time.Millisecond + r.connectionTimeout = time.Duration(connack.Payload.ConnectionTimeoutMs) * time.Millisecond return nil } @@ -410,6 +414,6 @@ func (r *realtimeWebSocketOperation) disconnect() { if err := r.ws.Close(); err != nil { log.Println(err) } - r.connectionTimeoutMs = 0 + r.connectionTimeout = 0 r.ws = nil } diff --git a/test/EventApp/appsync_eventapp.go b/test/EventApp/appsync_eventapp.go index ccc80e1..0661724 100644 --- a/test/EventApp/appsync_eventapp.go +++ b/test/EventApp/appsync_eventapp.go @@ -22,11 +22,6 @@ type event struct { Description *string `json:"description"` } -type eventConnection struct { - Items []event `json:"items"` - NextToken string `json:"nextToken"` -} - type subscriber interface { Start() error Stop() @@ -63,7 +58,7 @@ mutation { if err != nil { log.Fatalln(err) } - pp.Println(res) + _, _ = pp.Println(res) ev := new(event) if err := res.DataAs(ev); err != nil { @@ -95,7 +90,7 @@ subscription { if err != nil { log.Fatalln(err) } - pp.Println(res) + _, _ = pp.Println(res) ext, err := appsync.NewExtensions(res) if err != nil { @@ -137,14 +132,14 @@ mutation { if err != nil { log.Fatalln(err) } - pp.Println(res) + _, _ = pp.Println(res) msg, ok := <-ch if !ok { log.Fatal("ch has been closed.") } log.Println("comment received") - pp.Print(msg) + _, _ = pp.Println(msg) log.Println("mutation deleteEvent()") mutation = fmt.Sprintf(` @@ -163,5 +158,5 @@ mutation { if err != nil { log.Fatalln(err) } - pp.Println(res) + _, _ = pp.Println(res) } From dcb3a6c97f1eeee5438cca6106622757b9aee884 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Feb 2021 12:59:31 +0900 Subject: [PATCH 08/11] try to reduce cyclomatic complexity --- pure_websocket_subscriber.go | 129 ++++++++++++++++++++--------------- 1 file changed, 74 insertions(+), 55 deletions(-) diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 6e88667..76cee06 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -226,6 +226,73 @@ func (r *realtimeWebSocketOperation) readLoop() { return } for { + handlers := map[string]func(b []byte) (finish bool){ + "connection_ack": func(b []byte) bool { + connack := new(connectionAckMessage) + if err := json.Unmarshal(b, connack); err != nil { + log.Println(err) + return true + } + r.connackCh <- *connack + return false + }, + "ka": func(b []byte) bool { + timeout := defaultTimeout + if r.connectionTimeout != 0 { + timeout = r.connectionTimeout + } + if err := r.ws.SetReadDeadline(time.Now().Add(timeout)); err != nil { + log.Println(err) + return true + } + return false + }, + "start_ack": func(b []byte) bool { + startack := new(startAckMessage) + if err := json.Unmarshal(b, startack); err != nil { + log.Println(err) + return true + } + r.startackCh <- *startack + return false + }, + "data": func(b []byte) bool { + data := new(processingDataMessage) + if err := json.Unmarshal(b, data); err != nil { + log.Println(err) + return true + } + r.onReceive(&graphql.Response{ + Data: data.Payload.Data, + }) + return false + }, + "complete": func(b []byte) bool { + complete := new(completeMessage) + if err := json.Unmarshal(b, complete); err != nil { + log.Println(err) + return true + } + r.completeCh <- *complete + return true + }, + "error": func(b []byte) bool { + em := new(errorMessage) + if err := json.Unmarshal(b, em); err != nil { + log.Println(err) + return true + } + errors := make([]interface{}, len(em.Payload.Errors)) + for i, e := range em.Payload.Errors { + errors[i] = e + } + r.onReceive(&graphql.Response{ + Errors: &errors, + }) + return true + }, + } + _, b, err := r.ws.ReadMessage() if err != nil { log.Println(err) @@ -240,62 +307,14 @@ func (r *realtimeWebSocketOperation) readLoop() { log.Println(err) return } - switch msg.Type { - case "connection_ack": - connack := new(connectionAckMessage) - if err := json.Unmarshal(b, connack); err != nil { - log.Println(err) - return - } - r.connackCh <- *connack - case "ka": - timeout := defaultTimeout - if r.connectionTimeout != 0 { - timeout = r.connectionTimeout - } - if err := r.ws.SetReadDeadline(time.Now().Add(timeout)); err != nil { - log.Println(err) - return - } - case "start_ack": - startack := new(startAckMessage) - if err := json.Unmarshal(b, startack); err != nil { - log.Println(err) - return - } - r.startackCh <- *startack - case "data": - data := new(processingDataMessage) - if err := json.Unmarshal(b, data); err != nil { - log.Println(err) - return - } - r.onReceive(&graphql.Response{ - Data: data.Payload.Data, - }) - case "complete": - complete := new(completeMessage) - if err := json.Unmarshal(b, complete); err != nil { - log.Println(err) - return - } - r.completeCh <- *complete - return - case "error": - em := new(errorMessage) - if err := json.Unmarshal(b, em); err != nil { - log.Println(err) - } - errors := make([]interface{}, len(em.Payload.Errors)) - for i, e := range em.Payload.Errors { - errors[i] = e - } - r.onReceive(&graphql.Response{ - Errors: &errors, - }) - return - default: + + handler, ok := handlers[msg.Type] + if !ok { log.Println("invalid message received") + continue + } + if handler(b) { + return } } } From e2df2050551c06805fb77890adf83b0ad4576037 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Tue, 2 Mar 2021 07:34:54 +0900 Subject: [PATCH 09/11] fix sample app argument definition --- test/EventApp/appsync_eventapp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/EventApp/appsync_eventapp.go b/test/EventApp/appsync_eventapp.go index 0661724..caf43e0 100644 --- a/test/EventApp/appsync_eventapp.go +++ b/test/EventApp/appsync_eventapp.go @@ -32,7 +32,7 @@ func main() { var ( region = flag.String("region", "", "AppSync API region") url = flag.String("url", "", "AppSync API URL") - protocol = flag.String("subscription protocol", "graphql-ws", "AppSync Subscription protocol(mqtt, graphql-ws)") + protocol = flag.String("protocol", "graphql-ws", "AppSync Subscription protocol(mqtt, graphql-ws)") ) flag.Parse() From e242418da3173c50ad7c8147017b9f5d06973dc1 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Mar 2021 11:46:06 +0900 Subject: [PATCH 10/11] WithOIDC accepts Bearer scheme --- pure_websocket_subscriber_option.go | 3 ++- pure_websocket_subscriber_option_test.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pure_websocket_subscriber_option.go b/pure_websocket_subscriber_option.go index d7c1854..35d46fb 100644 --- a/pure_websocket_subscriber_option.go +++ b/pure_websocket_subscriber_option.go @@ -2,6 +2,7 @@ package appsync import ( "net/url" + "strings" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" ) @@ -28,7 +29,7 @@ func WithAPIKey(host, apiKey string) PureWebSocketSubscriberOption { func WithOIDC(host, jwt string) PureWebSocketSubscriberOption { return func(p *PureWebSocketSubscriber) { p.header.Set("host", sanitize(host)) - p.header.Set("Authorization", jwt) + p.header.Set("Authorization", strings.TrimPrefix(jwt, "Bearer ")) } } diff --git a/pure_websocket_subscriber_option_test.go b/pure_websocket_subscriber_option_test.go index 0eea4b6..a43fb54 100644 --- a/pure_websocket_subscriber_option_test.go +++ b/pure_websocket_subscriber_option_test.go @@ -1,6 +1,7 @@ package appsync import ( + "strings" "testing" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" @@ -83,6 +84,13 @@ func TestWithOIDC(t *testing.T) { jwt: "jwt", }, }, + { + name: "WithOIDC with trim Bearer Success", + args: args{ + host: "https://example1234567890000.appsync-api.us-east-1.amazonaws.com/graphql", + jwt: "Bearer jwt", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -98,8 +106,8 @@ func TestWithOIDC(t *testing.T) { if s.header.Get("host") != sanitize(tt.args.host) { t.Errorf("got: %s, want: %s", s.header.Get("host"), tt.args.host) } - if s.header.Get("authorization") != tt.args.jwt { - t.Errorf("got: %s, want: %s", s.header.Get("authorization"), tt.args.jwt) + if s.header.Get("authorization") != strings.TrimPrefix(tt.args.jwt, "Bearer ") { + t.Errorf("got: %s, want: %s", s.header.Get("authorization"), strings.TrimPrefix(tt.args.jwt, "Bearer ")) } }) } From ce1b3f9dd88335235709c9e39ab24c8a6361c875 Mon Sep 17 00:00:00 2001 From: Kazuhito NAKAMURA Date: Mon, 22 Mar 2021 12:08:35 +0900 Subject: [PATCH 11/11] try to reduce cyclomatic complexity --- pure_websocket_subscriber.go | 149 +++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 68 deletions(-) diff --git a/pure_websocket_subscriber.go b/pure_websocket_subscriber.go index 76cee06..3c94ce5 100644 --- a/pure_websocket_subscriber.go +++ b/pure_websocket_subscriber.go @@ -198,6 +198,8 @@ func (p *PureWebSocketSubscriber) Stop() { p.op.disconnect() } +const defaultTimeout = time.Duration(300000) * time.Millisecond + type realtimeWebSocketOperation struct { onReceive func(response *graphql.Response) onConnectionLost func(err error) @@ -220,80 +222,21 @@ func (r *realtimeWebSocketOperation) readLoop() { defer close(r.startackCh) defer close(r.completeCh) - const defaultTimeout = time.Duration(300000) * time.Millisecond if err := r.ws.SetReadDeadline(time.Now().Add(defaultTimeout)); err != nil { log.Println(err) return } for { handlers := map[string]func(b []byte) (finish bool){ - "connection_ack": func(b []byte) bool { - connack := new(connectionAckMessage) - if err := json.Unmarshal(b, connack); err != nil { - log.Println(err) - return true - } - r.connackCh <- *connack - return false - }, - "ka": func(b []byte) bool { - timeout := defaultTimeout - if r.connectionTimeout != 0 { - timeout = r.connectionTimeout - } - if err := r.ws.SetReadDeadline(time.Now().Add(timeout)); err != nil { - log.Println(err) - return true - } - return false - }, - "start_ack": func(b []byte) bool { - startack := new(startAckMessage) - if err := json.Unmarshal(b, startack); err != nil { - log.Println(err) - return true - } - r.startackCh <- *startack - return false - }, - "data": func(b []byte) bool { - data := new(processingDataMessage) - if err := json.Unmarshal(b, data); err != nil { - log.Println(err) - return true - } - r.onReceive(&graphql.Response{ - Data: data.Payload.Data, - }) - return false - }, - "complete": func(b []byte) bool { - complete := new(completeMessage) - if err := json.Unmarshal(b, complete); err != nil { - log.Println(err) - return true - } - r.completeCh <- *complete - return true - }, - "error": func(b []byte) bool { - em := new(errorMessage) - if err := json.Unmarshal(b, em); err != nil { - log.Println(err) - return true - } - errors := make([]interface{}, len(em.Payload.Errors)) - for i, e := range em.Payload.Errors { - errors[i] = e - } - r.onReceive(&graphql.Response{ - Errors: &errors, - }) - return true - }, + "connection_ack": r.onConnected, + "ka": r.onKeepAlive, + "start_ack": r.onStarted, + "data": r.onData, + "complete": r.onStopped, + "error": r.onError, } - _, b, err := r.ws.ReadMessage() + _, payload, err := r.ws.ReadMessage() if err != nil { log.Println(err) if strings.Contains(err.Error(), "i/o timeout") { @@ -303,7 +246,7 @@ func (r *realtimeWebSocketOperation) readLoop() { } msg := new(message) - if err := json.Unmarshal(b, msg); err != nil { + if err := json.Unmarshal(payload, msg); err != nil { log.Println(err) return } @@ -313,7 +256,7 @@ func (r *realtimeWebSocketOperation) readLoop() { log.Println("invalid message received") continue } - if handler(b) { + if handler(payload) { return } } @@ -348,6 +291,16 @@ func (r *realtimeWebSocketOperation) connect(realtimeEndpoint string, header, pa return nil } +func (r *realtimeWebSocketOperation) onConnected(payload []byte) bool { + connack := new(connectionAckMessage) + if err := json.Unmarshal(payload, connack); err != nil { + log.Println(err) + return true + } + r.connackCh <- *connack + return false +} + func (r *realtimeWebSocketOperation) connectionInit() error { if r.connectionTimeout != 0 { return errors.New("already connection initialized") @@ -371,6 +324,18 @@ func (r *realtimeWebSocketOperation) connectionInit() error { return nil } +func (r *realtimeWebSocketOperation) onKeepAlive([]byte) bool { + timeout := defaultTimeout + if r.connectionTimeout != 0 { + timeout = r.connectionTimeout + } + if err := r.ws.SetReadDeadline(time.Now().Add(timeout)); err != nil { + log.Println(err) + return true + } + return false +} + func (r *realtimeWebSocketOperation) start(request []byte, authorization map[string]string) error { if len(r.subscriptionID) != 0 { return errors.New("already started") @@ -404,6 +369,28 @@ func (r *realtimeWebSocketOperation) start(request []byte, authorization map[str return nil } +func (r *realtimeWebSocketOperation) onStarted(payload []byte) bool { + startack := new(startAckMessage) + if err := json.Unmarshal(payload, startack); err != nil { + log.Println(err) + return true + } + r.startackCh <- *startack + return false +} + +func (r *realtimeWebSocketOperation) onData(payload []byte) bool { + data := new(processingDataMessage) + if err := json.Unmarshal(payload, data); err != nil { + log.Println(err) + return true + } + r.onReceive(&graphql.Response{ + Data: data.Payload.Data, + }) + return false +} + func (r *realtimeWebSocketOperation) stop() { if len(r.subscriptionID) == 0 { return @@ -425,6 +412,16 @@ func (r *realtimeWebSocketOperation) stop() { r.subscriptionID = "" } +func (r *realtimeWebSocketOperation) onStopped(payload []byte) bool { + complete := new(completeMessage) + if err := json.Unmarshal(payload, complete); err != nil { + log.Println(err) + return true + } + r.completeCh <- *complete + return true +} + func (r *realtimeWebSocketOperation) disconnect() { if r.ws == nil { return @@ -436,3 +433,19 @@ func (r *realtimeWebSocketOperation) disconnect() { r.connectionTimeout = 0 r.ws = nil } + +func (r *realtimeWebSocketOperation) onError(payload []byte) bool { + em := new(errorMessage) + if err := json.Unmarshal(payload, em); err != nil { + log.Println(err) + return true + } + errors := make([]interface{}, len(em.Payload.Errors)) + for i, e := range em.Payload.Errors { + errors[i] = e + } + r.onReceive(&graphql.Response{ + Errors: &errors, + }) + return true +}