Skip to content

Commit

Permalink
martian: handle req/res body close error
Browse files Browse the repository at this point in the history
  • Loading branch information
Choraden committed Nov 14, 2023
1 parent 5f1d5a4 commit 59633df
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
17 changes: 12 additions & 5 deletions internal/martian/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (p proxyHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Body = http.NoBody
}
if outreq.Body != nil {
defer outreq.Body.Close()
defer mustCloseBody(outreq.Context(), outreq.Body)
}
outreq.Close = false

Expand Down Expand Up @@ -124,7 +124,7 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter,
if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445
pr, pw := io.Pipe()
req.Body = pr
defer req.Body.Close()
defer mustCloseBody(req.Context(), req.Body)

// perform the HTTP roundtrip
res, cerr = p.roundTrip(ctx, req)
Expand Down Expand Up @@ -164,7 +164,7 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter,
res = p.errorResponse(req, cerr)
p.warning(res.Header, cerr)
}
defer res.Body.Close()
defer mustCloseBody(res.Request.Context(), res.Body)

if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf(req.Context(), "error modifying CONNECT response: %v", err)
Expand Down Expand Up @@ -311,7 +311,7 @@ func (p proxyHandler) handleRequest(ctx *Context, rw http.ResponseWriter, req *h
res = p.errorResponse(req, err)
p.warning(res.Header, err)
}
defer res.Body.Close()
defer mustCloseBody(res.Request.Context(), res.Body)

// set request to original request manually, res.Request may be changed in transport.
// see https://github.com/google/martian/issues/298
Expand Down Expand Up @@ -408,7 +408,7 @@ func writeResponse(rw http.ResponseWriter, res *http.Response) {
panic(http.ErrAbortHandler)
}

res.Body.Close() // close now, instead of defer, to populate res.Trailer
mustCloseBody(res.Request.Context(), res.Body) // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
} else {
Expand All @@ -420,3 +420,10 @@ func writeResponse(rw http.ResponseWriter, res *http.Response) {
}
}
}

func mustCloseBody(ctx context.Context, body io.Closer) {
if err := body.Close(); err != nil {
log.Errorf(ctx, "got error while closing body: %v", err)
panic(http.ErrAbortHandler)
}
}
24 changes: 17 additions & 7 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/saucelabs/forwarder/internal/martian/mitm"
"github.com/saucelabs/forwarder/internal/martian/nosigpipe"
"github.com/saucelabs/forwarder/internal/martian/proxyutil"
"go.uber.org/multierr"
"golang.org/x/net/http/httpguts"
)

Expand Down Expand Up @@ -531,7 +532,7 @@ func (p *Proxy) handleMITM(ctx *Context, req *http.Request, session *Session, br
return p.handle(ctx, conn, brw)
}

func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) error {
func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) (retErr error) {
if err := p.reqmod.ModifyRequest(req); err != nil {
log.Errorf(req.Context(), "error modifying CONNECT request: %v", err)
p.warning(req.Header, err)
Expand All @@ -555,7 +556,7 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S
if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445
pr, pw := io.Pipe()
req.Body = pr
defer req.Body.Close()
defer closeBody(req.Context(), req.Body, &retErr)

// perform the HTTP roundtrip
res, cerr = p.roundTrip(ctx, req)
Expand Down Expand Up @@ -595,7 +596,7 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S
res = p.errorResponse(req, cerr)
p.warning(res.Header, cerr)
}
defer res.Body.Close()
defer closeBody(res.Request.Context(), res.Body, &retErr)

if err := p.resmod.ModifyResponse(res); err != nil {
log.Errorf(req.Context(), "error modifying CONNECT response: %v", err)
Expand Down Expand Up @@ -709,7 +710,7 @@ func copySync(ctx context.Context, name string, w io.Writer, r io.Reader, donec
donec <- true
}

func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) (retErr error) {
log.Debugf(context.TODO(), "waiting for request: %v", conn.RemoteAddr())

session := ctx.Session()
Expand All @@ -728,7 +729,7 @@ func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error
}
return errClose
}
defer req.Body.Close()
defer closeBody(req.Context(), req.Body, &retErr)

if p.Closing() {
return errClose
Expand Down Expand Up @@ -789,7 +790,7 @@ func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error
res = p.errorResponse(req, err)
p.warning(res.Header, err)
}
defer res.Body.Close()
defer closeBody(res.Request.Context(), res.Body, &retErr)

// set request to original request manually, res.Request may be changed in transport.
// see https://github.com/google/martian/issues/298
Expand Down Expand Up @@ -945,7 +946,9 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res

if res != nil {
if res.StatusCode/100 == 2 {
res.Body.Close()
if err := res.Body.Close(); err != nil {
log.Errorf(req.Context(), "error closing CONNECT response body: %v", err)
}
return proxyutil.NewResponse(200, http.NoBody, req), conn, nil
}

Expand Down Expand Up @@ -984,3 +987,10 @@ func upgradeType(h http.Header) string {
}
return h.Get("Upgrade")
}

func closeBody(ctx context.Context, body io.Closer, err *error) {
if cerr := body.Close(); cerr != nil {
log.Errorf(ctx, "error closing body: %v", cerr)
*err = multierr.Append(*err, cerr)
}
}

0 comments on commit 59633df

Please sign in to comment.