Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap errors with context cancellation codes #659

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,9 @@ issues:
# We want to show examples with http.Get
- linters: [noctx]
path: internal/memhttp/memhttp_test.go
# We need envelope readers and writers to have access to the context for error handling.
- linters: [containedctx]
path: envelope.go
# We need marshallers and unmarshallers to have access to the context for error handling.
- linters: [containedctx]
path: protocol_connect.go
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
128 changes: 128 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (
"compress/flate"
"compress/gzip"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"runtime"
"strings"
Expand Down Expand Up @@ -2276,6 +2278,132 @@ func TestStreamUnexpectedEOF(t *testing.T) {
}
}

// TestClientDisconnect tests that the handler receives a CodeCanceled error when
// the client abruptly disconnects.
func TestClientDisconnect(t *testing.T) {
t.Parallel()
captureTransportConn := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}, http2 bool) http.RoundTripper {
if http2 {
transport := server.Transport()
dialContext := transport.DialTLSContext
transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr, cfg)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
transport := server.TransportHTTP1()
dialContext := transport.DialContext
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
testTransportClosure := func(t *testing.T, http2 bool) { //nolint:thelper
t.Run("handler_reads", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
close(gotRequest)
for stream.Receive() {
// Do nothing
}
handlerReceiveErr = stream.Err()
handlerContextErr = ctx.Err()
close(gotResponse)
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransportConn(server, &clientConn, gotRequest, http2)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream := client.Sum(context.Background())
// Send header.
assert.Nil(t, stream.Send(nil))
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
_, err := stream.CloseAndReceive()
assert.NotNil(t, err)
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
t.Run("handler_writes", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
close(gotRequest)
var err error
for err == nil {
err = stream.Send(&pingv1.CountUpResponse{})
}
handlerReceiveErr = err
handlerContextErr = ctx.Err()
close(gotResponse)
return nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransportConn(server, &clientConn, gotRequest, http2)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{}))
if !assert.Nil(t, err) {
return
}
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
for stream.Receive() {
// Do nothing
}
assert.NotNil(t, stream.Err())
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
}
testTransportClosure(t, true)
testTransportClosure(t, false)
}

// TestBlankImportCodeGeneration tests that services.connect.go is generated with
// blank import statements to services.pb.go so that the service's Descriptor is
// available in the global proto registry.
Expand Down
19 changes: 8 additions & 11 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
Expand Down Expand Up @@ -117,6 +118,7 @@ func (e *envelope) Len() int {
}

type envelopeWriter struct {
ctx context.Context
Copy link
Contributor Author

@emcfarlane emcfarlane Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the ctx in the envelope reader is a little odd, but would otherwise be needed to be stored on the stream interceptors and then passed down which makes this problem worse. Would be nice to find a better solution though!

sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -209,6 +211,7 @@ func (w *envelopeWriter) marshal(message any) *Error {
func (w *envelopeWriter) write(env *envelope) *Error {
if _, err := w.sender.Send(env); err != nil {
err = wrapIfContextError(err)
err = wrapWithContextError(w.ctx, err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this would be better to push the error classification into sender.Send? The sender (the duplexHTTPCall) has a reference to the request, which should also include the context.

Also, under what conditions would we ever call only one or the other of those context-wrap functions? Feels like they want to be a single function -- maybe even a method on duplexHTTPCall (so it can pull out the request context for the comparisons in the latter functions).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call wrapWithContext more liberally to convert context errors but not to wrap errors with context codes. We will always call wrapIfContext with wrapWithContext but still need wrapIfContext separate so left them split out. A method on duplexHTTPCall would need a corresponding method on the handler side. The marshallers are used on both client/server so can handle it uniformly there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will always call wrapIfContext with wrapWithContext but still need wrapIfContext separate so left them split out.

Why do they need to be split? Under what cases do we care about the error being a context error but if the context is actually done (or vice versa)? Also, the names are rather confusing. Just reading the above sentence is really not clear as to which is which.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for context errors on non IO related errors like the wrapper to the stream handlers. I've merged the call to the wrapIfContextError from wrapIfContextDone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/emcfarlane/connect-go/blob/c3324072f50f3882a98e4ee846ee7511d3911efa/error.go#L279

Which we don't currently have an easy way to access the context from.

if connectErr, ok := asError(err); ok {
return connectErr
}
Expand All @@ -218,6 +221,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
}

type envelopeReader struct {
ctx context.Context
reader io.Reader
codec Codec
last envelope
Expand Down Expand Up @@ -305,17 +309,12 @@ func (r *envelopeReader) Read(env *envelope) *Error {
return NewError(CodeUnknown, err)
}
err = wrapIfContextError(err)
err = wrapWithContextError(r.ctx, err)
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
if connectErr, ok := asError(err); ok {
return connectErr
}
// Something else has gone wrong - the stream didn't end cleanly.
if connectErr, ok := asError(err); ok {
return connectErr
}
if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
return errorf(
CodeInvalidArgument,
"protocol error: incomplete envelope: %w", err,
Expand All @@ -333,10 +332,6 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand All @@ -348,6 +343,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
)
}
err = wrapIfContextError(err)
err = wrapWithContextError(r.ctx, err)
err = wrapIfMaxBytesError(err, "read %d byte message", size)
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down
2 changes: 2 additions & 0 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"io"
"testing"

Expand Down Expand Up @@ -44,6 +45,7 @@ func TestEnvelope(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
ctx: context.Background(),
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
Expand Down
32 changes: 30 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ func wrapIfContextError(err error) error {
return err
}

// wrapWithContextError wraps errors with CodeCanceled or CodeDeadlineExceeded
// if the context is done. It leaves already-wrapped errors unchanged.
func wrapWithContextError(ctx context.Context, err error) error {
if err == nil {
return nil
}
if _, ok := asError(err); ok {
return err
}
ctxErr := ctx.Err()
switch {
case errors.Is(ctxErr, context.Canceled):
return NewError(CodeCanceled, err)
case errors.Is(ctxErr, context.DeadlineExceeded):
return NewError(CodeDeadlineExceeded, err)
default:
return err
}
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
}

// wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message
// telling the caller that they likely need to use h2c but are using a raw http.Client{}.
//
Expand Down Expand Up @@ -408,10 +428,18 @@ func wrapIfRSTError(err error) error {
}
}

func asMaxBytesError(err error, tmpl string, args ...any) *Error {
// wrapIfMaxBytesError wraps errors returned reading from a http.MaxBytesHandler
// whose limit has been exceeded.
func wrapIfMaxBytesError(err error, tmpl string, args ...any) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted wrapIfMaxBytesError to the same style as other error handling for consistency.

if err == nil {
return nil
}
if _, ok := asError(err); ok {
return err
}
var maxBytesErr *http.MaxBytesError
if ok := errors.As(err, &maxBytesErr); !ok {
return nil
return err
}
prefix := fmt.Sprintf(tmpl, args...)
return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit)
Expand Down
16 changes: 13 additions & 3 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func (h *connectHandler) NewConn(
responseWriter http.ResponseWriter,
request *http.Request,
) (handlerConnCloser, bool) {
ctx := request.Context()
query := request.URL.Query()
// We need to parse metadata before entering the interceptor stack; we'll
// send the error to the client later on.
Expand Down Expand Up @@ -254,6 +255,7 @@ func (h *connectHandler) NewConn(
request: request,
responseWriter: responseWriter,
marshaler: connectUnaryMarshaler{
ctx: ctx,
sender: writeSender{writer: responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -264,6 +266,7 @@ func (h *connectHandler) NewConn(
sendMaxBytes: h.SendMaxBytes,
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand All @@ -280,6 +283,7 @@ func (h *connectHandler) NewConn(
responseWriter: responseWriter,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: writeSender{responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -290,6 +294,7 @@ func (h *connectHandler) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand Down Expand Up @@ -375,6 +380,7 @@ func (c *connectClient) NewConn(
bufferPool: c.BufferPool,
marshaler: connectUnaryRequestMarshaler{
connectUnaryMarshaler: connectUnaryMarshaler{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -386,6 +392,7 @@ func (c *connectClient) NewConn(
},
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -415,6 +422,7 @@ func (c *connectClient) NewConn(
codec: c.Codec,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -425,6 +433,7 @@ func (c *connectClient) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -892,6 +901,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error {
}

type connectUnaryMarshaler struct {
ctx context.Context
sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -1057,6 +1067,7 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error {
}

type connectUnaryUnmarshaler struct {
ctx context.Context
reader io.Reader
codec Codec
compressionPool *compressionPool
Expand Down Expand Up @@ -1084,12 +1095,11 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
bytesRead, err := data.ReadFrom(reader)
if err != nil {
err = wrapIfContextError(err)
err = wrapWithContextError(u.ctx, err)
err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead)
if connectErr, ok := asError(err); ok {
return connectErr
}
if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil {
return readMaxBytesErr
}
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
Expand Down
Loading