From 8c6985bd47501ce4e14dcde56ac5936a486582b9 Mon Sep 17 00:00:00 2001 From: eugene Date: Fri, 13 Sep 2024 10:15:58 -0400 Subject: [PATCH] [fixes #625] implement support for receiving multipart edge messages update version [fixes #625] --- ziti/edge/conn.go | 11 ++++++- ziti/edge/messages.go | 9 ++++++ ziti/edge/network/conn.go | 52 +++++++++++++++++++++++++++------- ziti/edge/network/conn_test.go | 51 +++++++++++++++++++++++++++++++++ ziti/sdkinfo/build_info.go | 2 +- 5 files changed, 113 insertions(+), 12 deletions(-) diff --git a/ziti/edge/conn.go b/ziti/edge/conn.go index 63e88262..bfb378f8 100644 --- a/ziti/edge/conn.go +++ b/ziti/edge/conn.go @@ -153,7 +153,8 @@ func (ec *MsgChannel) WriteTraced(data []byte, msgUUID []byte, hdrs map[int32][] copyBuf := make([]byte, len(data)) copy(copyBuf, data) - msg := NewDataMsg(ec.id, ec.msgIdSeq.Next(), copyBuf) + seq := ec.msgIdSeq.Next() + msg := NewDataMsg(ec.id, seq, copyBuf) if msgUUID != nil { msg.Headers[UUIDHeader] = msgUUID } @@ -161,6 +162,14 @@ func (ec *MsgChannel) WriteTraced(data []byte, msgUUID []byte, hdrs map[int32][] for k, v := range hdrs { msg.Headers[k] = v } + + // indicate that we can accept multipart messages + // with the first message + if seq == 1 { + flags, _ := msg.GetUint32Header(FlagsHeader) + flags = flags | MULTIPART + msg.PutUint32Header(FlagsHeader, flags) + } ec.TraceMsg("write", msg) pfxlog.Logger().WithFields(GetLoggerFields(msg)).Debugf("writing %v bytes", len(copyBuf)) diff --git a/ziti/edge/messages.go b/ziti/edge/messages.go index 1dec98e2..6006a907 100644 --- a/ziti/edge/messages.go +++ b/ziti/edge/messages.go @@ -106,6 +106,15 @@ const ( // FIN is an edge payload flag used to signal communication ends FIN = 0x1 + // TRACE_UUID indicates that peer will send data messages with specially constructed UUID headers + TRACE_UUID = 1 << 1 + // MULTIPART indicates that peer can accept multipart data messages + MULTIPART = 1 << 2 + // STREAM indicates connection with stream semantics + // this allows consolidation of payloads to lower overhead + STREAM = 1 << 3 + // MULTIPART_MSG set on data message with multiple payloads + MULTIPART_MSG = 1 << 4 ) type CryptoMethod byte diff --git a/ziti/edge/network/conn.go b/ziti/edge/network/conn.go index 19f1cc2e..d8dc973d 100644 --- a/ziti/edge/network/conn.go +++ b/ziti/edge/network/conn.go @@ -26,6 +26,7 @@ import ( "crypto/rand" "encoding/base64" + "encoding/binary" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v3" "github.com/openziti/edge-api/rest_model" @@ -52,9 +53,10 @@ var _ edge.Conn = &edgeConn{} type edgeConn struct { edge.MsgChannel readQ *noopSeq[*channel.Message] - leftover []byte + inBuffer [][]byte msgMux edge.MsgMux hosting cmap.ConcurrentMap[string, *edgeListener] + flags uint32 closed atomic.Bool readFIN atomic.Bool sentFIN atomic.Bool @@ -458,10 +460,16 @@ func (conn *edgeConn) Read(p []byte) (int, error) { } log.Tracef("read buffer = %d bytes", len(p)) - if len(conn.leftover) > 0 { - log.Tracef("found %d leftover bytes", len(conn.leftover)) - n := copy(p, conn.leftover) - conn.leftover = conn.leftover[n:] + if len(conn.inBuffer) > 0 { + first := conn.inBuffer[0] + log.Tracef("found %d buffered bytes", len(first)) + n := copy(p, first) + first = first[n:] + if len(first) == 0 { + conn.inBuffer = conn.inBuffer[1:] + } else { + conn.inBuffer[0] = first + } return n, nil } @@ -471,7 +479,7 @@ func (conn *edgeConn) Read(p []byte) (int, error) { } msg, err := conn.readQ.GetNext() - if err == ErrClosed { + if errors.Is(err, ErrClosed) { log.Debug("sequencer closed, closing connection") conn.closed.Store(true) return 0, io.EOF @@ -484,6 +492,7 @@ func (conn *edgeConn) Read(p []byte) (int, error) { if flags&edge.FIN != 0 { conn.readFIN.Store(true) } + conn.flags = conn.flags | (flags & (edge.STREAM | edge.MULTIPART)) switch msg.ContentType { @@ -499,6 +508,8 @@ func (conn *edgeConn) Read(p []byte) (int, error) { return 0, io.EOF } + multipart := (flags & edge.MULTIPART_MSG) != 0 + // first data message should contain crypto header if conn.rxKey != nil { if len(d) != secretstream.StreamHeaderBytes { @@ -519,11 +530,32 @@ func (conn *edgeConn) Read(p []byte) (int, error) { return 0, err } } - n := copy(p, d) - conn.leftover = d[n:] + n := 0 + if multipart && len(d) > 0 { + var parts [][]byte + for len(d) > 0 { + l := binary.LittleEndian.Uint16(d[0:2]) + d = d[2:] + part := d[0:l] + d = d[l:] + parts = append(parts, part) + } + n = copy(p, parts[0]) + parts[0] = parts[0][n:] + if len(parts[0]) == 0 { + parts = parts[1:] + } + conn.inBuffer = append(conn.inBuffer, parts...) + } else { + n = copy(p, d) + d = d[n:] + if len(d) > 0 { + conn.inBuffer = append(conn.inBuffer, d) + } + } - log.Tracef("saving %d bytes for leftover", len(conn.leftover)) - log.Debugf("reading %v bytes", n) + log.Tracef("%d chunks in incoming buffer", len(conn.inBuffer)) + log.Debugf("read %v bytes", n) return n, nil default: diff --git a/ziti/edge/network/conn_test.go b/ziti/edge/network/conn_test.go index 7baf0c66..4def4524 100644 --- a/ziti/edge/network/conn_test.go +++ b/ziti/edge/network/conn_test.go @@ -2,10 +2,12 @@ package network import ( "crypto/x509" + "encoding/binary" "github.com/openziti/channel/v3" "github.com/openziti/foundation/v2/sequencer" "github.com/openziti/sdk-golang/ziti/edge" "github.com/stretchr/testify/require" + "io" "sync/atomic" "testing" "time" @@ -121,6 +123,55 @@ func BenchmarkSequencer(b *testing.B) { } } +func TestReadMultipart(t *testing.T) { + req := require.New(t) + mux := edge.NewCowMapMsgMux() + testChannel := &NoopTestChannel{} + + readQ := NewNoopSequencer[*channel.Message](4) + conn := &edgeConn{ + MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), + readQ: readQ, + msgMux: mux, + serviceName: "test", + } + + var stop atomic.Bool + defer stop.Store(true) + + var multipart []byte + words := []string{"Hello", "World", "of", "ziti"} + for _, w := range words { + multipart = binary.LittleEndian.AppendUint16(multipart, uint16(len(w))) + multipart = append(multipart, []byte(w)...) + } + msg := edge.NewDataMsg(1, uint32(0), multipart) + msg.Headers.PutUint32Header(edge.FlagsHeader, uint32(edge.MULTIPART_MSG)) + _ = readQ.PutSequenced(msg) + msg = edge.NewDataMsg(1, uint32(0), nil) + msg.Headers.PutUint32Header(edge.FlagsHeader, uint32(edge.FIN)) + err := readQ.PutSequenced(msg) + if err != nil { + panic(err) + } + + var read []string + for { + data := make([]byte, 1024) + req.NoError(conn.SetReadDeadline(time.Now().Add(1 * time.Second))) + n, e := conn.Read(data) + if e == io.EOF { + break + } + + req.NoError(e) + + read = append(read, string(data[:n])) + } + + req.Equal(words, read) +} + type NoopTestChannel struct { } diff --git a/ziti/sdkinfo/build_info.go b/ziti/sdkinfo/build_info.go index 5a44604c..5bc6177d 100644 --- a/ziti/sdkinfo/build_info.go +++ b/ziti/sdkinfo/build_info.go @@ -20,5 +20,5 @@ package sdkinfo const ( - Version = "v0.23.41" + Version = "v0.23.42" )