From 5fe5237d0a2501965b76ff98f12110b13af48af5 Mon Sep 17 00:00:00 2001 From: michaelyou Date: Tue, 12 Dec 2017 12:49:42 +0800 Subject: [PATCH] Fix ReadPacket and WritePacket payload length is a multiple MaxPayloadLen --- packet/conn.go | 152 +++++++++++++++++++------------------------------ 1 file changed, 57 insertions(+), 95 deletions(-) diff --git a/packet/conn.go b/packet/conn.go index 3772e1a33..48dc40096 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -2,7 +2,6 @@ package packet import ( "bufio" - "bytes" "io" "net" @@ -30,86 +29,51 @@ func NewConn(conn net.Conn) *Conn { } func (c *Conn) ReadPacket() ([]byte, error) { - var buf bytes.Buffer - - if err := c.ReadPacketTo(&buf); err != nil { - return nil, errors.Trace(err) - } else { - return buf.Bytes(), nil - } - - // header := []byte{0, 0, 0, 0} - - // if _, err := io.ReadFull(c.br, header); err != nil { - // return nil, ErrBadConn - // } - - // length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - // if length < 1 { - // return nil, fmt.Errorf("invalid payload length %d", length) - // } - - // sequence := uint8(header[3]) - - // if sequence != c.Sequence { - // return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence) - // } - - // c.Sequence++ - - // data := make([]byte, length) - // if _, err := io.ReadFull(c.br, data); err != nil { - // return nil, ErrBadConn - // } else { - // if length < MaxPayloadLen { - // return data, nil - // } - - // var buf []byte - // buf, err = c.ReadPacket() - // if err != nil { - // return nil, ErrBadConn - // } else { - // return append(data, buf...), nil - // } - // } -} - -func (c *Conn) ReadPacketTo(w io.Writer) error { - header := []byte{0, 0, 0, 0} - - if _, err := io.ReadFull(c.br, header); err != nil { - return ErrBadConn - } + var prevData []byte + for { + // read packet header + header := []byte{0, 0, 0, 0} + if _, err := io.ReadFull(c.br, header); err != nil { + return nil, ErrBadConn + } - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - if length < 1 { - return errors.Errorf("invalid payload length %d", length) - } + // packet length [24 bit] + length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - sequence := uint8(header[3]) + // check packet sync [8 bit] + sequence := uint8(header[3]) + if sequence != c.Sequence { + return nil, errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence) + } + c.Sequence++ - if sequence != c.Sequence { - return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence) - } + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if length == 0 { + // there was no previous packet + if prevData == nil { + return nil, errors.Errorf("invalid payload length %d", length) + } + return prevData, nil + } - c.Sequence++ + // read packet body [length bytes] + data := make([]byte, length) + if _, err := io.ReadFull(c.br, data); err != nil { + return nil, ErrBadConn + } - if n, err := io.CopyN(w, c.br, int64(length)); err != nil { - return ErrBadConn - } else if n != int64(length) { - return ErrBadConn - } else { + // return data if this was the last packet if length < MaxPayloadLen { - return nil - } + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - if err := c.ReadPacketTo(w); err != nil { - return err + return append(prevData, data...), nil } + prevData = append(prevData, data...) } - - return nil } // data already has 4 bytes header @@ -117,37 +81,35 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { func (c *Conn) WritePacket(data []byte) error { length := len(data) - 4 - for length >= MaxPayloadLen { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - + for { + var size int + if length >= MaxPayloadLen { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = MaxPayloadLen + } else { + data[0] = byte(length) + data[1] = byte(length >> 8) + data[2] = byte(length >> 16) + size = length + } data[3] = c.Sequence - if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil { + if n, err := c.Write(data[:4+size]); err != nil { return ErrBadConn - } else if n != (4 + MaxPayloadLen) { + } else if n != (4 + size) { return ErrBadConn } else { c.Sequence++ - length -= MaxPayloadLen - data = data[MaxPayloadLen:] + if size != MaxPayloadLen { + return nil + } + length -= size + data = data[size:] + continue } } - - data[0] = byte(length) - data[1] = byte(length >> 8) - data[2] = byte(length >> 16) - data[3] = c.Sequence - - if n, err := c.Write(data); err != nil { - return ErrBadConn - } else if n != len(data) { - return ErrBadConn - } else { - c.Sequence++ - return nil - } } func (c *Conn) ResetSequence() {