From 90ed22b20e0996ce439e768eda07a1864d6595e3 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Fri, 1 Dec 2023 08:10:59 -0500 Subject: [PATCH] Improve GET request handling: clients should not include content headers, servers should disallow a body (#644) --- client_ext_test.go | 26 ++++++++++++++++++++++++++ handler.go | 17 +++++++++++++++++ handler_ext_test.go | 30 ++++++++++++++++++++++++++++++ protocol.go | 10 ++++++---- protocol_connect.go | 7 +++++-- 5 files changed, 84 insertions(+), 6 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 5bb90cd3..bb009a8b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -190,6 +190,32 @@ func TestGetNotModified(t *testing.T) { assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } +func TestGetNoContentHeaders(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{})) + server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { + if len(req.Header.Values("content-type")) > 0 || + len(req.Header.Values("content-encoding")) > 0 || + len(req.Header.Values("content-length")) > 0 { + http.Error(respWriter, "GET request should not include content headers", http.StatusBadRequest) + } + mux.ServeHTTP(respWriter, req) + })) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithHTTPGet(), + ) + ctx := context.Background() + + unaryReq := connect.NewRequest(&pingv1.PingRequest{}) + _, err := client.Ping(ctx, unaryReq) + assert.Nil(t, err) + assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) +} + func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() diff --git a/handler.go b/handler.go index 05906fe2..0a4ce659 100644 --- a/handler.go +++ b/handler.go @@ -222,6 +222,23 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re return } + if request.Method == http.MethodGet { + // A body must not be present. + hasBody := request.ContentLength > 0 + if request.ContentLength < 0 { + // No content-length header. + // Test if body is empty by trying to read a single byte. + var b [1]byte + n, _ := request.Body.Read(b[:]) + hasBody = n > 0 + } + if hasBody { + responseWriter.WriteHeader(http.StatusUnsupportedMediaType) + return + } + _ = request.Body.Close() + } + // Establish a stream and serve the RPC. setHeaderCanonical(request.Header, headerContentType, contentType) setHeaderCanonical(request.Header, headerHost, request.Host) diff --git a/handler_ext_test.go b/handler_ext_test.go index 7cab9727..6d2d225c 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -80,6 +80,36 @@ func TestHandler_ServeHTTP(t *testing.T) { assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType) }) + t.Run("get_method_body_not_allowed", func(t *testing.T) { + t.Parallel() + const queryStringSuffix = `?encoding=json&message={}` + request, err := http.NewRequestWithContext( + context.Background(), + http.MethodGet, + server.URL()+pingProcedure+queryStringSuffix, + strings.NewReader("!"), // non-empty body + ) + assert.Nil(t, err) + resp, err := client.Do(request) + assert.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType) + + // Same thing, but this time w/ a content-length header + request, err = http.NewRequestWithContext( + context.Background(), + http.MethodGet, + server.URL()+pingProcedure+queryStringSuffix, + strings.NewReader("!"), // non-empty body + ) + assert.Nil(t, err) + request.Header.Set("content-length", "1") + resp, err = client.Do(request) + assert.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType) + }) + t.Run("idempotent_get_method", func(t *testing.T) { t.Parallel() request, err := http.NewRequestWithContext( diff --git a/protocol.go b/protocol.go index cc6455cd..e85565ab 100644 --- a/protocol.go +++ b/protocol.go @@ -35,10 +35,12 @@ const ( ) const ( - headerContentType = "Content-Type" - headerHost = "Host" - headerUserAgent = "User-Agent" - headerTrailer = "Trailer" + headerContentType = "Content-Type" + headerContentEncoding = "Content-Encoding" + headerContentLength = "Content-Length" + headerHost = "Host" + headerUserAgent = "User-Agent" + headerTrailer = "Trailer" discardLimit = 1024 * 1024 * 4 // 4MiB ) diff --git a/protocol_connect.go b/protocol_connect.go index fe25dcbb..299cb830 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -874,7 +874,7 @@ func (u *connectStreamingUnmarshaler) Unmarshal(message any) *Error { for name, value := range end.Trailer { canonical := http.CanonicalHeaderKey(name) if name != canonical { - delete(end.Trailer, name) + delHeaderCanonical(end.Trailer, name) end.Trailer[canonical] = append(end.Trailer[canonical], value...) } } @@ -1045,7 +1045,10 @@ func (m *connectUnaryRequestMarshaler) buildGetURL(data []byte, compressed bool) } func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { - delete(m.header, connectHeaderProtocolVersion) + delHeaderCanonical(m.header, connectHeaderProtocolVersion) + delHeaderCanonical(m.header, headerContentType) + delHeaderCanonical(m.header, headerContentEncoding) + delHeaderCanonical(m.header, headerContentLength) m.duplexCall.SetMethod(http.MethodGet) *m.duplexCall.URL() = *url return nil