diff --git a/client/channel/channel.go b/client/channel/channel.go index 9133a1b..af6ab0d 100644 --- a/client/channel/channel.go +++ b/client/channel/channel.go @@ -17,6 +17,7 @@ package channel import ( + "bufio" "context" "crypto/tls" "crypto/x509" @@ -28,12 +29,12 @@ import ( "net/url" "strconv" "strings" - "sync" "github.com/gorilla/websocket" "golang.org/x/oauth2" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/oauth" @@ -62,6 +63,8 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port) opts = append(opts, grpc.WithAuthority(cb.Host)) + var companionFuncs = []func(){} + onClose := func() {} if cb.Scheme != "sc" { grpcSide, websocketSide := net.Pipe() u := url.URL{Scheme: cb.Scheme, Host: remote, Path: cb.Path, RawQuery: cb.Query} @@ -69,67 +72,28 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { header := http.Header{} header.Set("Authorization", "Bearer "+cb.Token) - c, _, err := websocket.DefaultDialer.Dial(u.String(), header) + ws, _, err := websocket.DefaultDialer.Dial(u.String(), header) if err != nil { log.Fatal("dial:", err) } opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { - return grpcSide, nil + return websocketSide, nil })) - done := make(chan struct{}) - data := make([]byte, 10*1024*1024) - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - defer c.Close() - defer close(done) - for { - mt, message, err := c.ReadMessage() - if err != nil { - log.Println("c.ReadMessage:", err) - break - } - - if mt != websocket.BinaryMessage { - log.Println("mt != websocket.BinaryMessage") - break - } - - n, err := websocketSide.Write(message) - if err != nil { - log.Println("pipe.Write:", err) - break - } - - if len(message) != n { - log.Printf("whooot! len(data) != n => %d != %d!\n", len(message), n) - break - } - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - for { - n, err := websocketSide.Read(data) - if err != nil { - log.Println("pipe.Read:", err) - break - } - - err = c.WriteMessage(websocket.BinaryMessage, data[:n]) - if err != nil { - log.Println("c.WriteMessage:", err) - break - } - } - }() + onClose = func() { + websocketSide.Close() + grpcSide.Close() + ws.Close() + } + + // This function is responsible for reading from the websocket and writing to the pipe (grpc). + wsToGrpc := func() { ForwardToGrpc(ws, grpcSide) } + + // This function is responsible for reading from the pipe (grpc) and writing to the websocket. + fwToWs := func() { ForwardToWebsocket(grpcSide, ws) } + companionFuncs = append(companionFuncs, wsToGrpc, fwToWs) } else if cb.Token == "" { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { @@ -155,6 +119,14 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { if err != nil { return nil, fmt.Errorf("failed to connect to remote %s: %w", remote, err) } + for _, f := range companionFuncs { + go f() + } + go func() { + for conn.GetState() != connectivity.Shutdown { + } + onClose() + }() return conn, nil } @@ -233,3 +205,42 @@ func NewBuilder(connection string) (*ChannelBuilder, error) { } return cb, nil } + +func ForwardToWebsocket(c net.Conn, ws *websocket.Conn) error { + input := bufio.NewScanner(c) + for input.Scan() { + err := ws.WriteMessage(websocket.BinaryMessage, input.Bytes()) + if err != nil { + log.Println("ws.WriteMessage:", err) + return err + } + } + return nil +} + +func ForwardToGrpc(ws *websocket.Conn, c net.Conn) error { + for { + mt, message, err := ws.ReadMessage() + if err != nil { + log.Println("ws.ReadMessage:", err) + return err + } + + if mt != websocket.BinaryMessage { + log.Println("mt != websocket.BinaryMessage") + continue + } + + n, err := c.Write(message) + if err != nil { + log.Println("pipe.Write:", err) + return err + } + + if len(message) != n { + errMsg := fmt.Sprintf("Write Failure: data length mismatch (expected %d, got %d)", len(message), n) + log.Printf("%s\n", errMsg) + return errors.New(errMsg) + } + } +} diff --git a/client/sql/sparksession.go b/client/sql/sparksession.go index 31c368c..38cd788 100644 --- a/client/sql/sparksession.go +++ b/client/sql/sparksession.go @@ -26,6 +26,7 @@ import ( "github.com/google/uuid" "github.com/sigmarkarl/ocean-spark-connect-go/v34/client/channel" proto "github.com/sigmarkarl/ocean-spark-connect-go/v34/internal/generated" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -72,6 +73,7 @@ func (s SparkSessionBuilder) Build() (sparkSession, error) { client := proto.NewSparkConnectServiceClient(conn) return &sparkSessionImpl{ sessionId: uuid.NewString(), + conn: conn, client: client, metadata: meta, }, nil @@ -79,6 +81,7 @@ func (s SparkSessionBuilder) Build() (sparkSession, error) { type sparkSessionImpl struct { sessionId string + conn *grpc.ClientConn client proto.SparkConnectServiceClient metadata metadata.MD } @@ -123,7 +126,7 @@ func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) { } func (s *sparkSessionImpl) Stop() error { - return nil + return s.conn.Close() } func (s *sparkSessionImpl) executePlan(plan *proto.Plan) (proto.SparkConnectService_ExecutePlanClient, error) {