From ad0fd0e464b8c9edd6e588ce05c054d45097da4b Mon Sep 17 00:00:00 2001
From: ImpSy <3097030+ImpSy@users.noreply.github.com>
Date: Mon, 12 Feb 2024 20:13:32 +0100
Subject: [PATCH] make websocket goroutines tied to grpc.ClientConn lifecycle

---
 client/channel/channel.go  | 119 ++++++++++++++++++++-----------------
 client/sql/sparksession.go |   5 +-
 2 files changed, 69 insertions(+), 55 deletions(-)

diff --git a/client/channel/channel.go b/client/channel/channel.go
index 9133a1b..af6ab0d 100644
--- a/client/channel/channel.go
+++ b/client/channel/channel.go
@@ -17,6 +17,7 @@
 package channel
 
 import (
+	"bufio"
 	"context"
 	"crypto/tls"
 	"crypto/x509"
@@ -28,12 +29,12 @@ import (
 	"net/url"
 	"strconv"
 	"strings"
-	"sync"
 
 	"github.com/gorilla/websocket"
 
 	"golang.org/x/oauth2"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/credentials/insecure"
 	"google.golang.org/grpc/credentials/oauth"
@@ -62,6 +63,8 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
 
 	remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port)
 	opts = append(opts, grpc.WithAuthority(cb.Host))
+	var companionFuncs = []func(){}
+	onClose := func() {}
 	if cb.Scheme != "sc" {
 		grpcSide, websocketSide := net.Pipe()
 		u := url.URL{Scheme: cb.Scheme, Host: remote, Path: cb.Path, RawQuery: cb.Query}
@@ -69,67 +72,28 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
 		header := http.Header{}
 		header.Set("Authorization", "Bearer "+cb.Token)
 
-		c, _, err := websocket.DefaultDialer.Dial(u.String(), header)
+		ws, _, err := websocket.DefaultDialer.Dial(u.String(), header)
 		if err != nil {
 			log.Fatal("dial:", err)
 		}
 
 		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
 		opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
-			return grpcSide, nil
+			return websocketSide, nil
 		}))
 
-		done := make(chan struct{})
-		data := make([]byte, 10*1024*1024)
-		var wg sync.WaitGroup
-
-		wg.Add(1)
-		go func() {
-			defer wg.Done()
-			defer c.Close()
-			defer close(done)
-			for {
-				mt, message, err := c.ReadMessage()
-				if err != nil {
-					log.Println("c.ReadMessage:", err)
-					break
-				}
-
-				if mt != websocket.BinaryMessage {
-					log.Println("mt != websocket.BinaryMessage")
-					break
-				}
-
-				n, err := websocketSide.Write(message)
-				if err != nil {
-					log.Println("pipe.Write:", err)
-					break
-				}
-
-				if len(message) != n {
-					log.Printf("whooot! len(data) != n => %d != %d!\n", len(message), n)
-					break
-				}
-			}
-		}()
-
-		wg.Add(1)
-		go func() {
-			defer wg.Done()
-			for {
-				n, err := websocketSide.Read(data)
-				if err != nil {
-					log.Println("pipe.Read:", err)
-					break
-				}
-
-				err = c.WriteMessage(websocket.BinaryMessage, data[:n])
-				if err != nil {
-					log.Println("c.WriteMessage:", err)
-					break
-				}
-			}
-		}()
+		onClose = func() {
+			websocketSide.Close()
+			grpcSide.Close()
+			ws.Close()
+		}
+
+		// This function is responsible for reading from the websocket and writing to the pipe (grpc).
+		wsToGrpc := func() { ForwardToGrpc(ws, grpcSide) }
+
+		// This function is responsible for reading from the pipe (grpc) and writing to the websocket.
+		fwToWs := func() { ForwardToWebsocket(grpcSide, ws) }
+		companionFuncs = append(companionFuncs, wsToGrpc, fwToWs)
 	} else if cb.Token == "" {
 		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
 	} else {
@@ -155,6 +119,14 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
 	if err != nil {
 		return nil, fmt.Errorf("failed to connect to remote %s: %w", remote, err)
 	}
+	for _, f := range companionFuncs {
+		go f()
+	}
+	go func() {
+		for conn.GetState() != connectivity.Shutdown {
+		}
+		onClose()
+	}()
 	return conn, nil
 }
 
@@ -233,3 +205,42 @@ func NewBuilder(connection string) (*ChannelBuilder, error) {
 	}
 	return cb, nil
 }
+
+func ForwardToWebsocket(c net.Conn, ws *websocket.Conn) error {
+	input := bufio.NewScanner(c)
+	for input.Scan() {
+		err := ws.WriteMessage(websocket.BinaryMessage, input.Bytes())
+		if err != nil {
+			log.Println("ws.WriteMessage:", err)
+			return err
+		}
+	}
+	return nil
+}
+
+func ForwardToGrpc(ws *websocket.Conn, c net.Conn) error {
+	for {
+		mt, message, err := ws.ReadMessage()
+		if err != nil {
+			log.Println("ws.ReadMessage:", err)
+			return err
+		}
+
+		if mt != websocket.BinaryMessage {
+			log.Println("mt != websocket.BinaryMessage")
+			continue
+		}
+
+		n, err := c.Write(message)
+		if err != nil {
+			log.Println("pipe.Write:", err)
+			return err
+		}
+
+		if len(message) != n {
+			errMsg := fmt.Sprintf("Write Failure: data length mismatch (expected %d, got %d)", len(message), n)
+			log.Printf("%s\n", errMsg)
+			return errors.New(errMsg)
+		}
+	}
+}
diff --git a/client/sql/sparksession.go b/client/sql/sparksession.go
index 31c368c..38cd788 100644
--- a/client/sql/sparksession.go
+++ b/client/sql/sparksession.go
@@ -26,6 +26,7 @@ import (
 	"github.com/google/uuid"
 	"github.com/sigmarkarl/ocean-spark-connect-go/v34/client/channel"
 	proto "github.com/sigmarkarl/ocean-spark-connect-go/v34/internal/generated"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/metadata"
 )
 
@@ -72,6 +73,7 @@ func (s SparkSessionBuilder) Build() (sparkSession, error) {
 	client := proto.NewSparkConnectServiceClient(conn)
 	return &sparkSessionImpl{
 		sessionId: uuid.NewString(),
+		conn:      conn,
 		client:    client,
 		metadata:  meta,
 	}, nil
@@ -79,6 +81,7 @@ func (s SparkSessionBuilder) Build() (sparkSession, error) {
 
 type sparkSessionImpl struct {
 	sessionId string
+	conn      *grpc.ClientConn
 	client    proto.SparkConnectServiceClient
 	metadata  metadata.MD
 }
@@ -123,7 +126,7 @@ func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) {
 }
 
 func (s *sparkSessionImpl) Stop() error {
-	return nil
+	return s.conn.Close()
 }
 
 func (s *sparkSessionImpl) executePlan(plan *proto.Plan) (proto.SparkConnectService_ExecutePlanClient, error) {