diff --git a/graphql/websocket.go b/graphql/websocket.go index 589f6a81..bb9a98b6 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -45,11 +45,11 @@ const ( type webSocketClient struct { Dialer Dialer Header http.Header + endpoint string conn WSConn - isClosing bool errChan chan error - endpoint string subscriptions subscriptionMap + isClosing bool sync.Mutex } @@ -96,6 +96,14 @@ func (w *webSocketClient) waitForConnAck() error { return nil } +func (w *webSocketClient) handleErr(err error) { + w.Lock() + defer w.Unlock() + if !w.isClosing { + w.errChan <- err + } +} + func (w *webSocketClient) listenWebSocket() { for { if w.isClosing { @@ -103,20 +111,12 @@ func (w *webSocketClient) listenWebSocket() { } _, message, err := w.conn.ReadMessage() if err != nil { - w.Lock() - defer w.Unlock() - if !w.isClosing { - w.errChan <- err - } + w.handleErr(err) return } err = w.forwardWebSocketData(message) if err != nil { - w.Lock() - defer w.Unlock() - if !w.isClosing { - w.errChan <- err - } + w.handleErr(err) return } } @@ -182,7 +182,10 @@ func (w *webSocketClient) Close() error { if err != nil { return fmt.Errorf("failed to send closure message: %w", err) } - w.UnsubscribeAll() + err = w.UnsubscribeAll() + if err != nil { + return fmt.Errorf("failed to unsubscribe: %w", err) + } w.Lock() defer w.Unlock() w.isClosing = true