Skip to content

Commit

Permalink
fix handling of proxying requests for IPv6 addresses (#69)
Browse files Browse the repository at this point in the history
net.ResolveUPDAddr expects an IPv6 address to be wrapped in [], e.g.
[::1]:443.
  • Loading branch information
marten-seemann authored Sep 20, 2024
1 parent e1c12a9 commit e270820
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 16 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *Client) Dial(ctx context.Context, raddr *net.UDPAddr) (net.PacketConn,
return nil, nil, errors.New("masque: no template")
}
str, err := c.Template.Expand(uritemplate.Values{
uriTemplateTargetHost: uritemplate.String(url.QueryEscape(raddr.IP.String())),
uriTemplateTargetHost: uritemplate.String(escape(raddr.IP.String())),
uriTemplateTargetPort: uritemplate.String(strconv.Itoa(raddr.Port)),
})
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions connect-udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func runEchoServer(t *testing.T) *net.UDPConn {
func runEchoServer(t *testing.T, addr *net.UDPAddr) *net.UDPConn {
t.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
conn, err := net.ListenUDP("udp", addr)
require.NoError(t, err)
go func() {
for {
Expand All @@ -43,7 +43,12 @@ func runEchoServer(t *testing.T) *net.UDPConn {
}

func TestProxyToIP(t *testing.T) {
remoteServerConn := runEchoServer(t)
t.Run("IPv4", func(t *testing.T) { testProxyToIP(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) })
t.Run("IPv6", func(t *testing.T) { testProxyToIP(t, &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) })
}

func testProxyToIP(t *testing.T, addr *net.UDPAddr) {
remoteServerConn := runEchoServer(t, addr)
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand Down Expand Up @@ -94,7 +99,7 @@ func TestProxyToIP(t *testing.T) {
}

func TestProxyToHostname(t *testing.T) {
remoteServerConn := runEchoServer(t)
remoteServerConn := runEchoServer(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand Down Expand Up @@ -197,7 +202,7 @@ func TestProxyToHostnameMissingPort(t *testing.T) {
}

func TestProxyShutdown(t *testing.T) {
remoteServerConn := runEchoServer(t)
remoteServerConn := runEchoServer(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
defer remoteServerConn.Close()

conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expand Down
18 changes: 9 additions & 9 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package masque
import (
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"

"github.com/dunglas/httpsfv"
"github.com/yosida95/uritemplate/v3"
Expand Down Expand Up @@ -86,20 +86,17 @@ func ParseRequest(r *http.Request, template *uritemplate.Template) (*Request, er
}

match := template.Match(r.URL.String())
targetHostEncoded := match.Get(uriTemplateTargetHost).String()
targetHost := unescape(match.Get(uriTemplateTargetHost).String())
targetPortStr := match.Get(uriTemplateTargetPort).String()
if targetHostEncoded == "" || targetPortStr == "" {
if targetHost == "" || targetPortStr == "" {
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("expected target_host and target_port"),
}
}
targetHost, err := url.QueryUnescape(targetHostEncoded)
if err != nil {
return nil, &RequestParseError{
HTTPStatus: http.StatusBadRequest,
Err: fmt.Errorf("failed to decode target_host: %w", err),
}
// IPv6 addresses need to be enclosed in [], otherwise resolving the address will fail.
if strings.Contains(targetHost, ":") {
targetHost = "[" + targetHost + "]"
}
targetPort, err := strconv.Atoi(targetPortStr)
if err != nil {
Expand All @@ -110,3 +107,6 @@ func ParseRequest(r *http.Request, template *uritemplate.Template) (*Request, er
}
return &Request{Target: fmt.Sprintf("%s:%d", targetHost, targetPort)}, nil
}

func escape(s string) string { return strings.ReplaceAll(s, ":", "%3A") }
func unescape(s string) string { return strings.ReplaceAll(s, "%3A", ":") }
17 changes: 16 additions & 1 deletion request_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package masque

import (
"fmt"
"net/http"
"testing"

Expand All @@ -12,13 +13,27 @@ import (
func TestRequestParsing(t *testing.T) {
template := uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}")

t.Run("invalid target port", func(t *testing.T) {
t.Run("valid request for a hostname", func(t *testing.T) {
req := newRequest("https://localhost:1234/masque?h=localhost&p=1337")
r, err := ParseRequest(req, template)
require.NoError(t, err)
require.Equal(t, r.Target, "localhost:1337")
})

t.Run("valid request for an IPv4 address", func(t *testing.T) {
req := newRequest("https://localhost:1234/masque?h=1.2.3.4&p=9999")
r, err := ParseRequest(req, template)
require.NoError(t, err)
require.Equal(t, r.Target, "1.2.3.4:9999")
})

t.Run("valid request for an IPv6 address", func(t *testing.T) {
req := newRequest(fmt.Sprintf("https://localhost:1234/masque?h=%s&p=1234", escape("::1")))
r, err := ParseRequest(req, template)
require.NoError(t, err)
require.Equal(t, r.Target, "[::1]:1234")
})

t.Run("wrong request method", func(t *testing.T) {
req := newRequest("https://localhost:1234/masque")
req.Method = http.MethodHead
Expand Down

0 comments on commit e270820

Please sign in to comment.