Skip to content

Commit

Permalink
Prevent Request headers canonicalization (#607)
Browse files Browse the repository at this point in the history
* Add header extraction with tests

* Add prevent canonicalization implementation

* Update documentation

* Fix linter issues

* Remove redundant body read

* Avoid to duplicate read data if prevent canonicalization is set to false

* Copy https mitm fixes to http mitm

* Prevent response body memory leak by making sure to close it when we are done

* Add additional request parsing tests

* Rely on stdlib header parsing implementation to extract header names
  • Loading branch information
ErikPelli authored Dec 31, 2024
1 parent 408830d commit 3d6017a
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ linters:
- containedctx
- decorder
- dogsled
- dupl
- durationcheck
- errchkjson
- errname
Expand Down Expand Up @@ -80,6 +79,7 @@ linters:
- copyloopvar
- cyclop
- depguard
- dupl
- dupword
- err113
- exhaustruct
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ proxy is `localhost:8080`, which is the default one in our example.
- You can specify a `MITM certificates cache`, to reuse them later for other requests to the same host, thus saving CPU. Not enabled by default, but you should use it in production!
- Redirect normal HTTP traffic to a `custom handler`, when the target is a `relative path` (e.g. `/ping`)
- You can choose the logger to use, by implementing the `Logger` interface
- You can `disable` the HTTP request headers `canonicalization`, by setting `PreventCanonicalization` to true

## Proxy modes
1. Regular HTTP proxy
Expand Down
12 changes: 10 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ module github.com/elazarl/goproxy

go 1.20

require golang.org/x/net v0.33.0
require (
github.com/stretchr/testify v1.10.0
golang.org/x/net v0.33.0
)

require golang.org/x/text v0.21.0 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
153 changes: 91 additions & 62 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"sync"
"sync/atomic"

"github.com/elazarl/goproxy/internal/http1parser"
"github.com/elazarl/goproxy/internal/signer"
)

Expand Down Expand Up @@ -192,15 +193,25 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
var targetSiteCon net.Conn
var remote *bufio.Reader

for {
client := bufio.NewReader(proxyClient)
req, err := http.ReadRequest(client)
client := http1parser.NewRequestReader(proxy.PreventCanonicalization, proxyClient)
for !client.IsEOF() {
req, err := client.ReadRequest()
if err != nil && !errors.Is(err, io.EOF) {
ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err)
}
if err != nil {
return
}

// Take the original value before filtering the request
closeConn := req.Close

// since we're converting the request, need to carry over the
// original connecting IP as well
req.RemoteAddr = r.RemoteAddr
ctx.Logf("req %v", r.Host)
ctx.Req = req

req, resp := proxy.filterRequest(req, ctx)
if resp == nil {
// Establish a connection with the remote server only if the proxy
Expand All @@ -218,18 +229,27 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
httpError(proxyClient, ctx, err)
return
}
resp, err = http.ReadResponse(remote, req)
resp, err = func() (*http.Response, error) {
defer req.Body.Close()
return http.ReadResponse(remote, req)
}()
if err != nil {
httpError(proxyClient, ctx, err)
return
}
defer resp.Body.Close()
}
resp = proxy.filterResponse(resp, ctx)
if err := resp.Write(proxyClient); err != nil {
err = resp.Write(proxyClient)
_ = resp.Body.Close()
if err != nil {
httpError(proxyClient, ctx, err)
return
}

if closeConn {
ctx.Logf("Non-persistent connection; closing")
return
}
}
case ConnectMitm:
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
Expand All @@ -255,9 +275,10 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Warnf("Cannot handshake client %v %v", r.Host, err)
return
}
clientTlsReader := bufio.NewReader(rawClientTls)
for !isEOF(clientTlsReader) {
req, err := http.ReadRequest(clientTlsReader)

clientTlsReader := http1parser.NewRequestReader(proxy.PreventCanonicalization, rawClientTls)
for !clientTlsReader.IsEOF() {
req, err := clientTlsReader.ReadRequest()
ctx := &ProxyCtx{
Req: req,
Session: atomic.AddInt64(&proxy.sess, 1),
Expand All @@ -266,10 +287,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
RoundTripper: ctx.RoundTripper,
}
if err != nil && !errors.Is(err, io.EOF) {
return
ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err)
}
if err != nil {
ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err)
return
}

Expand Down Expand Up @@ -298,7 +318,8 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
// parse the HTTP Body for PRI requests. This leaves the body of
// the http2.ClientPreface ("SM\r\n\r\n") on the wire which we need
// to clear before setting up the connection.
_, err := clientTlsReader.Discard(6)
reader := clientTlsReader.Reader()
_, err := reader.Discard(6)
if err != nil {
ctx.Warnf("Failed to process HTTP2 client preface: %v", err)
return
Expand All @@ -307,7 +328,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Warnf("HTTP2 connection failed: disallowed")
return
}
tr := H2Transport{clientTlsReader, rawClientTls, tlsConfig.Clone(), host}
tr := H2Transport{reader, rawClientTls, tlsConfig.Clone(), host}
if _, err := tr.RoundTrip(req); err != nil {
ctx.Warnf("HTTP2 connection failed: %v", err)
} else {
Expand Down Expand Up @@ -349,61 +370,69 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Logf("resp %v", resp.Status)
}
resp = proxy.filterResponse(resp, ctx)
defer resp.Body.Close()

text := resp.Status
statusCode := strconv.Itoa(resp.StatusCode) + " "
text = strings.TrimPrefix(text, statusCode)
// always use 1.1 to support chunked encoding
if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil {
ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err)
return
}

if resp.Request.Method == http.MethodHead {
// don't change Content-Length for HEAD request
} else if (resp.StatusCode >= 100 && resp.StatusCode < 200) ||
resp.StatusCode == http.StatusNoContent {
// RFC7230: A server MUST NOT send a Content-Length header field in any response
// with a status code of 1xx (Informational) or 204 (No Content)
resp.Header.Del("Content-Length")
} else {
// Since we don't know the length of resp, return chunked encoded response
// TODO: use a more reasonable scheme
resp.Header.Del("Content-Length")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// Force connection close otherwise chrome will keep CONNECT tunnel open forever
resp.Header.Set("Connection", "close")
if err := resp.Header.Write(rawClientTls); err != nil {
ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err)
return
}
if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil {
ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err)
return
}
// Run defer inside a custom function to prevent response body memory leak
if ok := func() bool {
defer resp.Body.Close()

text := resp.Status
statusCode := strconv.Itoa(resp.StatusCode) + " "
text = strings.TrimPrefix(text, statusCode)
// always use 1.1 to support chunked encoding
if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil {
ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err)
return false
}

if resp.Request.Method == http.MethodHead ||
(resp.StatusCode >= 100 && resp.StatusCode < 200) ||
resp.StatusCode == http.StatusNoContent ||
resp.StatusCode == http.StatusNotModified {
// Don't write out a response body, when it's not allowed
// in RFC7230
} else {
chunked := newChunkedWriter(rawClientTls)
if _, err := io.Copy(chunked, resp.Body); err != nil {
ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err)
return
if resp.Request.Method == http.MethodHead {
// don't change Content-Length for HEAD request
} else if (resp.StatusCode >= 100 && resp.StatusCode < 200) ||
resp.StatusCode == http.StatusNoContent {
// RFC7230: A server MUST NOT send a Content-Length header field in any response
// with a status code of 1xx (Informational) or 204 (No Content)
resp.Header.Del("Content-Length")
} else {
// Since we don't know the length of resp, return chunked encoded response
// TODO: use a more reasonable scheme
resp.Header.Del("Content-Length")
resp.Header.Set("Transfer-Encoding", "chunked")
}
if err := chunked.Close(); err != nil {
ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err)
return
// Force connection close otherwise chrome will keep CONNECT tunnel open forever
resp.Header.Set("Connection", "close")
if err := resp.Header.Write(rawClientTls); err != nil {
ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err)
return false
}
if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil {
ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err)
return
ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err)
return false
}

if resp.Request.Method == http.MethodHead ||
(resp.StatusCode >= 100 && resp.StatusCode < 200) ||
resp.StatusCode == http.StatusNoContent ||
resp.StatusCode == http.StatusNotModified {
// Don't write out a response body, when it's not allowed
// in RFC7230
} else {
chunked := newChunkedWriter(rawClientTls)
if _, err := io.Copy(chunked, resp.Body); err != nil {
ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err)
return false
}
if err := chunked.Close(); err != nil {
ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err)
return false
}
if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil {
ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err)
return false
}
}

return true
}(); !ok {
return
}

if closeConn {
Expand Down
43 changes: 43 additions & 0 deletions internal/http1parser/header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package http1parser

import (
"errors"
"net/textproto"
"strings"
)

var ErrBadProto = errors.New("bad protocol")

// Http1ExtractHeaders is an HTTP/1.0 and HTTP/1.1 header-only parser,
// to extract the original header names for the received request.
// Fully inspired by readMIMEHeader() in
// https://github.com/golang/go/blob/master/src/net/textproto/reader.go
func Http1ExtractHeaders(r *textproto.Reader) ([]string, error) {
// Discard first line, it doesn't contain useful information, and it has
// already been validated in http.ReadRequest()
if _, err := r.ReadLine(); err != nil {
return nil, err
}

// The first line cannot start with a leading space.
if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
return nil, ErrBadProto
}

var headerNames []string
for {
kv, err := r.ReadContinuedLine()
if len(kv) == 0 {
// We have finished to parse the headers if we receive empty
// data without an error
return headerNames, err
}

// Key ends at first colon.
k, _, ok := strings.Cut(kv, ":")
if !ok {
return nil, ErrBadProto
}
headerNames = append(headerNames, k)
}
}
48 changes: 48 additions & 0 deletions internal/http1parser/header_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package http1parser_test

import (
"bufio"
"bytes"
"net/textproto"
"testing"

"github.com/elazarl/goproxy/internal/http1parser"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestHttp1ExtractHeaders_Empty(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
"\r\n"

textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
headers, err := http1parser.Http1ExtractHeaders(textParser)
require.NoError(t, err)
assert.Empty(t, headers)
}

func TestHttp1ExtractHeaders(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
"Host: www.test.com\r\n" +
"Accept: */ /*\r\n" +
"Content-Length: 17\r\n" +
"lowercase: 3z\r\n" +
"\r\n" +
`{"hello":"world"}`

textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
headers, err := http1parser.Http1ExtractHeaders(textParser)
require.NoError(t, err)
assert.Len(t, headers, 4)
assert.Contains(t, headers, "Content-Length")
assert.Contains(t, headers, "lowercase")
}

func TestHttp1ExtractHeaders_InvalidData(t *testing.T) {
http1Data := "POST /index.html HTTP/1.1\r\n" +
`{"hello":"world"}`

textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data))))
_, err := http1parser.Http1ExtractHeaders(textParser)
require.Error(t, err)
}
Loading

0 comments on commit 3d6017a

Please sign in to comment.