diff --git a/headerv2.go b/headerv2.go index e3565bf..032f90e 100644 --- a/headerv2.go +++ b/headerv2.go @@ -50,7 +50,24 @@ func parseV2(r *bufio.Reader) (*HeaderV2, error) { } // highest 4 indicate address family - if (rawHdr.FamProto >> 4) > 3 { + switch rawHdr.FamProto >> 4 { + case 0: // local + if rawHdr.Len != 0 { + return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid length")} + } + case 1: // ipv4 + if rawHdr.Len != 12 { + return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid length")} + } + case 2: // ipv6 + if rawHdr.Len != 36 { + return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid length")} + } + case 3: // unix + if rawHdr.Len != 216 { + return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid length")} + } + default: return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid v2 address family")} } @@ -59,13 +76,7 @@ func parseV2(r *bufio.Reader) (*HeaderV2, error) { return nil, &InvalidHeaderErr{Read: buf[:16], error: errors.New("invalid v2 transport protocol")} } - if 16+int(rawHdr.Len) > len(buf) { - newBuf := make([]byte, 16+int(rawHdr.Len)) - copy(newBuf, buf[:16]) - buf = newBuf - } else { - buf = buf[:16+int(rawHdr.Len)] - } + buf = buf[:16+int(rawHdr.Len)] n, err = io.ReadFull(r, buf[16:]) if err != nil { diff --git a/headerv2_test.go b/headerv2_test.go index 352c8b5..42ac127 100644 --- a/headerv2_test.go +++ b/headerv2_test.go @@ -33,7 +33,9 @@ func TestHeaderV2(t *testing.T) { buf.Write(s.value) } hdr, err := Parse(bufio.NewReader(&buf)) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } assert.IsType(t, &HeaderV2{}, hdr, "Header Type") p := hdr.(*HeaderV2) assert.Equal(t, h.Command, p.Command, "Command") diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..48e2eaa --- /dev/null +++ b/parse_test.go @@ -0,0 +1,28 @@ +package proxyprotocol + +import ( + "bufio" + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse_Malformed(t *testing.T) { + data := []byte{ + // PROXY protocol v2 magic header + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, + // v2 version, PROXY cmd + 0x21, + // TCP, IPv4 (also works with 0x13,0x21,0x22,0x31,0x32) + 0x12, + // Length + 0x00, 0x00, + // src/dest address data _should_ be here but is omitted. + } + + _, err := Parse( + bufio.NewReader( + bytes.NewReader(data))) + assert.Error(t, err) +}