diff --git a/proxy.go b/proxy.go index 8e295aa..482cd1d 100644 --- a/proxy.go +++ b/proxy.go @@ -2,7 +2,6 @@ package masque import ( "context" - "fmt" "io" "log" "net" @@ -13,13 +12,6 @@ import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/quicvarint" - - "github.com/dunglas/httpsfv" -) - -const ( - requestProtocol = "connect-udp" - capsuleHeader = "Capsule-Protocol" ) const ( @@ -27,16 +19,6 @@ const ( uriTemplateTargetPort = "target_port" ) -var capsuleProtocolHeaderValue string - -func init() { - v, err := httpsfv.Marshal(httpsfv.NewItem(1)) - if err != nil { - panic(fmt.Sprintf("failed to marshal capsule protocol header value: %v", err)) - } - capsuleProtocolHeaderValue = v -} - type proxyEntry struct { str http3.Stream conn *net.UDPConn diff --git a/proxy_test.go b/proxy_test.go index 75cbacb..6dc75c0 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -30,7 +30,7 @@ func newRequest(target string) *http.Request { req := httptest.NewRequest(http.MethodGet, target, nil) req.Method = http.MethodConnect req.Proto = requestProtocol - req.Header.Add("Capsule-Protocol", "1") + req.Header.Add("Capsule-Protocol", capsuleProtocolHeaderValue) return req } diff --git a/request.go b/request.go index 73fe121..645139f 100644 --- a/request.go +++ b/request.go @@ -4,12 +4,28 @@ import ( "fmt" "net/http" "net/url" + "reflect" "strconv" "github.com/dunglas/httpsfv" "github.com/yosida95/uritemplate/v3" ) +const ( + requestProtocol = "connect-udp" + capsuleHeader = "Capsule-Protocol" +) + +var capsuleProtocolHeaderValue string + +func init() { + v, err := httpsfv.Marshal(httpsfv.NewItem(true)) + if err != nil { + panic(fmt.Sprintf("failed to marshal capsule protocol header value: %v", err)) + } + capsuleProtocolHeaderValue = v +} + // Request is the parsed CONNECT-UDP request returned from ParseRequest. // Target is the target server that the client requests to connect to. // It can either be DNS name:port or an IP:port. @@ -57,10 +73,15 @@ func ParseRequest(r *http.Request, template *uritemplate.Template) (*Request, er Err: fmt.Errorf("invalid capsule header value: %s", capsuleHeaderValues), } } - if v, ok := item.Value.(int64); !ok || v != 1 { + if v, ok := item.Value.(bool); !ok { + return nil, &RequestParseError{ + HTTPStatus: http.StatusBadRequest, + Err: fmt.Errorf("incorrect capsule header value type: %s", reflect.TypeOf(item.Value)), + } + } else if !v { return nil, &RequestParseError{ HTTPStatus: http.StatusBadRequest, - Err: fmt.Errorf("incorrect capsule header value: %d", v), + Err: fmt.Errorf("incorrect capsule header value: %t", item.Value), } } diff --git a/request_test.go b/request_test.go index 12202af..8e46f4a 100644 --- a/request_test.go +++ b/request_test.go @@ -4,6 +4,7 @@ import ( "net/http" "testing" + "github.com/dunglas/httpsfv" "github.com/stretchr/testify/require" "github.com/yosida95/uritemplate/v3" ) @@ -11,6 +12,13 @@ import ( func TestRequestParsing(t *testing.T) { template := uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}") + t.Run("invalid target port", func(t *testing.T) { + req := newRequest("https://localhost:1234/masque?h=localhost&p=1337") + r, err := ParseRequest(req, template) + require.NoError(t, err) + require.Equal(t, r.Target, "localhost:1337") + }) + t.Run("wrong request method", func(t *testing.T) { req := newRequest("https://localhost:1234/masque") req.Method = http.MethodHead @@ -43,11 +51,21 @@ func TestRequestParsing(t *testing.T) { require.Equal(t, http.StatusBadRequest, err.(*RequestParseError).HTTPStatus) }) - t.Run("invalid Capsule-Protocol header value", func(t *testing.T) { + t.Run("invalid Capsule-Protocol header value type", func(t *testing.T) { req := newRequest("https://localhost:1234/masque") - req.Header.Set("Capsule-Protocol", "2") + req.Header.Set("Capsule-Protocol", "1") _, err := ParseRequest(req, template) - require.EqualError(t, err, "incorrect capsule header value: 2") + require.EqualError(t, err, "incorrect capsule header value type: int64") + require.Equal(t, http.StatusBadRequest, err.(*RequestParseError).HTTPStatus) + }) + + t.Run("invalid Capsule-Protocol header value", func(t *testing.T) { + req := newRequest("https://localhost:1234/masque") + v, err := httpsfv.Marshal(httpsfv.NewItem(false)) + require.NoError(t, err) + req.Header.Set("Capsule-Protocol", v) + _, err = ParseRequest(req, template) + require.EqualError(t, err, "incorrect capsule header value: false") require.Equal(t, http.StatusBadRequest, err.(*RequestParseError).HTTPStatus) })