Skip to content

Commit

Permalink
Rely on stdlib header parsing implementation to extract header names
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikPelli committed Dec 28, 2024
1 parent b9fc9e5 commit f18e376
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 157 deletions.
176 changes: 27 additions & 149 deletions internal/http1parser/header.go
Original file line number Diff line number Diff line change
@@ -1,165 +1,43 @@
package http1parser

import "errors"

var (
ErrBadProto = errors.New("bad protocol")
ErrMissingData = errors.New("missing data")
import (
"errors"
"net/textproto"
"strings"
)

const (
_eNextHeader int = iota
_eNextHeaderN
_eHeader
_eHeaderValueSpace
_eHeaderValue
_eHeaderValueN
_eMLHeaderStart
_eMLHeaderValue
)
var ErrBadProto = errors.New("bad protocol")

// Http1ExtractHeaders is an HTTP/1.0 and HTTP/1.1 header-only parser,
// to extract the original header names for the received request.
// Fully inspired by https://github.com/evanphx/wildcat
func Http1ExtractHeaders(input []byte) ([]string, error) {
total := len(input)
var path, version, headers int
var headerNames []string

// First line: METHOD PATH VERSION
var methodOk bool
for i := 0; i < total; i++ {
switch input[i] {
case ' ', '\t':
methodOk = true
path = i + 1
}
if methodOk {
break
}
}

if !methodOk {
return nil, ErrMissingData
}

var pathOk bool
for i := path; i < total; i++ {
switch input[i] {
case ' ', '\t':
pathOk = true
version = i + 1
}
if pathOk {
break
}
// Fully inspired by readMIMEHeader() in
// https://github.com/golang/go/blob/master/src/net/textproto/reader.go
func Http1ExtractHeaders(r *textproto.Reader) ([]string, error) {
// Discard first line, it doesn't contain useful information, and it has
// already been validated in http.ReadRequest()
if _, err := r.ReadLine(); err != nil {
return nil, err
}

if !pathOk {
return nil, ErrMissingData
// The first line cannot start with a leading space.
if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
return nil, ErrBadProto
}

var versionOk bool
var readN bool
for i := version; i < total; i++ {
c := input[i]

switch readN {
case false:
switch c {
case '\r':
readN = true
case '\n':
headers = i + 1
versionOk = true
}
case true:
if c != '\n' {
return nil, ErrBadProto
}
headers = i + 1
versionOk = true
}
if versionOk {
break
var headerNames []string
for {
kv, err := r.ReadContinuedLine()
if len(kv) == 0 {
// We have finished to parse the headers if we receive empty
// data without an error
return headerNames, err
}
}

if !versionOk {
return nil, ErrMissingData
}

// Header parsing
state := _eNextHeader
start := headers

for i := headers; i < total; i++ {
switch state {
case _eNextHeader:
switch input[i] {
case '\r':
state = _eNextHeaderN
case '\n':
return headerNames, nil
case ' ', '\t':
state = _eMLHeaderStart
default:
start = i
state = _eHeader
}
case _eNextHeaderN:
if input[i] != '\n' {
return nil, ErrBadProto
}

return headerNames, nil
case _eHeader:
if input[i] == ':' {
headerName := input[start:i]
headerNames = append(headerNames, string(headerName))
state = _eHeaderValueSpace
}
case _eHeaderValueSpace:
switch input[i] {
case ' ', '\t':
continue
}

start = i
state = _eHeaderValue
case _eHeaderValue:
switch input[i] {
case '\r':
state = _eHeaderValueN
case '\n':
state = _eNextHeader
default:
continue
}
case _eHeaderValueN:
if input[i] != '\n' {
return nil, ErrBadProto
}
state = _eNextHeader
case _eMLHeaderStart:
switch input[i] {
case ' ', '\t':
continue
}

start = i
state = _eMLHeaderValue
case _eMLHeaderValue:
switch input[i] {
case '\r':
state = _eHeaderValueN
case '\n':
state = _eNextHeader
default:
continue
}
// Key ends at first colon.
k, _, ok := strings.Cut(kv, ":")
if !ok {
return nil, ErrBadProto
}
headerNames = append(headerNames, k)
}

return nil, ErrMissingData
}
16 changes: 12 additions & 4 deletions internal/http1parser/header_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package http1parser_test

import (
"bufio"
"bytes"
"net/textproto"
"testing"

"github.com/elazarl/goproxy/internal/http1parser"
Expand All @@ -11,21 +14,24 @@ import (
func TestHttp1ExtractHeaders_Empty(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
"\r\n"
headers, err := http1parser.Http1ExtractHeaders([]byte(http1Data))

textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
headers, err := http1parser.Http1ExtractHeaders(textParser)
require.NoError(t, err)
assert.Empty(t, headers)
}

func TestHttp1ExtractHeaders(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
"Host: www.test.com\r\n" +
"Accept: */*\r\n" +
"Accept: */ /*\r\n" +
"Content-Length: 17\r\n" +
"lowercase: 3z\r\n" +
"\r\n" +
`{"hello":"world"}`

headers, err := http1parser.Http1ExtractHeaders([]byte(http1Data))
textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
headers, err := http1parser.Http1ExtractHeaders(textParser)
require.NoError(t, err)
assert.Len(t, headers, 4)
assert.Contains(t, headers, "Content-Length")
Expand All @@ -35,6 +41,8 @@ func TestHttp1ExtractHeaders(t *testing.T) {
func TestHttp1ExtractHeaders_InvalidData(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
`{"hello":"world"}`
_, err := http1parser.Http1ExtractHeaders([]byte(http1Data))

textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
_, err := http1parser.Http1ExtractHeaders(textParser)
require.Error(t, err)
}
18 changes: 14 additions & 4 deletions internal/http1parser/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ func NewRequestReader(preventCanonicalization bool, conn io.Reader) *RequestRead
}
}

// IsEOF returns true if there is no more data that can be read from the
// buffer and the underlying connection is closed.
func (r *RequestReader) IsEOF() bool {
_, err := r.reader.Peek(1)
return errors.Is(err, io.EOF)
}

// Reader is used to take over the buffered connection data
// (e.g. with HTTP/2 data).
// After calling this function, make sure to consume all the data related
// to the current request.
func (r *RequestReader) Reader() *bufio.Reader {
return r.reader
}
Expand All @@ -54,8 +60,9 @@ func (r *RequestReader) ReadRequest() (*http.Request, error) {
return nil, err
}

httpData := getRequestData(r.reader, r.cloned)
headers, _ := Http1ExtractHeaders(httpData)
httpDataReader := getRequestReader(r.reader, r.cloned)
headers, _ := Http1ExtractHeaders(httpDataReader)

for _, headerName := range headers {
canonicalizedName := textproto.CanonicalMIMEHeaderKey(headerName)
if canonicalizedName == headerName {
Expand All @@ -73,12 +80,15 @@ func (r *RequestReader) ReadRequest() (*http.Request, error) {
return req, nil
}

func getRequestData(r *bufio.Reader, cloned *bytes.Buffer) []byte {
func getRequestReader(r *bufio.Reader, cloned *bytes.Buffer) *textproto.Reader {
// "Cloned" buffer uses the raw connection as the data source.
// However, the *bufio.Reader can read also bytes of another unrelated
// request on the same connection, since it's buffered, so we have to
// ignore them before passing the data to our headers parser.
// Data related to the next request will remain inside the buffer for
// later usage.
return cloned.Next(cloned.Len() - r.Buffered())
data := cloned.Next(cloned.Len() - r.Buffered())
return &textproto.Reader{
R: bufio.NewReader(bytes.NewReader(data)),
}
}

0 comments on commit f18e376

Please sign in to comment.