Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReadPacket and WritePacket payload length is a multiple #209

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 57 additions & 95 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package packet

import (
"bufio"
"bytes"
"io"
"net"

Expand Down Expand Up @@ -30,124 +29,87 @@ func NewConn(conn net.Conn) *Conn {
}

func (c *Conn) ReadPacket() ([]byte, error) {
var buf bytes.Buffer

if err := c.ReadPacketTo(&buf); err != nil {
return nil, errors.Trace(err)
} else {
return buf.Bytes(), nil
}

// header := []byte{0, 0, 0, 0}

// if _, err := io.ReadFull(c.br, header); err != nil {
// return nil, ErrBadConn
// }

// length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
// if length < 1 {
// return nil, fmt.Errorf("invalid payload length %d", length)
// }

// sequence := uint8(header[3])

// if sequence != c.Sequence {
// return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
// }

// c.Sequence++

// data := make([]byte, length)
// if _, err := io.ReadFull(c.br, data); err != nil {
// return nil, ErrBadConn
// } else {
// if length < MaxPayloadLen {
// return data, nil
// }

// var buf []byte
// buf, err = c.ReadPacket()
// if err != nil {
// return nil, ErrBadConn
// } else {
// return append(data, buf...), nil
// }
// }
}

func (c *Conn) ReadPacketTo(w io.Writer) error {
header := []byte{0, 0, 0, 0}

if _, err := io.ReadFull(c.br, header); err != nil {
return ErrBadConn
}
var prevData []byte
for {
// read packet header
header := []byte{0, 0, 0, 0}
if _, err := io.ReadFull(c.br, header); err != nil {
return nil, ErrBadConn
}

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
if length < 1 {
return errors.Errorf("invalid payload length %d", length)
}
// packet length [24 bit]
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not checking length with max payload length here?


sequence := uint8(header[3])
// check packet sync [8 bit]
sequence := uint8(header[3])
if sequence != c.Sequence {
return nil, errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
}
c.Sequence++

if sequence != c.Sequence {
return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
}
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)−1 bytes long
if length == 0 {
// there was no previous packet
if prevData == nil {
return nil, errors.Errorf("invalid payload length %d", length)
}
return prevData, nil
}

c.Sequence++
// read packet body [length bytes]
data := make([]byte, length)
if _, err := io.ReadFull(c.br, data); err != nil {
return nil, ErrBadConn
}

if n, err := io.CopyN(w, c.br, int64(length)); err != nil {
return ErrBadConn
} else if n != int64(length) {
return ErrBadConn
} else {
// return data if this was the last packet
if length < MaxPayloadLen {
return nil
}
// zero allocations for non-split packets
if prevData == nil {
return data, nil
}

if err := c.ReadPacketTo(w); err != nil {
return err
return append(prevData, data...), nil
}
prevData = append(prevData, data...)
}

return nil
}

// data already has 4 bytes header
// will modify data inplace
func (c *Conn) WritePacket(data []byte) error {
length := len(data) - 4

for length >= MaxPayloadLen {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff

for {
var size int
if length >= MaxPayloadLen {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = MaxPayloadLen
} else {
data[0] = byte(length)
data[1] = byte(length >> 8)
data[2] = byte(length >> 16)
size = length
}
data[3] = c.Sequence

if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil {
if n, err := c.Write(data[:4+size]); err != nil {
return ErrBadConn
} else if n != (4 + MaxPayloadLen) {
} else if n != (4 + size) {
return ErrBadConn
} else {
c.Sequence++
length -= MaxPayloadLen
data = data[MaxPayloadLen:]
if size != MaxPayloadLen {
return nil
}
length -= size
data = data[size:]
continue
}
}

data[0] = byte(length)
data[1] = byte(length >> 8)
data[2] = byte(length >> 16)
data[3] = c.Sequence

if n, err := c.Write(data); err != nil {
return ErrBadConn
} else if n != len(data) {
return ErrBadConn
} else {
c.Sequence++
return nil
}
}

func (c *Conn) ResetSequence() {
Expand Down