From dd259cdcc54e6910dd8c4a0bc8a8dcf1139ff5d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Mon, 7 Nov 2022 14:17:12 +0100 Subject: [PATCH] IsLocalHost: support host:port Without this patch IsLocalHost does not work for URLs with port specified i.e. it works for `http://localhost` but does not work for `http://localhost:80` or `http://localhost:10000`. Fixes #487 --- dispatcher.go | 23 +++++++++++++++-------- dispatcher_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 dispatcher_test.go diff --git a/dispatcher.go b/dispatcher.go index 25c949c0..9b79bcc3 100644 --- a/dispatcher.go +++ b/dispatcher.go @@ -97,15 +97,22 @@ func ReqHostIs(hosts ...string) ReqConditionFunc { } } -var localHostIpv4 = regexp.MustCompile(`127\.0\.0\.\d+`) - -// IsLocalHost checks whether the destination host is explicitly local host -// (buggy, there can be IPv6 addresses it doesn't catch) +// IsLocalHost checks whether the destination host is localhost. var IsLocalHost ReqConditionFunc = func(req *http.Request, ctx *ProxyCtx) bool { - return req.URL.Host == "::1" || - req.URL.Host == "0:0:0:0:0:0:0:1" || - localHostIpv4.MatchString(req.URL.Host) || - req.URL.Host == "localhost" + h := req.URL.Hostname() + if h == "localhost" { + return true + } + if ip := net.ParseIP(h); ip != nil { + return ip.IsLoopback() + } + + // In case of IPv6 without a port number Hostname() sometimes returns the invalid value. + if ip := net.ParseIP(req.URL.Host); ip != nil { + return ip.IsLoopback() + } + + return false } // UrlMatches returns a ReqCondition testing whether the destination URL diff --git a/dispatcher_test.go b/dispatcher_test.go new file mode 100644 index 00000000..c7e50245 --- /dev/null +++ b/dispatcher_test.go @@ -0,0 +1,47 @@ +package goproxy + +import ( + "net" + "net/http" + "strings" + "testing" +) + +func TestIsLocalHost(t *testing.T) { + hosts := []string{ + "localhost", + "127.0.0.1", + "127.0.0.7", + "::ffff:127.0.0.1", + "::ffff:127.0.0.7", + "::1", + "0:0:0:0:0:0:0:1", + } + ports := []string{ + "", + "80", + "443", + } + + for _, host := range hosts { + for _, port := range ports { + if port == "" && strings.HasPrefix(host, "::ffff:") { + continue + } + + addr := host + if port != "" { + addr = net.JoinHostPort(host, port) + } + t.Run(addr, func(t *testing.T) { + req, err := http.NewRequest("GET", "http://"+addr, http.NoBody) + if err != nil { + t.Fatal(err) + } + if !IsLocalHost(req, nil) { + t.Fatal("expected true") + } + }) + } + } +}