From e27082010ce364e058216a463bbc5aea37f1fd51 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 20 Sep 2024 22:53:06 +0800 Subject: [PATCH] fix handling of proxying requests for IPv6 addresses (#69) net.ResolveUPDAddr expects an IPv6 address to be wrapped in [], e.g. [::1]:443. --- client.go | 2 +- connect-udp_test.go | 15 ++++++++++----- request.go | 18 +++++++++--------- request_test.go | 17 ++++++++++++++++- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index a5bce35..a2cfefa 100644 --- a/client.go +++ b/client.go @@ -67,7 +67,7 @@ func (c *Client) Dial(ctx context.Context, raddr *net.UDPAddr) (net.PacketConn, return nil, nil, errors.New("masque: no template") } str, err := c.Template.Expand(uritemplate.Values{ - uriTemplateTargetHost: uritemplate.String(url.QueryEscape(raddr.IP.String())), + uriTemplateTargetHost: uritemplate.String(escape(raddr.IP.String())), uriTemplateTargetPort: uritemplate.String(strconv.Itoa(raddr.Port)), }) if err != nil { diff --git a/connect-udp_test.go b/connect-udp_test.go index 25c01f3..47c0a40 100644 --- a/connect-udp_test.go +++ b/connect-udp_test.go @@ -23,9 +23,9 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func runEchoServer(t *testing.T) *net.UDPConn { +func runEchoServer(t *testing.T, addr *net.UDPAddr) *net.UDPConn { t.Helper() - conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + conn, err := net.ListenUDP("udp", addr) require.NoError(t, err) go func() { for { @@ -43,7 +43,12 @@ func runEchoServer(t *testing.T) *net.UDPConn { } func TestProxyToIP(t *testing.T) { - remoteServerConn := runEchoServer(t) + t.Run("IPv4", func(t *testing.T) { testProxyToIP(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) }) + t.Run("IPv6", func(t *testing.T) { testProxyToIP(t, &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) }) +} + +func testProxyToIP(t *testing.T, addr *net.UDPAddr) { + remoteServerConn := runEchoServer(t, addr) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) @@ -94,7 +99,7 @@ func TestProxyToIP(t *testing.T) { } func TestProxyToHostname(t *testing.T) { - remoteServerConn := runEchoServer(t) + remoteServerConn := runEchoServer(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) @@ -197,7 +202,7 @@ func TestProxyToHostnameMissingPort(t *testing.T) { } func TestProxyShutdown(t *testing.T) { - remoteServerConn := runEchoServer(t) + remoteServerConn := runEchoServer(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) defer remoteServerConn.Close() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) diff --git a/request.go b/request.go index 645139f..71a9b85 100644 --- a/request.go +++ b/request.go @@ -3,9 +3,9 @@ package masque import ( "fmt" "net/http" - "net/url" "reflect" "strconv" + "strings" "github.com/dunglas/httpsfv" "github.com/yosida95/uritemplate/v3" @@ -86,20 +86,17 @@ func ParseRequest(r *http.Request, template *uritemplate.Template) (*Request, er } match := template.Match(r.URL.String()) - targetHostEncoded := match.Get(uriTemplateTargetHost).String() + targetHost := unescape(match.Get(uriTemplateTargetHost).String()) targetPortStr := match.Get(uriTemplateTargetPort).String() - if targetHostEncoded == "" || targetPortStr == "" { + if targetHost == "" || targetPortStr == "" { return nil, &RequestParseError{ HTTPStatus: http.StatusBadRequest, Err: fmt.Errorf("expected target_host and target_port"), } } - targetHost, err := url.QueryUnescape(targetHostEncoded) - if err != nil { - return nil, &RequestParseError{ - HTTPStatus: http.StatusBadRequest, - Err: fmt.Errorf("failed to decode target_host: %w", err), - } + // IPv6 addresses need to be enclosed in [], otherwise resolving the address will fail. + if strings.Contains(targetHost, ":") { + targetHost = "[" + targetHost + "]" } targetPort, err := strconv.Atoi(targetPortStr) if err != nil { @@ -110,3 +107,6 @@ func ParseRequest(r *http.Request, template *uritemplate.Template) (*Request, er } return &Request{Target: fmt.Sprintf("%s:%d", targetHost, targetPort)}, nil } + +func escape(s string) string { return strings.ReplaceAll(s, ":", "%3A") } +func unescape(s string) string { return strings.ReplaceAll(s, "%3A", ":") } diff --git a/request_test.go b/request_test.go index 8e46f4a..a09fe13 100644 --- a/request_test.go +++ b/request_test.go @@ -1,6 +1,7 @@ package masque import ( + "fmt" "net/http" "testing" @@ -12,13 +13,27 @@ import ( func TestRequestParsing(t *testing.T) { template := uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}") - t.Run("invalid target port", func(t *testing.T) { + t.Run("valid request for a hostname", func(t *testing.T) { req := newRequest("https://localhost:1234/masque?h=localhost&p=1337") r, err := ParseRequest(req, template) require.NoError(t, err) require.Equal(t, r.Target, "localhost:1337") }) + t.Run("valid request for an IPv4 address", func(t *testing.T) { + req := newRequest("https://localhost:1234/masque?h=1.2.3.4&p=9999") + r, err := ParseRequest(req, template) + require.NoError(t, err) + require.Equal(t, r.Target, "1.2.3.4:9999") + }) + + t.Run("valid request for an IPv6 address", func(t *testing.T) { + req := newRequest(fmt.Sprintf("https://localhost:1234/masque?h=%s&p=1234", escape("::1"))) + r, err := ParseRequest(req, template) + require.NoError(t, err) + require.Equal(t, r.Target, "[::1]:1234") + }) + t.Run("wrong request method", func(t *testing.T) { req := newRequest("https://localhost:1234/masque") req.Method = http.MethodHead