Skip to content

Commit

Permalink
Fix ErrorWriter to be codec agnostic (#701)
Browse files Browse the repository at this point in the history
This PR changes the ErrorWriter to be more lenient with classifying
protocols. Errors codecs are agnostic to the codec used. Therefore we
avoid checking the codec in classifying the request. IsSupported will
return true for an unknown codec which allows the server to encode a
better error message to the client. If not supported a 415 error
response could be used to match gRPC server like handling. If not
supported and trying to write an error the ErrorWriter will default to
connects unary encoding (consistent with current behaviour).

Fixes #689
  • Loading branch information
emcfarlane authored Mar 12, 2024
1 parent e3f35a6 commit 872a6fd
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 72 deletions.
4 changes: 2 additions & 2 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2423,7 +2423,7 @@ func TestClientDisconnect(t *testing.T) {
assert.NotNil(t, err)
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled, assert.Sprintf("got %v", handlerReceiveErr))
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
t.Run("handler_writes", func(t *testing.T) {
Expand All @@ -2434,7 +2434,7 @@ func TestClientDisconnect(t *testing.T) {
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
countUp: func(ctx context.Context, _ *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
close(gotRequest)
var err error
for err == nil {
Expand Down
84 changes: 26 additions & 58 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,86 +41,54 @@ const (
type ErrorWriter struct {
bufferPool *bufferPool
protobuf Codec
grpcContentTypes map[string]struct{}
grpcWebContentTypes map[string]struct{}
unaryConnectContentTypes map[string]struct{}
streamingConnectContentTypes map[string]struct{}
requireConnectProtocolHeader bool
}

// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
// RPC Content-Types in net/http middleware, you must pass the same
// HandlerOptions to NewErrorWriter and any wrapped Connect handlers.
// NewErrorWriter constructs an ErrorWriter. Handler options may be passed to
// configure the error writer behaviour to match the handlers.
// [WithRequiredConnectProtocolHeader] will assert that Connect protocol
// requests include the version header allowing the error writer to correctly
// classify the request.
// Options supplied via [WithConditionalHandlerOptions] are ignored.
func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
config := newHandlerConfig("", StreamTypeUnary, opts)
writer := &ErrorWriter{
codecs := newReadOnlyCodecs(config.Codecs)
return &ErrorWriter{
bufferPool: config.BufferPool,
protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(),
grpcContentTypes: make(map[string]struct{}),
grpcWebContentTypes: make(map[string]struct{}),
unaryConnectContentTypes: make(map[string]struct{}),
streamingConnectContentTypes: make(map[string]struct{}),
protobuf: codecs.Protobuf(),
requireConnectProtocolHeader: config.RequireConnectProtocolHeader,
}
for name := range config.Codecs {
unary := connectContentTypeFromCodecName(StreamTypeUnary, name)
writer.unaryConnectContentTypes[unary] = struct{}{}
streaming := connectContentTypeFromCodecName(StreamTypeBidi, name)
writer.streamingConnectContentTypes[streaming] = struct{}{}
}
if config.HandleGRPC {
writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{}
for name := range config.Codecs {
ct := grpcContentTypeFromCodecName(false /* web */, name)
writer.grpcContentTypes[ct] = struct{}{}
}
}
if config.HandleGRPCWeb {
writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{}
for name := range config.Codecs {
ct := grpcContentTypeFromCodecName(true /* web */, name)
writer.grpcWebContentTypes[ct] = struct{}{}
}
}
return writer
}

func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
if _, ok := w.unaryConnectContentTypes[ctype]; ok {
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
}
if _, ok := w.streamingConnectContentTypes[ctype]; ok {
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
}
if _, ok := w.grpcContentTypes[ctype]; ok {
return grpcProtocol
}
if _, ok := w.grpcWebContentTypes[ctype]; ok {
return grpcWebProtocol
}
// Check for Connect-Protocol-Version header or connect protocol query
// parameter to support connect GET requests.
if request.Method == http.MethodGet {
connectVersion := getHeaderCanonical(request.Header, connectProtocolVersion)
if connectVersion == connectProtocolVersion {
return connectUnaryProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
connectVersion = request.URL.Query().Get(connectUnaryConnectQueryParameter)
if connectVersion == connectUnaryConnectQueryValue {
return connectUnaryProtocol
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
return unknownProtocol
}

// IsSupported checks whether a request is using one of the ErrorWriter's
Expand Down
62 changes: 60 additions & 2 deletions error_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ import (

func TestErrorWriter(t *testing.T) {
t.Parallel()

t.Run("RequireConnectProtocolHeader", func(t *testing.T) {
t.Parallel()
writer := NewErrorWriter(WithRequireConnectProtocolHeader())

t.Run("Unary", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON)
Expand All @@ -52,4 +50,64 @@ func TestErrorWriter(t *testing.T) {
assert.True(t, writer.IsSupported(req))
})
})
t.Run("Protocols", func(t *testing.T) {
t.Parallel()
writer := NewErrorWriter() // All supported by default
t.Run("ConnectUnary", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON)
assert.True(t, writer.IsSupported(req))
})
t.Run("ConnectUnaryGET", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
assert.True(t, writer.IsSupported(req))
})
t.Run("ConnectStream", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON)
assert.True(t, writer.IsSupported(req))
})
t.Run("GRPC", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", grpcContentTypeDefault)
assert.True(t, writer.IsSupported(req))
req.Header.Set("Content-Type", grpcContentTypePrefix+"json")
assert.True(t, writer.IsSupported(req))
})
t.Run("GRPCWeb", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", grpcWebContentTypeDefault)
assert.True(t, writer.IsSupported(req))
req.Header.Set("Content-Type", grpcWebContentTypePrefix+"json")
assert.True(t, writer.IsSupported(req))
})
})
t.Run("UnknownCodec", func(t *testing.T) {
// An Unknown codec should return supported as the protocol is known and
// the error codec is agnostic to the codec used. The server can respond
// with a protocol error for the unknown codec.
t.Parallel()
writer := NewErrorWriter()
unknownCodec := "invalid"
t.Run("ConnectUnary", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectUnaryContentTypePrefix+unknownCodec)
assert.True(t, writer.IsSupported(req))
})
t.Run("ConnectStream", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", connectStreamingContentTypePrefix+unknownCodec)
assert.True(t, writer.IsSupported(req))
})
t.Run("GRPC", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", grpcContentTypePrefix+unknownCodec)
assert.True(t, writer.IsSupported(req))
})
t.Run("GRPCWeb", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", nil)
req.Header.Set("Content-Type", grpcWebContentTypePrefix+unknownCodec)
assert.True(t, writer.IsSupported(req))
})
})
}
14 changes: 4 additions & 10 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ type handlerConfig struct {
Procedure string
Schema any
Initializer maybeInitializer
HandleGRPC bool
HandleGRPCWeb bool
RequireConnectProtocolHeader bool
IdempotencyLevel IdempotencyLevel
BufferPool *bufferPool
Expand All @@ -290,8 +288,6 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
Codecs: make(map[string]Codec),
HandleGRPC: true,
HandleGRPCWeb: true,
BufferPool: newBufferPool(),
StreamType: streamType,
}
Expand All @@ -314,12 +310,10 @@ func (c *handlerConfig) newSpec() Spec {
}

func (c *handlerConfig) newProtocolHandlers() []protocolHandler {
protocols := []protocol{&protocolConnect{}}
if c.HandleGRPC {
protocols = append(protocols, &protocolGRPC{web: false})
}
if c.HandleGRPCWeb {
protocols = append(protocols, &protocolGRPC{web: true})
protocols := []protocol{
&protocolConnect{},
&protocolGRPC{web: false},
&protocolGRPC{web: true},
}
handlers := make([]protocolHandler, 0, len(protocols))
codecs := newReadOnlyCodecs(c.Codecs)
Expand Down

0 comments on commit 872a6fd

Please sign in to comment.