From dc0fa808666f5a510cb8cadc1137ef1cc7bed8d3 Mon Sep 17 00:00:00 2001 From: Alex Vanderpot Date: Mon, 22 Jun 2020 20:42:48 -0700 Subject: [PATCH] Allow setting ConnectDial on context --- ctx.go | 4 ++++ go.mod | 2 ++ https.go | 15 +++++++++------ websocket.go | 2 +- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ctx.go b/ctx.go index b372f7d4..afe9ab0c 100644 --- a/ctx.go +++ b/ctx.go @@ -2,6 +2,7 @@ package goproxy import ( "crypto/tls" + "net" "net/http" "regexp" ) @@ -13,7 +14,10 @@ type ProxyCtx struct { Req *http.Request // Will contain the remote server's response (if available. nil if the request wasn't send yet) Resp *http.Response + // Can be used to override the RoundTripper for a single request. RoundTripper RoundTripper + // Can be used to override the ConnectDial for a single request. + ConnectDial func(network, addr string) (net.Conn, error) // will contain the recent error that occurred while trying to send receive or parse traffic Error error // A handle for the user to keep data in the context, from the call of ReqHandler to the diff --git a/go.mod b/go.mod index 30554d58..2361ec15 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/elazarl/goproxy +go 1.14 + require github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2 diff --git a/https.go b/https.go index c9e76e9e..564b47d8 100644 --- a/https.go +++ b/https.go @@ -61,11 +61,14 @@ func (proxy *ProxyHttpServer) dial(network, addr string) (c net.Conn, err error) return net.Dial(network, addr) } -func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err error) { - if proxy.ConnectDial == nil { - return proxy.dial(network, addr) +func (proxy *ProxyHttpServer) connectDial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) { + if ctx != nil && ctx.ConnectDial != nil{ + return ctx.ConnectDial(network, addr) } - return proxy.ConnectDial(network, addr) + if proxy.ConnectDial != nil { + return proxy.ConnectDial(network, addr) + } + return proxy.dial(network, addr) } type halfClosable interface { @@ -106,7 +109,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request if !hasPort.MatchString(host) { host += ":80" } - targetSiteCon, err := proxy.connectDial("tcp", host) + targetSiteCon, err := proxy.connectDial(ctx, "tcp", host) if err != nil { httpError(proxyClient, ctx, err) return @@ -137,7 +140,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request case ConnectHTTPMitm: proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") - targetSiteCon, err := proxy.connectDial("tcp", host) + targetSiteCon, err := proxy.connectDial(ctx, "tcp", host) if err != nil { ctx.Warnf("Error dialing to %s: %s", host, err.Error()) return diff --git a/websocket.go b/websocket.go index 2a969911..522b88e3 100644 --- a/websocket.go +++ b/websocket.go @@ -49,7 +49,7 @@ func (proxy *ProxyHttpServer) serveWebsocketTLS(ctx *ProxyCtx, w http.ResponseWr func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) { targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path} - targetConn, err := proxy.connectDial("tcp", targetURL.Host) + targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host) if err != nil { ctx.Warnf("Error dialing target site: %v", err) return