Skip to content

Commit

Permalink
Improve GET request handling: clients should not include content head…
Browse files Browse the repository at this point in the history
…ers, servers should disallow a body (#644)
  • Loading branch information
jhump authored Dec 1, 2023
1 parent 996f8b9 commit 90ed22b
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 6 deletions.
26 changes: 26 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,32 @@ func TestGetNotModified(t *testing.T) {
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

func TestGetNoContentHeaders(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{}))
server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) {
if len(req.Header.Values("content-type")) > 0 ||
len(req.Header.Values("content-encoding")) > 0 ||
len(req.Header.Values("content-length")) > 0 {
http.Error(respWriter, "GET request should not include content headers", http.StatusBadRequest)
}
mux.ServeHTTP(respWriter, req)
}))
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
connect.WithHTTPGet(),
)
ctx := context.Background()

unaryReq := connect.NewRequest(&pingv1.PingRequest{})
_, err := client.Ping(ctx, unaryReq)
assert.Nil(t, err)
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

func TestSpecSchema(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
Expand Down
17 changes: 17 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
return
}

if request.Method == http.MethodGet {
// A body must not be present.
hasBody := request.ContentLength > 0
if request.ContentLength < 0 {
// No content-length header.
// Test if body is empty by trying to read a single byte.
var b [1]byte
n, _ := request.Body.Read(b[:])
hasBody = n > 0
}
if hasBody {
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
return
}
_ = request.Body.Close()
}

// Establish a stream and serve the RPC.
setHeaderCanonical(request.Header, headerContentType, contentType)
setHeaderCanonical(request.Header, headerHost, request.Host)
Expand Down
30 changes: 30 additions & 0 deletions handler_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ func TestHandler_ServeHTTP(t *testing.T) {
assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType)
})

t.Run("get_method_body_not_allowed", func(t *testing.T) {
t.Parallel()
const queryStringSuffix = `?encoding=json&message={}`
request, err := http.NewRequestWithContext(
context.Background(),
http.MethodGet,
server.URL()+pingProcedure+queryStringSuffix,
strings.NewReader("!"), // non-empty body
)
assert.Nil(t, err)
resp, err := client.Do(request)
assert.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType)

// Same thing, but this time w/ a content-length header
request, err = http.NewRequestWithContext(
context.Background(),
http.MethodGet,
server.URL()+pingProcedure+queryStringSuffix,
strings.NewReader("!"), // non-empty body
)
assert.Nil(t, err)
request.Header.Set("content-length", "1")
resp, err = client.Do(request)
assert.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType)
})

t.Run("idempotent_get_method", func(t *testing.T) {
t.Parallel()
request, err := http.NewRequestWithContext(
Expand Down
10 changes: 6 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ const (
)

const (
headerContentType = "Content-Type"
headerHost = "Host"
headerUserAgent = "User-Agent"
headerTrailer = "Trailer"
headerContentType = "Content-Type"
headerContentEncoding = "Content-Encoding"
headerContentLength = "Content-Length"
headerHost = "Host"
headerUserAgent = "User-Agent"
headerTrailer = "Trailer"

discardLimit = 1024 * 1024 * 4 // 4MiB
)
Expand Down
7 changes: 5 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ func (u *connectStreamingUnmarshaler) Unmarshal(message any) *Error {
for name, value := range end.Trailer {
canonical := http.CanonicalHeaderKey(name)
if name != canonical {
delete(end.Trailer, name)
delHeaderCanonical(end.Trailer, name)
end.Trailer[canonical] = append(end.Trailer[canonical], value...)
}
}
Expand Down Expand Up @@ -1045,7 +1045,10 @@ func (m *connectUnaryRequestMarshaler) buildGetURL(data []byte, compressed bool)
}

func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error {
delete(m.header, connectHeaderProtocolVersion)
delHeaderCanonical(m.header, connectHeaderProtocolVersion)
delHeaderCanonical(m.header, headerContentType)
delHeaderCanonical(m.header, headerContentEncoding)
delHeaderCanonical(m.header, headerContentLength)
m.duplexCall.SetMethod(http.MethodGet)
*m.duplexCall.URL() = *url
return nil
Expand Down

0 comments on commit 90ed22b

Please sign in to comment.