diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100755 index 000000000..29eacf3c0 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +version: 2 +updates: + # GitHub Actions Pipeline + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + + # Go packages + - package-ecosystem: "gomod" + directories: + - / + - /ext + - /examples + schedule: + interval: "daily" diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 67809e6ee..80a56e72c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,30 +1,38 @@ -name: Go +name: Code Check on: - push: - branches: [ master ] + workflow_dispatch: pull_request: - branches: [ master ] + types: + - opened + - synchronize + - reopened jobs: - build: - name: Build + name: Build And Test Go code runs-on: ubuntu-latest + env: + CGO_ENABLED: 0 steps: + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.20' - - name: Set up Go 1.20 - uses: actions/setup-go@v1 - with: - go-version: 1.20 - id: go + - name: Check out code into the Go module directory + uses: actions/checkout@v4 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + # https://stackoverflow.com/questions/76269119/github-actions-go-lambda-project-different-sha256sums + - name: Build + run: go build -v -buildvcs=false ./... - - name: Get dependencies - run: | - go get -v -t -d ./... + - name: Test + run: go test -p 1 -v -shuffle=on ./... - - name: Test - run: go test -v ./... + - name: Linter + uses: golangci/golangci-lint-action@v6 + with: + version: latest + env: + GOFLAGS: "-buildvcs=false" diff --git a/.golangci.yml b/.golangci.yml new file mode 100755 index 000000000..1deff156e --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,162 @@ +run: + timeout: 5m + modules-download-mode: readonly + go: '1.20' + +# List from https://golangci-lint.run/usage/linters/ +linters: + enable: + # Default linters + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - unused + # Other linters + - asasalint + - asciicheck + - bidichk + - containedctx + - decorder + - dogsled + - dupl + - durationcheck + - errchkjson + - errname + - errorlint + - exhaustive + - fatcontext + - forbidigo + - forcetypeassert + - gci + - gocheckcompilerdirectives + - gochecksumtype + - gocritic + - godot + - gofmt + - gofumpt + - goheader + - gomodguard + - goprintffuncname + - gosec + - gosmopolitan + - grouper + - iface + - importas + - interfacebloat + - lll + - loggercheck + - makezero + - mirror + - misspell + - nakedret + - nilerr + - noctx + - nolintlint + - perfsprint + - prealloc + - predeclared + - reassign + - revive + - stylecheck + - tagalign + - tenv + - testableexamples + - testifylint + - testpackage + - thelper + - tparallel + - unconvert + - usestdlibvars + - wastedassign + - whitespace + - exportloopref + + disable: + - bodyclose + - canonicalheader + - contextcheck # Re-enable in V2 + - copyloopvar + - cyclop + - depguard + - dupword + - err113 + - exhaustruct + - funlen + - ginkgolinter + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocyclo + - godox + - goimports + - gomoddirectives + - inamedparam + - intrange + - ireturn + - maintidx + - mnd + - musttag + - nestif # TODO: Re-enable in V2 + - nilnil + - nlreturn + - nonamedreturns + - nosprintfhostport + - paralleltest + - promlinter + - protogetter + - rowserrcheck + - sloglint + - spancheck + - sqlclosecheck + - tagliatelle + - unparam + - varnamelen + - wrapcheck + - wsl + - zerologlint + +linters-settings: + gci: + sections: + - standard + - default + skip-generated: false + custom-order: true + gosec: + excludes: + - G402 # InsecureSkipVerify + - G102 # Binds to all network interfaces + - G403 # RSA keys should be at least 2048 bits + - G115 # Integer overflow conversion (uint64 -> int64) + - G404 # Use of weak random number generator (math/rand) + - G204 # Subprocess launched with a potential tainted input or cmd arguments + +issues: + exclude-rules: + - linters: + - gocritic + text: "ifElseChain" + - linters: + - lll + source: "^// " + - linters: + - revive + text: "add-constant: " + - linters: + - revive + text: "unused-parameter: " + - linters: + - revive + text: "empty-block: " + - linters: + - revive + text: "var-naming: " # TODO: Re-enable in V2 + - linters: + - stylecheck + text: " should be " # TODO: Re-enable in V2 + - linters: + - stylecheck + text: "ST1003: should not use ALL_CAPS in Go names; use CamelCase instead" # TODO: Re-enable in V2 diff --git a/README.md b/README.md index 89c7f0884..e73c33934 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,23 @@ we won't merge it... `:D` The code for this project is released under the `BSD 3-Clause` license, making it useful for `commercial` uses as well. +### Linter +The codebase uses an automatic lint check over your Pull Request code. +Before opening it, you should check if your changes respect it, by running +the linter in your local machine, so you won't have any surprise. + +To install the linter: +```sh +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +This will create an executable in your `$GOPATH/bin` folder +(`$GOPATH` is an environment variable, usually +its value is equivalent to `~/go`, check its value in your machine if you +aren't sure about it). +Make sure to include the bin folder in the path of your shell, to be able to +directly use the `golangci-lint run` command. + ## A taste of GoProxy To get a taste of `goproxy`, here you are a basic HTTP/HTTPS proxy diff --git a/actions.go b/actions.go index e1a3e7ff1..94eb90c1f 100644 --- a/actions.go +++ b/actions.go @@ -11,10 +11,10 @@ type ReqHandler interface { Handle(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) } -// A wrapper that would convert a function to a ReqHandler interface type +// A wrapper that would convert a function to a ReqHandler interface type. type FuncReqHandler func(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) -// FuncReqHandler.Handle(req,ctx) <=> FuncReqHandler(req,ctx) +// FuncReqHandler.Handle(req,ctx) <=> FuncReqHandler(req,ctx). func (f FuncReqHandler) Handle(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response) { return f(req, ctx) } @@ -22,15 +22,15 @@ func (f FuncReqHandler) Handle(req *http.Request, ctx *ProxyCtx) (*http.Request, // after the proxy have sent the request to the destination server, it will // "filter" the response through the RespHandlers it has. // The proxy server will send to the client the response returned by the RespHandler. -// In case of error, resp will be nil, and ctx.RoundTrip.Error will contain the error +// In case of error, resp will be nil, and ctx.RoundTrip.Error will contain the error. type RespHandler interface { Handle(resp *http.Response, ctx *ProxyCtx) *http.Response } -// A wrapper that would convert a function to a RespHandler interface type +// A wrapper that would convert a function to a RespHandler interface type. type FuncRespHandler func(resp *http.Response, ctx *ProxyCtx) *http.Response -// FuncRespHandler.Handle(req,ctx) <=> FuncRespHandler(req,ctx) +// FuncRespHandler.Handle(req,ctx) <=> FuncRespHandler(req,ctx). func (f FuncRespHandler) Handle(resp *http.Response, ctx *ProxyCtx) *http.Response { return f(resp, ctx) } @@ -43,15 +43,15 @@ func (f FuncRespHandler) Handle(resp *http.Response, ctx *ProxyCtx) *http.Respon // send back and forth all messages from the server to the client and vice versa. // The request and responses sent in this Man In the Middle channel are filtered // through the usual flow (request and response filtered through the ReqHandlers -// and RespHandlers) +// and RespHandlers). type HttpsHandler interface { HandleConnect(req string, ctx *ProxyCtx) (*ConnectAction, string) } -// A wrapper that would convert a function to a HttpsHandler interface type +// A wrapper that would convert a function to a HttpsHandler interface type. type FuncHttpsHandler func(host string, ctx *ProxyCtx) (*ConnectAction, string) -// FuncHttpsHandler should implement the RespHandler interface +// FuncHttpsHandler should implement the RespHandler interface. func (f FuncHttpsHandler) HandleConnect(host string, ctx *ProxyCtx) (*ConnectAction, string) { return f(host, ctx) } diff --git a/chunked.go b/chunked.go index 83654f658..8b34b684e 100644 --- a/chunked.go +++ b/chunked.go @@ -28,9 +28,8 @@ type chunkedWriter struct { // Write the contents of data as one chunk to Wire. // NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has -// a bug since it does not check for success of io.WriteString +// a bug since it does not check for success of io.WriteString. func (cw *chunkedWriter) Write(data []byte) (n int, err error) { - // Don't send 0-length data. It looks like EOF for chunked encoding. if len(data) == 0 { return 0, nil @@ -42,15 +41,14 @@ func (cw *chunkedWriter) Write(data []byte) (n int, err error) { return 0, err } if n, err = cw.Wire.Write(data); err != nil { - return + return n, err } if n != len(data) { err = io.ErrShortWrite - return + return n, err } _, err = io.WriteString(cw.Wire, "\r\n") - - return + return n, err } func (cw *chunkedWriter) Close() error { diff --git a/dispatcher.go b/dispatcher.go index 34af5d788..9161fa06f 100644 --- a/dispatcher.go +++ b/dispatcher.go @@ -10,7 +10,7 @@ import ( ) // ReqCondition.HandleReq will decide whether or not to use the ReqHandler on an HTTP request -// before sending it to the remote server +// before sending it to the remote server. type ReqCondition interface { RespCondition HandleReq(req *http.Request, ctx *ProxyCtx) bool @@ -23,10 +23,10 @@ type RespCondition interface { HandleResp(resp *http.Response, ctx *ProxyCtx) bool } -// ReqConditionFunc.HandleReq(req,ctx) <=> ReqConditionFunc(req,ctx) +// ReqConditionFunc.HandleReq(req,ctx) <=> ReqConditionFunc(req,ctx). type ReqConditionFunc func(req *http.Request, ctx *ProxyCtx) bool -// RespConditionFunc.HandleResp(resp,ctx) <=> RespConditionFunc(resp,ctx) +// RespConditionFunc.HandleResp(resp,ctx) <=> RespConditionFunc(resp,ctx). type RespConditionFunc func(resp *http.Response, ctx *ProxyCtx) bool func (c ReqConditionFunc) HandleReq(req *http.Request, ctx *ProxyCtx) bool { @@ -93,7 +93,7 @@ func ReqHostMatches(regexps ...*regexp.Regexp) ReqConditionFunc { } // ReqHostIs returns a ReqCondition, testing whether the host to which the request is directed to equal -// to one of the given strings +// to one of the given strings. func ReqHostIs(hosts ...string) ReqConditionFunc { hostSet := make(map[string]bool) for _, h := range hosts { @@ -124,7 +124,7 @@ var IsLocalHost ReqConditionFunc = func(req *http.Request, ctx *ProxyCtx) bool { } // UrlMatches returns a ReqCondition testing whether the destination URL -// of the request matches the given regexp, with or without prefix +// of the request matches the given regexp, with or without prefix. func UrlMatches(re *regexp.Regexp) ReqConditionFunc { return func(req *http.Request, ctx *ProxyCtx) bool { return re.MatchString(req.URL.Path) || @@ -132,7 +132,7 @@ func UrlMatches(re *regexp.Regexp) ReqConditionFunc { } } -// DstHostIs returns a ReqCondition testing wether the host in the request url is the given string +// DstHostIs returns a ReqCondition testing wether the host in the request url is the given string. func DstHostIs(host string) ReqConditionFunc { host = strings.ToLower(host) return func(req *http.Request, ctx *ProxyCtx) bool { @@ -140,7 +140,7 @@ func DstHostIs(host string) ReqConditionFunc { } } -// SrcIpIs returns a ReqCondition testing whether the source IP of the request is one of the given strings +// SrcIpIs returns a ReqCondition testing whether the source IP of the request is one of the given strings. func SrcIpIs(ips ...string) ReqCondition { return ReqConditionFunc(func(req *http.Request, ctx *ProxyCtx) bool { for _, ip := range ips { @@ -152,7 +152,7 @@ func SrcIpIs(ips ...string) ReqCondition { }) } -// Not returns a ReqCondition negating the given ReqCondition +// Not returns a ReqCondition negating the given ReqCondition. func Not(r ReqCondition) ReqConditionFunc { return func(req *http.Request, ctx *ProxyCtx) bool { return !r.HandleReq(req, ctx) @@ -178,7 +178,7 @@ func ContentTypeIs(typ string, types ...string) RespCondition { } // StatusCodeIs returns a RespCondition, testing whether or not the HTTP status -// code is one of the given ints +// code is one of the given ints. func StatusCodeIs(codes ...int) RespCondition { codeSet := make(map[int]bool) for _, c := range codes { @@ -203,14 +203,15 @@ func (proxy *ProxyHttpServer) OnRequest(conds ...ReqCondition) *ReqProxyConds { return &ReqProxyConds{proxy, conds} } -// ReqProxyConds aggregate ReqConditions for a ProxyHttpServer. Upon calling Do, it will register a ReqHandler that would +// ReqProxyConds aggregate ReqConditions for a ProxyHttpServer. +// Upon calling Do, it will register a ReqHandler that would // handle the request if all conditions on the HTTP request are met. type ReqProxyConds struct { proxy *ProxyHttpServer reqConds []ReqCondition } -// DoFunc is equivalent to proxy.OnRequest().Do(FuncReqHandler(f)) +// DoFunc is equivalent to proxy.OnRequest().Do(FuncReqHandler(f)). func (pcond *ReqProxyConds) DoFunc(f func(req *http.Request, ctx *ProxyCtx) (*http.Request, *http.Response)) { pcond.Do(FuncReqHandler(f)) } @@ -297,7 +298,7 @@ type ProxyConds struct { respCond []RespCondition } -// ProxyConds.DoFunc is equivalent to proxy.OnResponse().Do(FuncRespHandler(f)) +// ProxyConds.DoFunc is equivalent to proxy.OnResponse().Do(FuncRespHandler(f)). func (pcond *ProxyConds) DoFunc(f func(resp *http.Response, ctx *ProxyCtx) *http.Response) { pcond.Do(FuncRespHandler(f)) } diff --git a/dispatcher_test.go b/dispatcher_test.go index a08cb5b5e..c78d88eae 100644 --- a/dispatcher_test.go +++ b/dispatcher_test.go @@ -1,10 +1,13 @@ -package goproxy +package goproxy_test import ( + "context" "net" "net/http" "strings" "testing" + + "github.com/elazarl/goproxy" ) func TestIsLocalHost(t *testing.T) { @@ -34,11 +37,11 @@ func TestIsLocalHost(t *testing.T) { addr = net.JoinHostPort(host, port) } t.Run(addr, func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, "http://"+addr, http.NoBody) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://"+addr, http.NoBody) if err != nil { t.Fatal(err) } - if !IsLocalHost(req, nil) { + if !goproxy.IsLocalHost(req, nil) { t.Fatal("expected true") } }) diff --git a/doc.go b/doc.go index 6f44317b9..28b1ba07b 100644 --- a/doc.go +++ b/doc.go @@ -23,7 +23,7 @@ Adding a header to each request return r, nil }) -Note that the function is called before the proxy sends the request to the server +# Note that the function is called before the proxy sends the request to the server For printing the content type of all incoming responses @@ -60,7 +60,9 @@ Finally, we have convenience function to throw a quick response proxy.OnResponse(hasGoProxyHeader).DoFunc(func(r*http.Response,ctx *goproxy.ProxyCtx)*http.Response { r.Body.Close() - return goproxy.NewResponse(ctx.Req, goproxy.ContentTypeText, http.StatusForbidden, "Can't see response with X-GoProxy header!") + return goproxy.NewResponse( + ctx.Req, goproxy.ContentTypeText, http.StatusForbidden, "Can't see response with X-GoProxy header!" + ) }) we close the body of the original response, and return a new 403 response with a short message. @@ -95,6 +97,5 @@ Will warn if multiple versions of jquery are used in the same domain. 6. https://github.com/elazarl/goproxy/blob/master/examples/goproxy-upside-down-ternet/ Modifies image files in an HTTP response via goproxy's image extension found in ext/. - */ package goproxy diff --git a/h2.go b/h2.go index 7c0f35790..6d50948ee 100644 --- a/h2.go +++ b/h2.go @@ -12,6 +12,8 @@ import ( "golang.org/x/net/http2" ) +var ErrInvalidH2Frame = errors.New("invalid H2 frame") + // H2Transport is an implementation of RoundTripper that abstracts an entire // HTTP/2 session, sending all client frames to the server and responses back // to the client. @@ -25,10 +27,10 @@ type H2Transport struct { // RoundTrip executes an HTTP/2 session (including all contained streams). // The request and response are ignored but any error encountered during the // proxying from the session is returned as a result of the invocation. -func (r *H2Transport) RoundTrip(prefaceReq *http.Request) (*http.Response, error) { +func (r *H2Transport) RoundTrip(_ *http.Request) (*http.Response, error) { raddr := r.Host if !strings.Contains(raddr, ":") { - raddr = raddr + ":443" + raddr += ":443" } rawServerTLS, err := dial("tcp", raddr) if err != nil { @@ -39,11 +41,15 @@ func (r *H2Transport) RoundTrip(prefaceReq *http.Request) (*http.Response, error r.TLSConfig.NextProtos = []string{http2.NextProtoTLS} // Initiate TLS and check remote host name against certificate. rawServerTLS = tls.Client(rawServerTLS, r.TLSConfig) - if err = rawServerTLS.(*tls.Conn).Handshake(); err != nil { + rawTLSConn, ok := rawServerTLS.(*tls.Conn) + if !ok { + return nil, errors.New("invalid TLS connection") + } + if err = rawTLSConn.Handshake(); err != nil { return nil, err } if r.TLSConfig == nil || !r.TLSConfig.InsecureSkipVerify { - if err = rawServerTLS.(*tls.Conn).VerifyHostname(raddr[:strings.LastIndex(raddr, ":")]); err != nil { + if err = rawTLSConn.VerifyHostname(raddr[:strings.LastIndex(raddr, ":")]); err != nil { return nil, err } } @@ -75,11 +81,11 @@ func (r *H2Transport) RoundTrip(prefaceReq *http.Request) (*http.Response, error for i := 0; i < 2; i++ { select { case err := <-errSToC: - if err != io.EOF { + if !errors.Is(err, io.EOF) { return nil, err } case err := <-errCToS: - if err != io.EOF { + if !errors.Is(err, io.EOF) { return nil, err } } @@ -105,14 +111,20 @@ func proxyFrame(fr *http2.Framer) error { } switch f.Header().Type { case http2.FrameData: - tf := f.(*http2.DataFrame) + tf, ok := f.(*http2.DataFrame) + if !ok { + return ErrInvalidH2Frame + } terr := fr.WriteData(tf.StreamID, tf.StreamEnded(), tf.Data()) if terr == nil && tf.StreamEnded() { terr = io.EOF } return terr case http2.FrameHeaders: - tf := f.(*http2.HeadersFrame) + tf, ok := f.(*http2.HeadersFrame) + if !ok { + return ErrInvalidH2Frame + } terr := fr.WriteHeaders(http2.HeadersFrameParam{ StreamID: tf.StreamID, BlockFragment: tf.HeaderBlockFragment(), @@ -126,19 +138,34 @@ func proxyFrame(fr *http2.Framer) error { } return terr case http2.FrameContinuation: - tf := f.(*http2.ContinuationFrame) + tf, ok := f.(*http2.ContinuationFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WriteContinuation(tf.StreamID, tf.HeadersEnded(), tf.HeaderBlockFragment()) case http2.FrameGoAway: - tf := f.(*http2.GoAwayFrame) + tf, ok := f.(*http2.GoAwayFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WriteGoAway(tf.StreamID, tf.ErrCode, tf.DebugData()) case http2.FramePing: - tf := f.(*http2.PingFrame) + tf, ok := f.(*http2.PingFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WritePing(tf.IsAck(), tf.Data) case http2.FrameRSTStream: - tf := f.(*http2.RSTStreamFrame) + tf, ok := f.(*http2.RSTStreamFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WriteRSTStream(tf.StreamID, tf.ErrCode) case http2.FrameSettings: - tf := f.(*http2.SettingsFrame) + tf, ok := f.(*http2.SettingsFrame) + if !ok { + return ErrInvalidH2Frame + } if tf.IsAck() { return fr.WriteSettingsAck() } @@ -151,13 +178,22 @@ func proxyFrame(fr *http2.Framer) error { } return fr.WriteSettings(settings...) case http2.FrameWindowUpdate: - tf := f.(*http2.WindowUpdateFrame) + tf, ok := f.(*http2.WindowUpdateFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WriteWindowUpdate(tf.StreamID, tf.Increment) case http2.FramePriority: - tf := f.(*http2.PriorityFrame) + tf, ok := f.(*http2.PriorityFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WritePriority(tf.StreamID, tf.PriorityParam) case http2.FramePushPromise: - tf := f.(*http2.PushPromiseFrame) + tf, ok := f.(*http2.PushPromiseFrame) + if !ok { + return ErrInvalidH2Frame + } return fr.WritePushPromise(http2.PushPromiseParam{ StreamID: tf.StreamID, PromiseID: tf.PromiseID, diff --git a/http.go b/http.go index 8d426dc94..63050270e 100644 --- a/http.go +++ b/http.go @@ -51,11 +51,11 @@ func (proxy *ProxyHttpServer) handleHttp(w http.ResponseWriter, r *http.Request) if ctx.Error != nil { errorString = "error read response " + r.URL.Host + " : " + ctx.Error.Error() ctx.Logf(errorString) - http.Error(w, ctx.Error.Error(), 500) + http.Error(w, ctx.Error.Error(), http.StatusInternalServerError) } else { errorString = "error read response " + r.URL.Host ctx.Logf(errorString) - http.Error(w, errorString, 500) + http.Error(w, errorString, http.StatusInternalServerError) } return } diff --git a/https.go b/https.go index 11fa47478..c7b127496 100644 --- a/https.go +++ b/https.go @@ -2,6 +2,7 @@ package goproxy import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -13,6 +14,8 @@ import ( "strconv" "strings" "sync/atomic" + + "github.com/elazarl/goproxy/internal/signer" ) type ConnectActionLiteral int @@ -33,10 +36,12 @@ var ( RejectConnect = &ConnectAction{Action: ConnectReject, TLSConfig: TLSConfigFromCA(&GoproxyCa)} ) +var _errorRespMaxLength int64 = 500 + // ConnectAction enables the caller to override the standard connect flow. // When Action is ConnectHijack, it is up to the implementer to send the // HTTP 200, or any other valid http response back to the client from within the -// Hijack func +// Hijack func. type ConnectAction struct { Action ConnectActionLiteral Hijack func(req *http.Request, client net.Conn, ctx *ProxyCtx) @@ -46,9 +51,8 @@ type ConnectAction struct { func stripPort(s string) string { var ix int if strings.Contains(s, "[") && strings.Contains(s, "]") { - //ipv6 : for example : [2606:4700:4700::1111]:443 - - //strip '[' and ']' + // ipv6 address example: [2606:4700:4700::1111]:443 + // strip '[' and ']' s = strings.ReplaceAll(s, "[", "") s = strings.ReplaceAll(s, "]", "") @@ -57,26 +61,25 @@ func stripPort(s string) string { return s } } else { - //ipv4 + // ipv4 ix = strings.IndexRune(s, ':') if ix == -1 { return s } - } return s[:ix] } -func (proxy *ProxyHttpServer) dial(network, addr string) (c net.Conn, err error) { - if proxy.Tr.Dial != nil { - return proxy.Tr.Dial(network, addr) +func (proxy *ProxyHttpServer) dial(ctx context.Context, network, addr string) (c net.Conn, err error) { + if proxy.Tr.DialContext != nil { + return proxy.Tr.DialContext(ctx, network, addr) } return net.Dial(network, addr) } func (proxy *ProxyHttpServer) connectDial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) { if proxy.ConnectDialWithReq == nil && proxy.ConnectDial == nil { - return proxy.dial(network, addr) + return proxy.dial(ctx.Req.Context(), network, addr) } if proxy.ConnectDialWithReq != nil { @@ -131,7 +134,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request return } ctx.Logf("Accepting CONNECT to %s", host) - proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n")) targetTCP, targetOK := targetSiteCon.(halfClosable) proxyClientTCP, clientOK := proxyClient.(halfClosable) @@ -147,7 +150,8 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request // by the use of a closed network connection. // // 2020/05/28 23:42:17 [001] WARN: Error copying to client: read tcp 127.0.0.1:33742->127.0.0.1:34763: i/o timeout - // 2020/05/28 23:42:17 [001] WARN: Error copying to client: read tcp 127.0.0.1:45145->127.0.0.1:60494: use of closed network connection + // 2020/05/28 23:42:17 [001] WARN: Error copying to client: read tcp 127.0.0.1:45145->127.0.0.1:60494: use of closed + // network connection // // It's also not possible to synchronize these connection closures due to // TCP connections which are half-closed. When this happens, only the one @@ -171,7 +175,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request case ConnectHijack: todo.Hijack(r, proxyClient, ctx) case ConnectHTTPMitm: - proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it") var targetSiteCon net.Conn @@ -180,7 +184,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request for { client := bufio.NewReader(proxyClient) req, err := http.ReadRequest(client) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) } if err != nil { @@ -217,7 +221,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request } } case ConnectMitm: - proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) ctx.Logf("Assuming CONNECT is TLS, mitm proxying it") // this goes in a separate goroutine, so that the net/http server won't think we're // still handling the request even after hijacking the connection. Those HTTP CONNECT @@ -233,7 +237,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request } } go func() { - //TODO: cache connections to the remote website + // TODO: cache connections to the remote website rawClientTls := tls.Server(proxyClient, tlsConfig) defer rawClientTls.Close() if err := rawClientTls.Handshake(); err != nil { @@ -241,17 +245,26 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request return } clientTlsReader := bufio.NewReader(rawClientTls) - for !isEof(clientTlsReader) { + for !isEOF(clientTlsReader) { req, err := http.ReadRequest(clientTlsReader) - var ctx = &ProxyCtx{Req: req, Session: atomic.AddInt64(&proxy.sess, 1), Proxy: proxy, UserData: ctx.UserData, RoundTripper: ctx.RoundTripper} - if err != nil && err != io.EOF { + ctx := &ProxyCtx{ + Req: req, + Session: atomic.AddInt64(&proxy.sess, 1), + Proxy: proxy, + UserData: ctx.UserData, + RoundTripper: ctx.RoundTripper, + } + if err != nil && !errors.Is(err, io.EOF) { return } if err != nil { ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) return } - req.RemoteAddr = r.RemoteAddr // since we're converting the request, need to carry over the original connecting IP as well + + // since we're converting the request, need to carry over the + // original connecting IP as well + req.RemoteAddr = r.RemoteAddr ctx.Logf("req %v", r.Host) if !strings.HasPrefix(req.URL.String(), "https://") { @@ -329,9 +342,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request text := resp.Status statusCode := strconv.Itoa(resp.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } + text = strings.TrimPrefix(text, statusCode) // always use 1.1 to support chunked encoding if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil { ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) @@ -392,7 +403,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Logf("Exiting on EOF") }() case ConnectProxyAuthHijack: - proxyClient.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n")) + _, _ = proxyClient.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n")) todo.Hijack(r, proxyClient, ctx) case ConnectReject: if ctx.Resp != nil { @@ -400,7 +411,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Warnf("Cannot write response that reject http CONNECT: %v", err) } } - proxyClient.Close() + _ = proxyClient.Close() } } @@ -408,7 +419,12 @@ func httpError(w io.WriteCloser, ctx *ProxyCtx, err error) { if ctx.Proxy.ConnectionErrHandler != nil { ctx.Proxy.ConnectionErrHandler(w, ctx, err) } else { - errStr := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(err.Error()), err.Error()) + errorMessage := err.Error() + errStr := fmt.Sprintf( + "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", + len(errorMessage), + errorMessage, + ) if _, err := io.WriteString(w, errStr); err != nil { ctx.Warnf("Error responding to client: %s", err) } @@ -435,27 +451,30 @@ func copyAndClose(ctx *ProxyCtx, dst, src halfClosable) { ctx.Warnf("Error copying to client: %s", err.Error()) } - dst.CloseWrite() - src.CloseRead() + _ = dst.CloseWrite() + _ = src.CloseRead() } func dialerFromEnv(proxy *ProxyHttpServer) func(network, addr string) (net.Conn, error) { - https_proxy := os.Getenv("HTTPS_PROXY") - if https_proxy == "" { - https_proxy = os.Getenv("https_proxy") + httpsProxy := os.Getenv("HTTPS_PROXY") + if httpsProxy == "" { + httpsProxy = os.Getenv("https_proxy") } - if https_proxy == "" { + if httpsProxy == "" { return nil } - return proxy.NewConnectDialToProxy(https_proxy) + return proxy.NewConnectDialToProxy(httpsProxy) } -func (proxy *ProxyHttpServer) NewConnectDialToProxy(https_proxy string) func(network, addr string) (net.Conn, error) { - return proxy.NewConnectDialToProxyWithHandler(https_proxy, nil) +func (proxy *ProxyHttpServer) NewConnectDialToProxy(httpsProxy string) func(network, addr string) (net.Conn, error) { + return proxy.NewConnectDialToProxyWithHandler(httpsProxy, nil) } -func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy string, connectReqHandler func(req *http.Request)) func(network, addr string) (net.Conn, error) { - u, err := url.Parse(https_proxy) +func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( + httpsProxy string, + connectReqHandler func(req *http.Request), +) func(network, addr string) (net.Conn, error) { + u, err := url.Parse(httpsProxy) if err != nil { return nil } @@ -473,27 +492,27 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin if connectReqHandler != nil { connectReqHandler(connectReq) } - c, err := proxy.dial(network, u.Host) + c, err := proxy.dial(context.Background(), network, u.Host) if err != nil { return nil, err } - connectReq.Write(c) + _ = connectReq.Write(c) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(c) resp, err := http.ReadResponse(br, connectReq) if err != nil { - c.Close() + _ = c.Close() return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - resp, err := io.ReadAll(io.LimitReader(resp.Body, 500)) + resp, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength)) if err != nil { return nil, err } - c.Close() + _ = c.Close() return nil, errors.New("proxy refused connection" + string(resp)) } return c, nil @@ -504,7 +523,7 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin u.Host += ":443" } return func(network, addr string) (net.Conn, error) { - c, err := proxy.dial(network, u.Host) + c, err := proxy.dial(context.Background(), network, u.Host) if err != nil { return nil, err } @@ -518,23 +537,23 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin if connectReqHandler != nil { connectReqHandler(connectReq) } - connectReq.Write(c) + _ = connectReq.Write(c) // Read response. // Okay to use and discard buffered reader here, because // TLS server will not speak until spoken to. br := bufio.NewReader(c) resp, err := http.ReadResponse(br, connectReq) if err != nil { - c.Close() + _ = c.Close() return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(io.LimitReader(resp.Body, 500)) + body, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength)) if err != nil { return nil, err } - c.Close() + _ = c.Close() return nil, errors.New("proxy refused connection" + string(body)) } return c, nil @@ -553,7 +572,7 @@ func TLSConfigFromCA(ca *tls.Certificate) func(host string, ctx *ProxyCtx) (*tls ctx.Logf("signing for %s", stripPort(host)) genCert := func() (*tls.Certificate, error) { - return signHost(*ca, []string{hostname}) + return signer.SignHost(*ca, []string{hostname}) } if ctx.certStore != nil { cert, err = ctx.certStore.Fetch(hostname, genCert) diff --git a/counterecryptor.go b/internal/signer/counterecryptor.go similarity index 92% rename from counterecryptor.go rename to internal/signer/counterecryptor.go index ef8511d1d..acb9925e7 100644 --- a/counterecryptor.go +++ b/internal/signer/counterecryptor.go @@ -1,4 +1,4 @@ -package goproxy +package signer import ( "crypto/aes" @@ -32,12 +32,11 @@ func NewCounterEncryptorRandFromKey(key any, seed []byte) (r CounterEncryptorRan return } default: - err = errors.New("only RSA, ED25519 and ECDSA keys supported") - return + return r, errors.New("only RSA, ED25519 and ECDSA keys supported") } h := sha256.New() if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { - return + return r, err } r.counter = make([]byte, r.cipher.BlockSize()) if seed != nil { @@ -45,7 +44,7 @@ func NewCounterEncryptorRandFromKey(key any, seed []byte) (r CounterEncryptorRan } r.rand = make([]byte, r.cipher.BlockSize()) r.ix = len(r.rand) - return + return r, nil } func (c *CounterEncryptorRand) Seed(b []byte) { diff --git a/counterecryptor_test.go b/internal/signer/counterecryptor_test.go similarity index 58% rename from counterecryptor_test.go rename to internal/signer/counterecryptor_test.go index 7362180c0..3cd35d2cd 100644 --- a/counterecryptor_test.go +++ b/internal/signer/counterecryptor_test.go @@ -1,4 +1,4 @@ -package goproxy_test +package signer_test import ( "bytes" @@ -9,7 +9,7 @@ import ( "math/rand" "testing" - "github.com/elazarl/goproxy" + "github.com/elazarl/goproxy/internal/signer" ) type RandSeedReader struct { @@ -23,15 +23,22 @@ func (r *RandSeedReader) Read(b []byte) (n int, err error) { return len(b), nil } +func fatalOnErr(t *testing.T, err error, msg string) { + t.Helper() + if err != nil { + t.Fatal(msg, err) + } +} + func TestCounterEncDifferentConsecutive(t *testing.T) { k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) - fatalOnErr(err, "rsa.GenerateKey", t) - c, err := goproxy.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) - fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + fatalOnErr(t, err, "rsa.GenerateKey") + c, err := signer.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(t, err, "NewCounterEncryptorRandFromKey") for i := 0; i < 100*1000; i++ { var a, b int64 - binary.Read(&c, binary.BigEndian, &a) - binary.Read(&c, binary.BigEndian, &b) + fatalOnErr(t, binary.Read(&c, binary.BigEndian, &a), "read a") + fatalOnErr(t, binary.Read(&c, binary.BigEndian, &b), "read b") if a == b { t.Fatal("two consecutive equal int64", a, b) } @@ -40,23 +47,22 @@ func TestCounterEncDifferentConsecutive(t *testing.T) { func TestCounterEncIdenticalStreams(t *testing.T) { k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) - fatalOnErr(err, "rsa.GenerateKey", t) - c1, err := goproxy.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) - fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) - c2, err := goproxy.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) - fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) - nout := 1000 - out1, out2 := make([]byte, nout), make([]byte, nout) - io.ReadFull(&c1, out1) - tmp := out2[:] - rand.Seed(0xFF43109) + fatalOnErr(t, err, "rsa.GenerateKey") + c1, err := signer.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(t, err, "NewCounterEncryptorRandFromKey") + c2, err := signer.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(t, err, "NewCounterEncryptorRandFromKey") + const nOut = 1000 + out1, out2 := make([]byte, nOut), make([]byte, nOut) + _, _ = io.ReadFull(&c1, out1) + tmp := out2 for len(tmp) > 0 { n := 1 + rand.Intn(256) if n > len(tmp) { n = len(tmp) } n, err := c2.Read(tmp[:n]) - fatalOnErr(err, "CounterEncryptorRand.Read", t) + fatalOnErr(t, err, "CounterEncryptorRand.Read") tmp = tmp[n:] } if !bytes.Equal(out1, out2) { @@ -65,24 +71,24 @@ func TestCounterEncIdenticalStreams(t *testing.T) { } func stddev(data []int) float64 { - var sum, sum_sqr float64 = 0, 0 + var sum, sumSqr float64 = 0, 0 for _, h := range data { sum += float64(h) - sum_sqr += float64(h) * float64(h) + sumSqr += float64(h) * float64(h) } n := float64(len(data)) - variance := (sum_sqr - ((sum * sum) / n)) / (n - 1) + variance := (sumSqr - ((sum * sum) / n)) / (n - 1) return math.Sqrt(variance) } func TestCounterEncStreamHistogram(t *testing.T) { k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) - fatalOnErr(err, "rsa.GenerateKey", t) - c, err := goproxy.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) - fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + fatalOnErr(t, err, "rsa.GenerateKey") + c, err := signer.NewCounterEncryptorRandFromKey(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(t, err, "NewCounterEncryptorRandFromKey") nout := 100 * 1000 out := make([]byte, nout) - io.ReadFull(&c, out) + _, _ = io.ReadFull(&c, out) refhist := make([]int, 512) for i := 0; i < nout; i++ { refhist[rand.Intn(256)]++ diff --git a/signer.go b/internal/signer/signer.go similarity index 80% rename from signer.go rename to internal/signer/signer.go index e8704c8d9..f32eeb5a3 100644 --- a/signer.go +++ b/internal/signer/signer.go @@ -1,4 +1,4 @@ -package goproxy +package signer import ( "crypto" @@ -6,7 +6,7 @@ import ( "crypto/ed25519" "crypto/elliptic" "crypto/rsa" - "crypto/sha1" + "crypto/sha256" "crypto/tls" "crypto/x509" "fmt" @@ -18,26 +18,20 @@ import ( "time" ) +const _goproxySignerVersion = ":goproxy2" + func hashSorted(lst []string) []byte { c := make([]string, len(lst)) copy(c, lst) sort.Strings(c) - h := sha1.New() + h := sha256.New() for _, s := range c { h.Write([]byte(s + ",")) } return h.Sum(nil) } -func hashSortedBigInt(lst []string) *big.Int { - rv := new(big.Int) - rv.SetBytes(hashSorted(lst)) - return rv -} - -var goproxySignerVersion = ":goroxy1" - -func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) { +func SignHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) { // Use the provided CA for certificate generation. // Use already parsed Leaf certificate when present. x509ca := ca.Leaf @@ -47,8 +41,9 @@ func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err er } } - start := time.Unix(time.Now().Unix()-2592000, 0) // 2592000 = 30 day - end := time.Unix(time.Now().Unix()+31536000, 0) // 31536000 = 365 day + now := time.Now() + start := now.Add(-30 * 24 * time.Hour) // -30 days + end := now.Add(365 * 24 * time.Hour) // 365 days // Always generate a positive int value // (Two complement is not enabled when the first bit is 0) @@ -75,28 +70,28 @@ func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err er } } - hash := hashSorted(append(hosts, goproxySignerVersion, ":"+runtime.Version())) + hash := hashSorted(append(hosts, _goproxySignerVersion, ":"+runtime.Version())) var csprng CounterEncryptorRand if csprng, err = NewCounterEncryptorRandFromKey(ca.PrivateKey, hash); err != nil { - return + return nil, err } var certpriv crypto.Signer switch ca.PrivateKey.(type) { case *rsa.PrivateKey: if certpriv, err = rsa.GenerateKey(&csprng, 2048); err != nil { - return + return nil, err } case *ecdsa.PrivateKey: if certpriv, err = ecdsa.GenerateKey(elliptic.P256(), &csprng); err != nil { - return + return nil, err } case ed25519.PrivateKey: if _, certpriv, err = ed25519.GenerateKey(&csprng); err != nil { - return + return nil, err } default: - err = fmt.Errorf("unsupported key type %T", ca.PrivateKey) + return nil, fmt.Errorf("unsupported key type %T", ca.PrivateKey) } derBytes, err := x509.CreateCertificate(&csprng, &template, x509ca, certpriv.Public(), ca.PrivateKey) diff --git a/signer_test.go b/internal/signer/signer_test.go similarity index 63% rename from signer_test.go rename to internal/signer/signer_test.go index d584b52cd..f5b40d4fb 100644 --- a/signer_test.go +++ b/internal/signer/signer_test.go @@ -1,6 +1,7 @@ -package goproxy +package signer_test import ( + "context" "crypto/tls" "crypto/x509" "io" @@ -11,9 +12,13 @@ import ( "strings" "testing" "time" + + "github.com/elazarl/goproxy" + "github.com/elazarl/goproxy/internal/signer" ) -func orFatal(msg string, err error, t *testing.T) { +func orFatal(t *testing.T, msg string, err error) { + t.Helper() if err != nil { t.Fatal(msg, err) } @@ -21,8 +26,8 @@ func orFatal(msg string, err error, t *testing.T) { type ConstantHanlder string -func (h ConstantHanlder) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(h)) +func (h ConstantHanlder) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, string(h)) } func getBrowser(args []string) string { @@ -38,100 +43,99 @@ func getBrowser(args []string) string { } func testSignerX509(t *testing.T, ca tls.Certificate) { - cert, err := signHost(ca, []string{"example.com", "1.1.1.1", "localhost"}) - orFatal("singHost", err, t) + t.Helper() + cert, err := signer.SignHost(ca, []string{"example.com", "1.1.1.1", "localhost"}) + orFatal(t, "singHost", err) cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) - orFatal("ParseCertificate", err, t) + orFatal(t, "ParseCertificate", err) certpool := x509.NewCertPool() certpool.AddCert(ca.Leaf) - orFatal("VerifyHostname", cert.Leaf.VerifyHostname("example.com"), t) - orFatal("CheckSignatureFrom", cert.Leaf.CheckSignatureFrom(ca.Leaf), t) + orFatal(t, "VerifyHostname", cert.Leaf.VerifyHostname("example.com")) + orFatal(t, "CheckSignatureFrom", cert.Leaf.CheckSignatureFrom(ca.Leaf)) _, err = cert.Leaf.Verify(x509.VerifyOptions{ DNSName: "example.com", Roots: certpool, }) - orFatal("Verify", err, t) + orFatal(t, "Verify", err) } -func testSignerTls(t *testing.T, ca tls.Certificate) { - cert, err := signHost(ca, []string{"example.com", "1.1.1.1", "localhost"}) - orFatal("singHost", err, t) +func testSignerTLS(t *testing.T, ca tls.Certificate) { + t.Helper() + cert, err := signer.SignHost(ca, []string{"example.com", "1.1.1.1", "localhost"}) + orFatal(t, "singHost", err) cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) - orFatal("ParseCertificate", err, t) + orFatal(t, "ParseCertificate", err) expected := "key verifies with Go" server := httptest.NewUnstartedServer(ConstantHanlder(expected)) defer server.Close() - server.TLS = &tls.Config{Certificates: []tls.Certificate{*cert, ca}} - server.TLS.BuildNameToCertificate() + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{*cert, ca}, + MinVersion: tls.VersionTLS12, + } server.StartTLS() certpool := x509.NewCertPool() certpool.AddCert(ca.Leaf) tr := &http.Transport{ TLSClientConfig: &tls.Config{RootCAs: certpool}, } - asLocalhost := strings.Replace(server.URL, "127.0.0.1", "localhost", -1) - req, err := http.NewRequest(http.MethodGet, asLocalhost, nil) - orFatal("NewRequest", err, t) + asLocalhost := strings.ReplaceAll(server.URL, "127.0.0.1", "localhost") + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, asLocalhost, nil) + orFatal(t, "NewRequest", err) resp, err := tr.RoundTrip(req) - orFatal("RoundTrip", err, t) + orFatal(t, "RoundTrip", err) txt, err := io.ReadAll(resp.Body) - orFatal("io.ReadAll", err, t) + orFatal(t, "io.ReadAll", err) if string(txt) != expected { t.Errorf("Expected '%s' got '%s'", expected, string(txt)) } browser := getBrowser(os.Args) if browser != "" { - exec.Command(browser, asLocalhost).Run() + _ = exec.Command(browser, asLocalhost).Run() time.Sleep(10 * time.Second) } } func TestSignerRsaTls(t *testing.T) { - testSignerTls(t, GoproxyCa) + testSignerTLS(t, goproxy.GoproxyCa) } func TestSignerRsaX509(t *testing.T) { - testSignerX509(t, GoproxyCa) + testSignerX509(t, goproxy.GoproxyCa) } func TestSignerEcdsaTls(t *testing.T) { - testSignerTls(t, EcdsaCa) + testSignerTLS(t, EcdsaCa) } func TestSignerEcdsaX509(t *testing.T) { testSignerX509(t, EcdsaCa) } -var c *tls.Certificate -var e error - func BenchmarkSignRsa(b *testing.B) { var cert *tls.Certificate var err error for n := 0; n < b.N; n++ { - cert, err = signHost(GoproxyCa, []string{"example.com", "1.1.1.1", "localhost"}) - + cert, err = signer.SignHost(goproxy.GoproxyCa, []string{"example.com", "1.1.1.1", "localhost"}) } - c = cert - e = err + _ = cert + _ = err } func BenchmarkSignEcdsa(b *testing.B) { var cert *tls.Certificate var err error for n := 0; n < b.N; n++ { - cert, err = signHost(EcdsaCa, []string{"example.com", "1.1.1.1", "localhost"}) - + cert, err = signer.SignHost(EcdsaCa, []string{"example.com", "1.1.1.1", "localhost"}) } - c = cert - e = err + _ = cert + _ = err } // // Eliptic Curve certificate and key for testing // -var ECDSA_CA_CERT = []byte(`-----BEGIN CERTIFICATE----- +var EcdsaCaCert = []byte(`-----BEGIN CERTIFICATE----- MIICGDCCAb8CFEkSgqYhlT0+Yyr9anQNJgtclTL0MAoGCCqGSM49BAMDMIGOMQsw CQYDVQQGEwJJTDEPMA0GA1UECAwGQ2VudGVyMQwwCgYDVQQHDANMb2QxEDAOBgNV BAoMB0dvUHJveHkxEDAOBgNVBAsMB0dvUHJveHkxGjAYBgNVBAMMEWdvcHJveHku @@ -146,13 +150,13 @@ svyoAcrcDsynClO9aQtsC9ivZ+Pmr3MwCgYIKoZIzj0EAwMDRwAwRAIgGRSSJVSE 98Bb3nddk2xys6a9 -----END CERTIFICATE-----`) -var ECDSA_CA_KEY = []byte(`-----BEGIN PRIVATE KEY----- +var EcdsaCaKey = []byte(`-----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgEsc8m+2aZfagnesg qMgXe8ph4LtVu2VOUYhHttuEDsChRANCAAQ5R+GK3bpDxQI2zvMfoEfRfCA+3glP Dq4W2vzCG5Uka0VXnaY9PJSvtrL8qAHK3A7MpwpTvWkLbAvYr2fj5q9z -----END PRIVATE KEY-----`) -var EcdsaCa, ecdsaCaErr = tls.X509KeyPair(ECDSA_CA_CERT, ECDSA_CA_KEY) +var EcdsaCa, ecdsaCaErr = tls.X509KeyPair(EcdsaCaCert, EcdsaCaKey) func init() { if ecdsaCaErr != nil { diff --git a/proxy.go b/proxy.go index 24474e717..21add93e9 100644 --- a/proxy.go +++ b/proxy.go @@ -2,6 +2,7 @@ package goproxy import ( "bufio" + "errors" "io" "log" "net" @@ -62,12 +63,9 @@ func copyHeaders(dst, src http.Header, keepDestHeaders bool) { } } -func isEof(r *bufio.Reader) bool { +func isEOF(r *bufio.Reader) bool { _, err := r.Peek(1) - if err == io.EOF { - return true - } - return false + return errors.Is(err, io.EOF) } func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req *http.Request, resp *http.Response) { @@ -82,6 +80,7 @@ func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req } return } + func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *ProxyCtx) (resp *http.Response) { resp = respOrig for _, h := range proxy.respHandlers { @@ -91,7 +90,7 @@ func (proxy *ProxyHttpServer) filterResponse(respOrig *http.Response, ctx *Proxy return } -// RemoveProxyHeaders removes all proxy headers which should not propagate to the next hop +// RemoveProxyHeaders removes all proxy headers which should not propagate to the next hop. func RemoveProxyHeaders(ctx *ProxyCtx, r *http.Request) { r.RequestURI = "" // this must be reset when serving a request with the client ctx.Logf("Sending request %v %v", r.Method, r.URL.String()) @@ -140,7 +139,6 @@ func (fw flushWriter) Write(p []byte) (int, error) { // Standard net/http function. Shouldn't be used directly, http.Serve will use it. func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - //r.Header["X-Forwarded-For"] = w.RemoteAddr() if r.Method == http.MethodConnect { proxy.handleHttps(w, r) } else { @@ -148,12 +146,12 @@ func (proxy *ProxyHttpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) } } -// NewProxyHttpServer creates and returns a proxy server, logging to stderr by default +// NewProxyHttpServer creates and returns a proxy server, logging to stderr by default. func NewProxyHttpServer() *ProxyHttpServer { proxy := ProxyHttpServer{ Logger: log.New(os.Stderr, "", log.LstdFlags), NonproxyHandler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - http.Error(w, "This is a proxy server. Does not respond to non-proxy requests.", 500) + http.Error(w, "This is a proxy server. Does not respond to non-proxy requests.", http.StatusInternalServerError) }), Tr: &http.Transport{TLSClientConfig: tlsClientSkipVerify, Proxy: http.ProxyFromEnvironment}, } diff --git a/proxy_test.go b/proxy_test.go index 6d164e816..a52d8092e 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -7,8 +7,8 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" - "fmt" "io" + "log" "net" "net/http" "net/http/httptest" @@ -23,13 +23,11 @@ import ( "github.com/elazarl/goproxy" ) -var acceptAllCerts = &tls.Config{InsecureSkipVerify: true} - -var noProxyClient = &http.Client{Transport: &http.Transport{TLSClientConfig: acceptAllCerts}} - -var https = httptest.NewTLSServer(nil) -var srv = httptest.NewServer(nil) -var fs = httptest.NewServer(http.FileServer(http.Dir("."))) +var ( + https = httptest.NewTLSServer(nil) + srv = httptest.NewServer(nil) + fs = httptest.NewServer(http.FileServer(http.Dir("."))) +) type QueryHandler struct{} @@ -37,7 +35,7 @@ func (QueryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if err := req.ParseForm(); err != nil { panic(err) } - io.WriteString(w, req.Form.Get("result")) + _, _ = io.WriteString(w, req.Form.Get("result")) } func init() { @@ -48,11 +46,15 @@ func init() { type ConstantHanlder string func (h ConstantHanlder) ServeHTTP(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, string(h)) + _, _ = io.WriteString(w, string(h)) } func get(url string, client *http.Client) ([]byte, error) { - resp, err := client.Get(url) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) if err != nil { return nil, err } @@ -64,7 +66,8 @@ func get(url string, client *http.Client) ([]byte, error) { return txt, nil } -func getOrFail(url string, client *http.Client, t *testing.T) []byte { +func getOrFail(t *testing.T, url string, client *http.Client) []byte { + t.Helper() txt, err := get(url, client) if err != nil { t.Fatal("Can't fetch url", url, err) @@ -72,44 +75,60 @@ func getOrFail(url string, client *http.Client, t *testing.T) []byte { return txt } -func localFile(url string) string { return fs.URL + "/" + url } -func localTls(url string) string { return https.URL + url } +func getCert(t *testing.T, c *tls.Conn) []byte { + t.Helper() + if err := c.Handshake(); err != nil { + t.Fatal("cannot handshake", err) + } + return c.ConnectionState().PeerCertificates[0].Raw +} + +func localFile(url string) string { + return fs.URL + "/" + url +} func TestSimpleHttpReqWithProxy(t *testing.T) { - client, s := oneShotProxy(goproxy.NewProxyHttpServer(), t) + client, s := oneShotProxy(goproxy.NewProxyHttpServer()) defer s.Close() - if r := string(getOrFail(srv.URL+"/bobo", client, t)); r != "bobo" { + if r := string(getOrFail(t, srv.URL+"/bobo", client)); r != "bobo" { t.Error("proxy server does not serve constant handlers", r) } - if r := string(getOrFail(srv.URL+"/bobo", client, t)); r != "bobo" { + if r := string(getOrFail(t, srv.URL+"/bobo", client)); r != "bobo" { t.Error("proxy server does not serve constant handlers", r) } - if string(getOrFail(https.URL+"/bobo", client, t)) != "bobo" { + if string(getOrFail(t, https.URL+"/bobo", client)) != "bobo" { t.Error("TLS server does not serve constant handlers, when proxy is used") } } -func oneShotProxy(proxy *goproxy.ProxyHttpServer, t *testing.T) (client *http.Client, s *httptest.Server) { +func oneShotProxy(proxy *goproxy.ProxyHttpServer) (client *http.Client, s *httptest.Server) { s = httptest.NewServer(proxy) proxyUrl, _ := url.Parse(s.URL) - tr := &http.Transport{TLSClientConfig: acceptAllCerts, Proxy: http.ProxyURL(proxyUrl)} + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Proxy: http.ProxyURL(proxyUrl), + } client = &http.Client{Transport: tr} return } func TestSimpleHook(t *testing.T) { proxy := goproxy.NewProxyHttpServer() - proxy.OnRequest(goproxy.SrcIpIs("127.0.0.1")).DoFunc(func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { - req.URL.Path = "/bobo" - return req, nil - }) - client, l := oneShotProxy(proxy, t) + proxy.OnRequest(goproxy.SrcIpIs("127.0.0.1")).DoFunc( + func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + req.URL.Path = "/bobo" + return req, nil + }, + ) + client, l := oneShotProxy(proxy) defer l.Close() - if result := string(getOrFail(srv.URL+("/momo"), client, t)); result != "bobo" { + if result := string(getOrFail(t, srv.URL+("/momo"), client)); result != "bobo" { t.Error("Redirecting all requests from 127.0.0.1 to bobo, didn't work." + " (Might break if Go's client sets RemoteAddr to IPv6 address). Got: " + result) @@ -122,10 +141,10 @@ func TestAlwaysHook(t *testing.T) { req.URL.Path = "/bobo" return req, nil }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if result := string(getOrFail(srv.URL+("/momo"), client, t)); result != "bobo" { + if result := string(getOrFail(t, srv.URL+("/momo"), client)); result != "bobo" { t.Error("Redirecting all requests from 127.0.0.1 to bobo, didn't work." + " (Might break if Go's client sets RemoteAddr to IPv6 address). Got: " + result) @@ -140,10 +159,10 @@ func TestReplaceResponse(t *testing.T) { return resp }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if result := string(getOrFail(srv.URL+("/momo"), client, t)); result != "chico" { + if result := string(getOrFail(t, srv.URL+("/momo"), client)); result != "chico" { t.Error("hooked response, should be chico, instead:", result) } } @@ -156,19 +175,19 @@ func TestReplaceReponseForUrl(t *testing.T) { return resp }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if result := string(getOrFail(srv.URL+("/koko"), client, t)); result != "chico" { + if result := string(getOrFail(t, srv.URL+("/koko"), client)); result != "chico" { t.Error("hooked 'koko', should be chico, instead:", result) } - if result := string(getOrFail(srv.URL+("/bobo"), client, t)); result != "bobo" { + if result := string(getOrFail(t, srv.URL+("/bobo"), client)); result != "bobo" { t.Error("still, bobo should stay as usual, instead:", result) } } func TestOneShotFileServer(t *testing.T) { - client, l := oneShotProxy(goproxy.NewProxyHttpServer(), t) + client, l := oneShotProxy(goproxy.NewProxyHttpServer()) defer l.Close() file := "test_data/panda.png" @@ -176,7 +195,12 @@ func TestOneShotFileServer(t *testing.T) { if err != nil { t.Fatal("Cannot find", file) } - if resp, err := client.Get(fs.URL + "/" + file); err == nil { + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fs.URL+"/"+file, nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + if resp, err := client.Do(req); err == nil { b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal("got", string(b)) @@ -191,16 +215,22 @@ func TestOneShotFileServer(t *testing.T) { func TestContentType(t *testing.T) { proxy := goproxy.NewProxyHttpServer() - proxy.OnResponse(goproxy.ContentTypeIs("image/png")).DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { - resp.Header.Set("X-Shmoopi", "1") - return resp - }) + proxy.OnResponse(goproxy.ContentTypeIs("image/png")).DoFunc( + func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { + resp.Header.Set("X-Shmoopi", "1") + return resp + }, + ) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() for _, file := range []string{"test_data/panda.png", "test_data/football.png"} { - if resp, err := client.Get(localFile(file)); err != nil || resp.Header.Get("X-Shmoopi") != "1" { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, localFile(file), nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + if resp, err := client.Do(req); err != nil || resp.Header.Get("X-Shmoopi") != "1" { if err == nil { t.Error("pngs should have X-Shmoopi header = 1, actually", resp.Header.Get("X-Shmoopi")) } else { @@ -209,8 +239,11 @@ func TestContentType(t *testing.T) { } } - file := "baby.jpg" - if resp, err := client.Get(localFile(file)); err != nil || resp.Header.Get("X-Shmoopi") != "" { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, localFile("baby.jpg"), nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + if resp, err := client.Do(req); err != nil || resp.Header.Get("X-Shmoopi") != "" { if err == nil { t.Error("Non png images should NOT have X-Shmoopi header at all", resp.Header.Get("X-Shmoopi")) } else { @@ -219,59 +252,40 @@ func TestContentType(t *testing.T) { } } -func readAll(r io.Reader, t *testing.T) []byte { - b, err := io.ReadAll(r) - if err != nil { - t.Fatal("Cannot read", err) - } - return b -} -func readFile(file string, t *testing.T) []byte { - b, err := os.ReadFile(file) - if err != nil { - t.Fatal("Cannot read", err) - } - return b -} -func fatalOnErr(err error, msg string, t *testing.T) { - if err != nil { - t.Fatal(msg, err) - } -} func panicOnErr(err error, msg string) { if err != nil { - println(err.Error() + ":-" + msg) - os.Exit(-1) + log.Fatal(err.Error() + ":-" + msg) } } func TestChangeResp(t *testing.T) { proxy := goproxy.NewProxyHttpServer() proxy.OnResponse().DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { - resp.Body.Read([]byte{0}) + _, _ = resp.Body.Read([]byte{0}) resp.Body = io.NopCloser(new(bytes.Buffer)) return resp }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - resp, err := client.Get(localFile("test_data/panda.png")) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, localFile("test_data/panda.png"), nil) if err != nil { - t.Fatal(err) + t.Fatal("Cannot create request", err) } - io.ReadAll(resp.Body) - _, err = client.Get(localFile("/bobo")) + resp, err := client.Do(req) if err != nil { t.Fatal(err) } -} - -func getCert(c *tls.Conn, t *testing.T) []byte { - if err := c.Handshake(); err != nil { - t.Fatal("cannot handshake", err) + _, _ = io.ReadAll(resp.Body) + req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, localFile("/bobo"), nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + _, err = client.Do(req) + if err != nil { + t.Fatal(err) } - return c.ConnectionState().PeerCertificates[0].Raw } func TestSimpleMitm(t *testing.T) { @@ -279,33 +293,34 @@ func TestSimpleMitm(t *testing.T) { proxy.OnRequest(goproxy.ReqHostIs(https.Listener.Addr().String())).HandleConnect(goproxy.AlwaysMitm) proxy.OnRequest(goproxy.ReqHostIs("no such host exists")).HandleConnect(goproxy.AlwaysMitm) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() c, err := tls.Dial("tcp", https.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) if err != nil { t.Fatal("cannot dial to tcp server", err) } - origCert := getCert(c, t) - c.Close() + origCert := getCert(t, c) + _ = c.Close() c2, err := net.Dial("tcp", l.Listener.Addr().String()) if err != nil { t.Fatal("dialing to proxy", err) } - creq, err := http.NewRequest("CONNECT", https.URL, nil) - //creq,err := http.NewRequest("CONNECT","https://google.com:443",nil) + creq, err := http.NewRequestWithContext(context.Background(), http.MethodConnect, https.URL, nil) if err != nil { t.Fatal("create new request", creq) } - creq.Write(c2) + _ = creq.Write(c2) c2buf := bufio.NewReader(c2) resp, err := http.ReadResponse(c2buf, creq) if err != nil || resp.StatusCode != http.StatusOK { t.Fatal("Cannot CONNECT through proxy", err) } - c2tls := tls.Client(c2, &tls.Config{InsecureSkipVerify: true}) - proxyCert := getCert(c2tls, t) + c2tls := tls.Client(c2, &tls.Config{ + InsecureSkipVerify: true, + }) + proxyCert := getCert(t, c2tls) if bytes.Equal(proxyCert, origCert) { t.Errorf("Certificate after mitm is not different\n%v\n%v", @@ -313,10 +328,10 @@ func TestSimpleMitm(t *testing.T) { base64.StdEncoding.EncodeToString(proxyCert)) } - if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" { + if resp := string(getOrFail(t, https.URL+"/bobo", client)); resp != "bobo" { t.Error("Wrong response when mitm", resp, "expected bobo") } - if resp := string(getOrFail(https.URL+"/query?result=bar", client, t)); resp != "bar" { + if resp := string(getOrFail(t, https.URL+"/query?result=bar", client)); resp != "bar" { t.Error("Wrong response when mitm", resp, "expected bar") } } @@ -329,29 +344,30 @@ func TestConnectHandler(t *testing.T) { return goproxy.OkConnect, u.Host }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if resp := string(getOrFail(https.URL+"/alturl", client, t)); resp != "althttps" { + if resp := string(getOrFail(t, https.URL+"/alturl", client)); resp != "althttps" { t.Error("Proxy should redirect CONNECT requests to local althttps server, expected 'althttps' got ", resp) } } func TestMitmIsFiltered(t *testing.T) { proxy := goproxy.NewProxyHttpServer() - //proxy.Verbose = true proxy.OnRequest(goproxy.ReqHostIs(https.Listener.Addr().String())).HandleConnect(goproxy.AlwaysMitm) - proxy.OnRequest(goproxy.UrlIs("/momo")).DoFunc(func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { - return nil, goproxy.TextResponse(req, "koko") - }) + proxy.OnRequest(goproxy.UrlIs("/momo")).DoFunc( + func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + return nil, goproxy.TextResponse(req, "koko") + }, + ) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if resp := string(getOrFail(https.URL+"/momo", client, t)); resp != "koko" { + if resp := string(getOrFail(t, https.URL+"/momo", client)); resp != "koko" { t.Error("Proxy should capture /momo to be koko and not", resp) } - if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" { + if resp := string(getOrFail(t, https.URL+"/bobo", client)); resp != "bobo" { t.Error("But still /bobo should be bobo and not", resp) } } @@ -365,30 +381,14 @@ func TestFirstHandlerMatches(t *testing.T) { panic("should never get here, previous response is no null") }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - if resp := string(getOrFail(srv.URL+"/", client, t)); resp != "koko" { + if resp := string(getOrFail(t, srv.URL+"/", client)); resp != "koko" { t.Error("should return always koko and not", resp) } } -func constantHttpServer(content []byte) (addr string) { - l, err := net.Listen("tcp", "localhost:0") - panicOnErr(err, "listen") - go func() { - c, err := l.Accept() - panicOnErr(err, "accept") - buf := bufio.NewReader(c) - _, err = http.ReadRequest(buf) - panicOnErr(err, "readReq") - c.Write(content) - c.Close() - l.Close() - }() - return l.Addr().String() -} - func TestIcyResponse(t *testing.T) { // TODO: fix this test /*s := constantHttpServer([]byte("ICY 200 OK\r\n\r\nblablabla")) @@ -424,35 +424,40 @@ func (v VerifyNoProxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) func TestNoProxyHeaders(t *testing.T) { s := httptest.NewServer(VerifyNoProxyHeaders{t}) - client, l := oneShotProxy(goproxy.NewProxyHttpServer(), t) + client, l := oneShotProxy(goproxy.NewProxyHttpServer()) defer l.Close() - req, err := http.NewRequest(http.MethodGet, s.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, s.URL, nil) panicOnErr(err, "bad request") req.Header.Add("Connection", "close") req.Header.Add("Proxy-Connection", "close") req.Header.Add("Proxy-Authenticate", "auth") req.Header.Add("Proxy-Authorization", "auth") - client.Do(req) + _, _ = client.Do(req) } func TestNoProxyHeadersHttps(t *testing.T) { s := httptest.NewTLSServer(VerifyNoProxyHeaders{t}) proxy := goproxy.NewProxyHttpServer() proxy.OnRequest().HandleConnect(goproxy.AlwaysMitm) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() - req, err := http.NewRequest(http.MethodGet, s.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, s.URL, nil) panicOnErr(err, "bad request") req.Header.Add("Connection", "close") req.Header.Add("Proxy-Connection", "close") - client.Do(req) + _, _ = client.Do(req) } func TestHeadReqHasContentLength(t *testing.T) { - client, l := oneShotProxy(goproxy.NewProxyHttpServer(), t) + client, l := oneShotProxy(goproxy.NewProxyHttpServer()) defer l.Close() - resp, err := client.Head(localFile("test_data/panda.png")) + req, err := http.NewRequestWithContext(context.Background(), http.MethodHead, localFile("test_data/panda.png"), nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + + resp, err := client.Do(req) panicOnErr(err, "resp to HEAD") if resp.Header.Get("Content-Length") == "" { t.Error("Content-Length should exist on HEAD requests") @@ -469,7 +474,7 @@ func TestChunkedResponse(t *testing.T) { panicOnErr(err, "accept") _, err = http.ReadRequest(bufio.NewReader(c)) panicOnErr(err, "readrequest") - io.WriteString(c, "HTTP/1.1 200 OK\r\n"+ + _, _ = io.WriteString(c, "HTTP/1.1 200 OK\r\n"+ "Content-Type: text/plain\r\n"+ "Transfer-Encoding: chunked\r\n\r\n"+ "25\r\n"+ @@ -480,15 +485,15 @@ func TestChunkedResponse(t *testing.T) { "con\r\n"+ "8\r\n"+ "sequence\r\n0\r\n\r\n") - c.Close() + _ = c.Close() } }() c, err := net.Dial("tcp", "localhost:10234") panicOnErr(err, "dial") defer c.Close() - req, _ := http.NewRequest(http.MethodGet, "/", nil) - req.Write(c) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + _ = req.Write(c) resp, err := http.ReadResponse(bufio.NewReader(c), req) panicOnErr(err, "readresp") b, err := io.ReadAll(resp.Body) @@ -502,23 +507,28 @@ func TestChunkedResponse(t *testing.T) { proxy.OnResponse().DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { panicOnErr(ctx.Error, "error reading output") b, err := io.ReadAll(resp.Body) - resp.Body.Close() + _ = resp.Body.Close() panicOnErr(err, "readall onresp") if enc := resp.Header.Get("Transfer-Encoding"); enc != "" { t.Fatal("Chunked response should be received as plaintext", enc) } - resp.Body = io.NopCloser(bytes.NewBufferString(strings.Replace(string(b), "e", "E", -1))) + resp.Body = io.NopCloser(bytes.NewBufferString(strings.ReplaceAll(string(b), "e", "E"))) return resp }) - client, s := oneShotProxy(proxy, t) + client, s := oneShotProxy(proxy) defer s.Close() - resp, err = client.Get("http://localhost:10234/") + req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost:10234/", nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + + resp, err = client.Do(req) panicOnErr(err, "client.Get") b, err = io.ReadAll(resp.Body) panicOnErr(err, "readall proxy") - if string(b) != strings.Replace(expected, "e", "E", -1) { + if string(b) != strings.ReplaceAll(expected, "e", "E") { t.Error("expected", expected, "w/ e->E. Got", string(b)) } } @@ -535,17 +545,16 @@ func TestGoproxyThroughProxy(t *testing.T) { proxy.OnRequest().HandleConnect(goproxy.AlwaysMitm) proxy.OnResponse().DoFunc(doubleString) - _, l := oneShotProxy(proxy, t) + _, l := oneShotProxy(proxy) defer l.Close() proxy2.ConnectDial = proxy2.NewConnectDialToProxy(l.URL) - client, l2 := oneShotProxy(proxy2, t) + client, l2 := oneShotProxy(proxy2) defer l2.Close() - if r := string(getOrFail(https.URL+"/bobo", client, t)); r != "bobo bobo" { + if r := string(getOrFail(t, https.URL+"/bobo", client)); r != "bobo bobo" { t.Error("Expected bobo doubled twice, got", r) } - } func TestGoproxyHijackConnect(t *testing.T) { @@ -553,13 +562,19 @@ func TestGoproxyHijackConnect(t *testing.T) { proxy.OnRequest(goproxy.ReqHostIs(srv.Listener.Addr().String())). HijackConnect(func(req *http.Request, client net.Conn, ctx *goproxy.ProxyCtx) { t.Logf("URL %+#v\nSTR %s", req.URL, req.URL.String()) - resp, err := http.Get("http:" + req.URL.String() + "/bobo") + req.URL.Scheme = "http" + req.URL.Path = "/bobo" + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, req.URL.String(), nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + resp, err := http.DefaultClient.Do(req) panicOnErr(err, "http.Get(CONNECT url)") panicOnErr(resp.Write(client), "resp.Write(client)") - resp.Body.Close() - client.Close() + _ = resp.Body.Close() + _ = client.Close() }) - client, l := oneShotProxy(proxy, t) + client, l := oneShotProxy(proxy) defer l.Close() proxyAddr := l.Listener.Addr().String() conn, err := net.Dial("tcp", proxyAddr) @@ -570,13 +585,13 @@ func TestGoproxyHijackConnect(t *testing.T) { t.Error("Expected bobo for CONNECT /foo, got", txt) } - if r := string(getOrFail(https.URL+"/bobo", client, t)); r != "bobo" { + if r := string(getOrFail(t, https.URL+"/bobo", client)); r != "bobo" { t.Error("Expected bobo would keep working with CONNECT", r) } } func readResponse(buf *bufio.Reader) string { - req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) panicOnErr(err, "NewRequest") resp, err := http.ReadResponse(buf, req) panicOnErr(err, "resp.Read") @@ -594,9 +609,9 @@ func writeConnect(w io.Writer) { // here: https://github.com/golang/go/issues/18824 validSrvURL := srv.URL[len("http:"):] - req, err := http.NewRequest("CONNECT", validSrvURL, nil) + req, err := http.NewRequest(http.MethodConnect, validSrvURL, nil) panicOnErr(err, "NewRequest") - req.Write(w) + _ = req.Write(w) panicOnErr(err, "req(CONNECT).Write") } @@ -610,7 +625,7 @@ func TestCurlMinusP(t *testing.T) { called = true return req, nil }) - _, l := oneShotProxy(proxy, t) + _, l := oneShotProxy(proxy) defer l.Close() cmd := exec.Command("curl", "-p", "-sS", "--proxy", l.URL, srv.URL+"/bobo") output, err := cmd.CombinedOutput() @@ -627,9 +642,9 @@ func TestCurlMinusP(t *testing.T) { func TestSelfRequest(t *testing.T) { proxy := goproxy.NewProxyHttpServer() - _, l := oneShotProxy(proxy, t) + _, l := oneShotProxy(proxy) defer l.Close() - if !strings.Contains(string(getOrFail(l.URL, http.DefaultClient, t)), "non-proxy") { + if !strings.Contains(string(getOrFail(t, l.URL, http.DefaultClient)), "non-proxy") { t.Fatal("non proxy requests should fail") } } @@ -646,7 +661,7 @@ func TestHasGoproxyCA(t *testing.T) { tr := &http.Transport{TLSClientConfig: &tls.Config{RootCAs: goproxyCA}, Proxy: http.ProxyURL(proxyUrl)} client := &http.Client{Transport: tr} - if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" { + if resp := string(getOrFail(t, https.URL+"/bobo", client)); resp != "bobo" { t.Error("Wrong response when mitm", resp, "expected bobo") } } @@ -662,14 +677,14 @@ func (tcs *TestCertStorage) Fetch(hostname string, gen func() (*tls.Certificate, var err error cert, ok := tcs.certs[hostname] if ok { - fmt.Printf("hit %v\n", cert == nil) + log.Printf("hit %v\n", cert == nil) tcs.hits++ } else { cert, err = gen() if err != nil { return nil, err } - fmt.Printf("miss %v\n", cert == nil) + log.Printf("miss %v\n", cert == nil) tcs.certs[hostname] = cert tcs.misses++ } @@ -711,7 +726,7 @@ func TestProxyWithCertStorage(t *testing.T) { tr := &http.Transport{TLSClientConfig: &tls.Config{RootCAs: goproxyCA}, Proxy: http.ProxyURL(proxyUrl)} client := &http.Client{Transport: tr} - if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" { + if resp := string(getOrFail(t, https.URL+"/bobo", client)); resp != "bobo" { t.Error("Wrong response when mitm", resp, "expected bobo") } @@ -723,7 +738,7 @@ func TestProxyWithCertStorage(t *testing.T) { } // Another round - this time the certificate can be loaded - if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" { + if resp := string(getOrFail(t, https.URL+"/bobo", client)); resp != "bobo" { t.Error("Wrong response when mitm", resp, "expected bobo") } @@ -767,11 +782,11 @@ func TestHttpsMitmURLRewrite(t *testing.T) { return nil, goproxy.TextResponse(req, "Dummy response") }) - client, s := oneShotProxy(proxy, t) + client, s := oneShotProxy(proxy) defer s.Close() fullURL := scheme + "://" + tc.Host + tc.RawPath - req, err := http.NewRequest(http.MethodGet, fullURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fullURL, nil) if err != nil { t.Fatal(err) } @@ -782,13 +797,12 @@ func TestHttpsMitmURLRewrite(t *testing.T) { } resp, err := client.Do(req) - if err != nil { t.Fatal(err) } b, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + _ = resp.Body.Close() if err != nil { t.Fatal(err) } @@ -809,10 +823,11 @@ func TestSimpleHttpRequest(t *testing.T) { var server *http.Server go func() { - fmt.Println("serving end proxy server at localhost:5000") + t.Log("serving end proxy server at localhost:5000") server = &http.Server{ - Addr: "localhost:5000", - Handler: proxy, + Addr: "localhost:5000", + Handler: proxy, + ReadHeaderTimeout: 10 * time.Second, } err := server.ListenAndServe() if err == nil { @@ -829,13 +844,24 @@ func TestSimpleHttpRequest(t *testing.T) { } client := http.Client{Transport: tr} - resp, err := client.Get("http://example.com") + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + + resp, err := client.Do(req) if err != nil { t.Error("Error requesting http site", err) } else if resp.StatusCode != http.StatusOK { t.Error("Non-OK status requesting http site", err) } - resp, _ = client.Get("http://example.invalid") + + req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.invalid", nil) + if err != nil { + t.Fatal("Cannot create request", err) + } + + resp, _ = client.Do(req) if resp == nil { t.Error("No response requesting invalid http site") } @@ -845,12 +871,12 @@ func TestSimpleHttpRequest(t *testing.T) { } proxy.OnResponse(goproxy.UrlMatches(regexp.MustCompile(".*"))).DoFunc(returnNil) - resp, _ = client.Get("http://example.invalid") + resp, _ = client.Do(req) if resp == nil { t.Error("No response requesting invalid http site") } - server.Shutdown(context.TODO()) + _ = server.Shutdown(context.TODO()) } func TestResponseContentLength(t *testing.T) { @@ -864,7 +890,7 @@ func TestResponseContentLength(t *testing.T) { proxy := goproxy.NewProxyHttpServer() proxy.OnResponse().DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { buf := &bytes.Buffer{} - buf.Write([]byte("change")) + buf.WriteString("change") resp.Body = io.NopCloser(buf) return resp }) @@ -877,11 +903,11 @@ func TestResponseContentLength(t *testing.T) { return url.Parse(proxySrv.URL) }, } - req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp, _ := http.DefaultClient.Do(req) body, _ := io.ReadAll(resp.Body) - defer resp.Body.Close() + _ = resp.Body.Close() if int64(len(body)) != resp.ContentLength { t.Logf("response body: %s", string(body)) diff --git a/regretable/regretreader.go b/regretable/regretreader.go index 3769a9c12..0d5c3a400 100644 --- a/regretable/regretreader.go +++ b/regretable/regretreader.go @@ -4,8 +4,8 @@ import ( "io" ) -// A RegretableReader will allow you to read from a reader, and then -// to "regret" reading it, and push back everything you've read. +// Reader in regretable package will allow you to read from a reader, +// and then to "regret" reading it, and push back everything you've read. // For example: // // rb := NewRegretableReader(bytes.NewBuffer([]byte{1,2,3})) @@ -13,39 +13,18 @@ import ( // rb.Read(b) // b[0] = 1 // rb.Regret() // ioutil.ReadAll(rb.Read) // returns []byte{1,2,3},nil -type RegretableReader struct { +type Reader struct { reader io.Reader overflow bool r, w int buf []byte } -var defaultBufferSize = 500 - -// Same as RegretableReader, but allows closing the underlying reader -type RegretableReaderCloser struct { - RegretableReader - c io.Closer -} - -// Closes the underlying readCloser, you cannot regret after closing the stream -func (rbc *RegretableReaderCloser) Close() error { - return rbc.c.Close() -} - -// initialize a RegretableReaderCloser with underlying readCloser rc -func NewRegretableReaderCloser(rc io.ReadCloser) *RegretableReaderCloser { - return &RegretableReaderCloser{*NewRegretableReader(rc), rc} -} - -// initialize a RegretableReaderCloser with underlying readCloser rc -func NewRegretableReaderCloserSize(rc io.ReadCloser, size int) *RegretableReaderCloser { - return &RegretableReaderCloser{*NewRegretableReaderSize(rc, size), rc} -} +const _defaultBufferSize = 500 // The next read from the RegretableReader will be as if the underlying reader // was never read (or from the last point forget is called). -func (rb *RegretableReader) Regret() { +func (rb *Reader) Regret() { if rb.overflow { panic("regretting after overflow makes no sense") } @@ -61,7 +40,7 @@ func (rb *RegretableReader) Regret() { // rb.Read(b) // b[0] = 2 // rb.Regret() // ioutil.ReadAll(rb.Read) // returns []byte{2,3},nil -func (rb *RegretableReader) Forget() { +func (rb *Reader) Forget() { if rb.overflow { panic("forgetting after overflow makes no sense") } @@ -69,18 +48,18 @@ func (rb *RegretableReader) Forget() { rb.w = 0 } -// initialize a RegretableReader with underlying reader r, whose buffer is size bytes long -func NewRegretableReaderSize(r io.Reader, size int) *RegretableReader { - return &RegretableReader{reader: r, buf: make([]byte, size)} +// initialize a RegretableReader with underlying reader r, whose buffer is size bytes long. +func NewRegretableReaderSize(r io.Reader, size int) *Reader { + return &Reader{reader: r, buf: make([]byte, size)} } -// initialize a RegretableReader with underlying reader r -func NewRegretableReader(r io.Reader) *RegretableReader { - return NewRegretableReaderSize(r, defaultBufferSize) +// initialize a RegretableReader with underlying reader r. +func NewRegretableReader(r io.Reader) *Reader { + return NewRegretableReaderSize(r, _defaultBufferSize) } // reads from the underlying reader. Will buffer all input until Regret is called. -func (rb *RegretableReader) Read(p []byte) (n int, err error) { +func (rb *Reader) Read(p []byte) (n int, err error) { if rb.overflow { return rb.reader.Read(p) } @@ -97,3 +76,24 @@ func (rb *RegretableReader) Read(p []byte) (n int, err error) { } return } + +// ReaderCloser is the same as Reader, but allows closing the underlying reader. +type ReaderCloser struct { + Reader + c io.Closer +} + +// initialize a RegretableReaderCloser with underlying readCloser rc. +func NewRegretableReaderCloser(rc io.ReadCloser) *ReaderCloser { + return &ReaderCloser{*NewRegretableReader(rc), rc} +} + +// initialize a RegretableReaderCloser with underlying readCloser rc. +func NewRegretableReaderCloserSize(rc io.ReadCloser, size int) *ReaderCloser { + return &ReaderCloser{*NewRegretableReaderSize(rc, size), rc} +} + +// Closes the underlying readCloser, you cannot regret after closing the stream. +func (rbc *ReaderCloser) Close() error { + return rbc.c.Close() +} diff --git a/regretable/regretreader_test.go b/regretable/regretreader_test.go index e5def117a..7230c2742 100644 --- a/regretable/regretreader_test.go +++ b/regretable/regretreader_test.go @@ -2,20 +2,37 @@ package regretable_test import ( "bytes" - . "github.com/elazarl/goproxy/regretable" "io" "strings" "testing" + + "github.com/elazarl/goproxy/regretable" ) +func assertEqual(t *testing.T, expected, actual string) { + t.Helper() + if expected != actual { + t.Fatal("Expected", expected, "actual", actual) + } +} + +func assertReadAll(t *testing.T, r io.Reader) string { + t.Helper() + s, err := io.ReadAll(r) + if err != nil { + t.Fatal("error when reading", err) + } + return string(s) +} + func TestRegretableReader(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) fivebytes := make([]byte, 5) - mb.Read(fivebytes) + _, _ = mb.Read(fivebytes) mb.Regret() s, _ := io.ReadAll(mb) @@ -26,12 +43,12 @@ func TestRegretableReader(t *testing.T) { func TestRegretableEmptyRead(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) zero := make([]byte, 0) - mb.Read(zero) + _, _ = mb.Read(zero) mb.Regret() s, err := io.ReadAll(mb) @@ -42,16 +59,16 @@ func TestRegretableEmptyRead(t *testing.T) { func TestRegretableAlsoEmptyRead(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) one := make([]byte, 1) zero := make([]byte, 0) five := make([]byte, 5) - mb.Read(one) - mb.Read(zero) - mb.Read(five) + _, _ = mb.Read(one) + _, _ = mb.Read(zero) + _, _ = mb.Read(five) mb.Regret() s, _ := io.ReadAll(mb) @@ -62,13 +79,13 @@ func TestRegretableAlsoEmptyRead(t *testing.T) { func TestRegretableRegretBeforeRead(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) five := make([]byte, 5) mb.Regret() - mb.Read(five) + _, _ = mb.Read(five) s, err := io.ReadAll(mb) if string(s) != "678" { @@ -78,12 +95,12 @@ func TestRegretableRegretBeforeRead(t *testing.T) { func TestRegretableFullRead(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) twenty := make([]byte, 20) - mb.Read(twenty) + _, _ = mb.Read(twenty) mb.Regret() s, _ := io.ReadAll(mb) @@ -92,23 +109,9 @@ func TestRegretableFullRead(t *testing.T) { } } -func assertEqual(t *testing.T, expected, actual string) { - if expected != actual { - t.Fatal("Expected", expected, "actual", actual) - } -} - -func assertReadAll(t *testing.T, r io.Reader) string { - s, err := io.ReadAll(r) - if err != nil { - t.Fatal("error when reading", err) - } - return string(s) -} - func TestRegretableRegretTwice(t *testing.T) { buf := new(bytes.Buffer) - mb := NewRegretableReader(buf) + mb := regretable.NewRegretableReader(buf) word := "12345678" buf.WriteString(word) @@ -133,39 +136,39 @@ func (cc *CloseCounter) Close() error { return nil } -func assert(t *testing.T, b bool, msg string) { - if !b { - t.Errorf("Assertion Error: %s", msg) - } -} - func TestRegretableCloserSizeRegrets(t *testing.T) { defer func() { - if r := recover(); r == nil || !strings.Contains(r.(string), "regret") { + r := recover() + if r == nil { t.Error("Did not panic when regretting overread buffer:", r) } + + stringValue, ok := r.(string) + if !ok || !strings.Contains(stringValue, "regret") { + t.Error("Invalid panic value when regretting overread buffer:", r) + } }() buf := new(bytes.Buffer) buf.WriteString("123456") - mb := NewRegretableReaderCloserSize(io.NopCloser(buf), 3) - mb.Read(make([]byte, 4)) + mb := regretable.NewRegretableReaderCloserSize(io.NopCloser(buf), 3) + _, _ = mb.Read(make([]byte, 4)) mb.Regret() } func TestRegretableCloserRegretsClose(t *testing.T) { buf := new(bytes.Buffer) cc := &CloseCounter{buf, 0} - mb := NewRegretableReaderCloser(cc) + mb := regretable.NewRegretableReaderCloser(cc) word := "12345678" buf.WriteString(word) - mb.Read([]byte{0}) - mb.Close() + _, _ = mb.Read([]byte{0}) + _ = mb.Close() if cc.closed != 1 { t.Error("RegretableReaderCloser ignores Close") } mb.Regret() - mb.Close() + _ = mb.Close() if cc.closed != 2 { t.Error("RegretableReaderCloser does ignore Close after regret") } diff --git a/responses.go b/responses.go index 4081c7e42..78b93a512 100644 --- a/responses.go +++ b/responses.go @@ -33,7 +33,7 @@ const ( ContentTypeHtml = "text/html" ) -// Alias for NewResponse(r,ContentTypeText,http.StatusAccepted,text) +// Alias for NewResponse(r,ContentTypeText,http.StatusAccepted,text). func TextResponse(r *http.Request, text string) *http.Response { return NewResponse(r, ContentTypeText, http.StatusAccepted, text) } diff --git a/transport/transport.go b/transport/transport.go index 0706035a1..06dd57646 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -91,13 +91,10 @@ func ProxyFromEnvironment(req *http.Request) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) if err != nil || proxyURL.Scheme == "" { - if u, err := url.Parse("http://" + proxy); err == nil { - proxyURL = u - err = nil - } + proxyURL, err = url.Parse("http://" + proxy) } if err != nil { - return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err) + return nil, fmt.Errorf("invalid proxy address %q: %w", proxy, err) } return proxyURL, nil } @@ -267,11 +264,11 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { return false } key := pconn.cacheKey - max := t.MaxIdleConnsPerHost - if max == 0 { - max = DefaultMaxIdleConnsPerHost + maxIdleConns := t.MaxIdleConnsPerHost + if maxIdleConns == 0 { + maxIdleConns = DefaultMaxIdleConnsPerHost } - if len(t.idleConn[key]) >= max { + if len(t.idleConn[key]) >= maxIdleConns { pconn.close() return false } @@ -338,7 +335,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { conn, raddr, ip, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { - err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + err = fmt.Errorf("http: error connecting to proxy %s: %w", cm.proxyURL, err) } return nil, err } @@ -374,7 +371,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if pa != "" { connectReq.Header.Set("Proxy-Authorization", pa) } - connectReq.Write(conn) + _ = connectReq.Write(conn) // Read response. // Okay to use and discard buffered reader here, because @@ -395,11 +392,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. conn = tls.Client(conn, t.TLSClientConfig) - if err = conn.(*tls.Conn).Handshake(); err != nil { + tlsConn, ok := conn.(*tls.Conn) + if !ok { + return nil, errors.New("invalid TLS connection") + } + if err = tlsConn.Handshake(); err != nil { return nil, err } if t.TLSClientConfig == nil || !t.TLSClientConfig.InsecureSkipVerify { - if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil { + if err = tlsConn.VerifyHostname(cm.tlsHost()); err != nil { return nil, err } } @@ -503,7 +504,7 @@ func (cm *connectMethod) tlsHost() string { } // persistConn wraps a connection, usually a persistent one -// (but may be used for non-keep-alive requests as well) +// (but may be used for non-keep-alive requests as well). type persistConn struct { t *Transport cacheKey string // its connectMethod.String() @@ -532,18 +533,6 @@ func (pc *persistConn) isBroken() bool { return pc.broken } -var remoteSideClosedFunc func(error) bool // or nil to use default - -func remoteSideClosed(err error) bool { - if err == io.EOF { - return true - } - if remoteSideClosedFunc != nil { - return remoteSideClosedFunc(err) - } - return false -} - func (pc *persistConn) readLoop() { alive := true var lastbody io.ReadCloser // last response body, if any, read on this connection @@ -600,9 +589,13 @@ func (pc *persistConn) readLoop() { var waitForBodyRead chan bool if alive { if hasBody { + bodyEof, ok := resp.Body.(*bodyEOFSignal) + if !ok { + alive = false + } lastbody = resp.Body waitForBodyRead = make(chan bool) - resp.Body.(*bodyEOFSignal).fn = func() { + bodyEof.fn = func() { if !pc.t.putIdleConn(pc) { alive = false } @@ -679,7 +672,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, er } if err != nil { pc.close() - return + return nil, err } pc.bw.Flush() @@ -710,7 +703,7 @@ var portMap = map[string]string{ "https": "443", } -// canonicalAddr returns url.Host but always with a ":port" suffix +// canonicalAddr returns url.Host but always with a ":port" suffix. func canonicalAddr(url *url.URL) string { addr := url.Host if !hasPort(addr) { @@ -719,11 +712,6 @@ func canonicalAddr(url *url.URL) string { return addr } -func responseIsKeepAlive(res *http.Response) bool { - // TODO: implement. for now just always shutting down the connection. - return false -} - // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before the final Read() or Close() call returns, but after // EOF has been seen. @@ -738,7 +726,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { if es.isClosed && n > 0 { panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") } - if err == io.EOF && es.fn != nil { + if errors.Is(err, io.EOF) && es.fn != nil { es.fn() es.fn = nil } @@ -780,6 +768,6 @@ type discardOnCloseReadCloser struct { } func (d *discardOnCloseReadCloser) Close() error { - io.Copy(io.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed + _, _ = io.Copy(io.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed return d.ReadCloser.Close() } diff --git a/websocket.go b/websocket.go index 753a1e8d3..07c73a3ba 100644 --- a/websocket.go +++ b/websocket.go @@ -25,7 +25,13 @@ func isWebSocketRequest(r *http.Request) bool { headerContains(r.Header, "Upgrade", "websocket") } -func (proxy *ProxyHttpServer) serveWebsocketTLS(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) { +func (proxy *ProxyHttpServer) serveWebsocketTLS( + ctx *ProxyCtx, + w http.ResponseWriter, + req *http.Request, + tlsConfig *tls.Config, + clientConn *tls.Conn, +) { targetURL := url.URL{Scheme: "wss", Host: req.URL.Host, Path: req.URL.Path} // Connect to upstream @@ -46,7 +52,12 @@ func (proxy *ProxyHttpServer) serveWebsocketTLS(ctx *ProxyCtx, w http.ResponseWr proxy.proxyWebsocket(ctx, targetConn, clientConn) } -func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request, clientConn *tls.Conn) { +func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS( + ctx *ProxyCtx, + w http.ResponseWriter, + req *http.Request, + clientConn *tls.Conn, +) { targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path} // Connect to upstream @@ -98,7 +109,12 @@ func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWrite proxy.proxyWebsocket(ctx, targetConn, clientConn) } -func (proxy *ProxyHttpServer) websocketHandshake(ctx *ProxyCtx, req *http.Request, targetSiteConn io.ReadWriter, clientConn io.ReadWriter) error { +func (proxy *ProxyHttpServer) websocketHandshake( + ctx *ProxyCtx, + req *http.Request, + targetSiteConn io.ReadWriter, + clientConn io.ReadWriter, +) error { // write handshake request to target err := req.Write(targetSiteConn) if err != nil {