diff --git a/option.go b/option.go index 703f8a58..056bc0f7 100644 --- a/option.go +++ b/option.go @@ -138,9 +138,9 @@ func WithHandlerOptions(options ...HandlerOption) HandlerOption { // WithRecover adds an interceptor that recovers from panics. The supplied // function receives the context, [Spec], request headers, and the recovered -// value (which may be nil). It must return an error to send back to the -// client. It may also log the panic, emit metrics, or execute other -// error-handling logic. Handler functions must be safe to call concurrently. +// value. It must return an error to send back to the client. It may also log +// the panic, emit metrics, or execute other error-handling logic. Handler +// functions must be safe to call concurrently. // // To preserve compatibility with [net/http]'s semantics, this interceptor // doesn't handle panics with [http.ErrAbortHandler]. @@ -150,8 +150,17 @@ func WithHandlerOptions(options ...HandlerOption) HandlerOption { // usually necessary to prevent crashes. Instead, it helps servers collect // RPC-specific data during panics and send a more detailed error to // clients. +// +// Deprecated: Use RecoverInterceptor to create an interceptor and +// WithInterceptors to register it via HandlerOption. func WithRecover(handle func(context.Context, Spec, http.Header, any) error) HandlerOption { - return WithInterceptors(&recoverHandlerInterceptor{handle: handle}) + return WithInterceptors(RecoverInterceptor(func(ctx context.Context, req AnyRequest, panicValue any) error { + //nolint:errorlint,goerr113 // net/http checks for ErrAbortHandler with ==, so we should too + if panicValue == http.ErrAbortHandler { + panic(panicValue) //nolint:forbidigo + } + return handle(ctx, req.Spec(), req.Header(), panicValue) + })) } // WithRequireConnectProtocolHeader configures the Handler to require requests diff --git a/recover.go b/recover.go index 4f4c5b30..941c4ba7 100644 --- a/recover.go +++ b/recover.go @@ -17,48 +17,292 @@ package connect import ( "context" "net/http" + "sync/atomic" ) -// recoverHandlerInterceptor lets handlers trap panics, perform side effects -// (like emitting logs or metrics), and present a friendlier error message to -// clients. -type recoverHandlerInterceptor struct { - Interceptor +// RecoverInterceptor is an interceptor that recovers from panics. The +// supplied function receives the context and request details. +// +// For streaming RPCs, req.Any() may return nil. It will always be nil +// for client-streaming or bidi-streaming RPCs, since there could be +// zero or even multiple request messages for such RPCs. For +// server-streaming RPCs, it will be nil if the panic occurred before +// the request message was received, which can happen if a panic occurs +// in an interceptor before the RPC handler method is invoked. +// +// Similarly, for streaming RPCs, req.Header() may return nil. This +// could happen in clients when the panic that is recovered occurs +// before the stream is actually created and before request headers are +// even allocated. +// +// Applications will generally want to add this interceptor first, which +// means it will actually be the last to handle any results from the +// RPC handler. This allows for recovering from the panics not only in +// the handler but also in any other interceptors. +// +// The recovered value will never be nil. If panic was called with a nil +// value, the recovered value will be a *[runtime.PanicNilError]. It must +// return an error to send back to the client. If it returns nil, an +// *Error with a code of CodeInternal will ne synthesized. The function +// may also log the panic, emit metrics, or execute other error-handling +// logic. The function must be safe to call concurrently. +// +// By default, handlers don't recover from panics. Because the standard +// library's [http.Server] recovers from panics by default, this option +// isn't usually necessary to prevent crashes. Instead, it helps servers +// collect RPC-specific data during panics and send a more detailed error +// to clients. +// +// Unlike [WithRecover], this interceptor does not do anything special with +// [http.ErrAbortHandler], so the handle function may be called with that as +// the panic value. +// +// Also unlike [WithRecover], which can only be used with handlers, this +// interceptor can be used with clients, to recover from any panics caused +// by bugs in the interceptor chain. For streaming RPCs, this will recover +// from panics that happen in calls to send or receive messages on the +// stream or to close the stream. +func RecoverInterceptor(handle func(ctx context.Context, req AnyRequest, panicValue any) error) Interceptor { + return &recoverHandlerInterceptor{handle: handle} +} - handle func(context.Context, Spec, http.Header, any) error +type recoverHandlerInterceptor struct { + handle func(context.Context, AnyRequest, any) error } func (i *recoverHandlerInterceptor) WrapUnary(next UnaryFunc) UnaryFunc { return func(ctx context.Context, req AnyRequest) (_ AnyResponse, retErr error) { - if req.Spec().IsClient { - return next(ctx, req) - } defer func() { if r := recover(); r != nil { - // net/http checks for ErrAbortHandler with ==, so we should too. - if r == http.ErrAbortHandler { //nolint:errorlint,goerr113 - panic(r) //nolint:forbidigo + retErr = i.handle(ctx, req, r) + if retErr == nil { + retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error") } - retErr = i.handle(ctx, req.Spec(), req.Header(), r) } }() - res, err := next(ctx, req) - return res, err + return next(ctx, req) } } func (i *recoverHandlerInterceptor) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandlerFunc { return func(ctx context.Context, conn StreamingHandlerConn) (retErr error) { + var streamConn *recoverStreamingHandlerConn + if conn.Spec().StreamType == StreamTypeServer { + // There will be exactly one request. So we try to capture it + // so we can provide it to the recover handle func. + streamConn = &recoverStreamingHandlerConn{StreamingHandlerConn: conn} + conn = streamConn + } + defer func() { - if r := recover(); r != nil { - // net/http checks for ErrAbortHandler with ==, so we should too. - if r == http.ErrAbortHandler { //nolint:errorlint,goerr113 - panic(r) //nolint:forbidigo + if panicVal := recover(); panicVal != nil { + var msg any + if streamConn != nil { + if msgPtr := streamConn.req.Load(); msgPtr != nil { + msg = *msgPtr + } + } + retErr = i.handle(ctx, &recoverStreamRequest{conn, msg}, panicVal) + if retErr == nil { + retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error") + } + } + }() + return next(ctx, conn) + } +} + +func (i *recoverHandlerInterceptor) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { + return func(ctx context.Context, spec Spec) (conn StreamingClientConn) { + defer func() { + if panicVal := recover(); panicVal != nil { + err := i.handle(ctx, emptyRequest(spec), panicVal) + if err == nil { + err = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error") } - retErr = i.handle(ctx, conn.Spec(), conn.RequestHeader(), r) + conn = &errStreamingClientConn{spec, err} } }() - err := next(ctx, conn) - return err + conn = next(ctx, spec) + return &recoverStreamingClientConn{ + StreamingClientConn: conn, + ctx: ctx, + handle: i.handle, + } + } +} + +type recoverStreamRequest struct { + StreamingHandlerConn + msg any +} + +func (r *recoverStreamRequest) Any() any { + return r.msg +} + +func (r *recoverStreamRequest) Header() http.Header { + return r.RequestHeader() +} + +func (r *recoverStreamRequest) HTTPMethod() string { + return http.MethodPost // streams always use POST +} + +func (r *recoverStreamRequest) internalOnly() { +} + +func (r *recoverStreamRequest) setRequestMethod(_ string) { + // only invoked internally for unary RPCs; safe to ignore +} + +type recoverStreamingHandlerConn struct { + StreamingHandlerConn + req atomic.Pointer[any] +} + +func (r *recoverStreamingHandlerConn) Receive(msg any) error { + err := r.StreamingHandlerConn.Receive(msg) + if err == nil { + // Note: The framework instantiates msg, passes it to + // this method, and then returns it to the application. + // It is possible that the application could mutate the + // value, so what we provide to the recover handler would + // then differ from the message actually received. But + // this is no different than if the RPC handler mutated + // the request message for a unary RPC and interceptors + // later examined it via Request.Any. So we tolerate the + // possibility for server-stream requests, too. + r.req.Store(&msg) } + return err +} + +type emptyRequest Spec + +func (e emptyRequest) Any() any { + return nil +} + +func (e emptyRequest) Spec() Spec { + return Spec(e) +} + +func (e emptyRequest) Peer() Peer { + return Peer{} +} + +func (e emptyRequest) Header() http.Header { + return nil +} + +func (e emptyRequest) HTTPMethod() string { + return http.MethodPost +} + +func (e emptyRequest) internalOnly() { +} + +func (e emptyRequest) setRequestMethod(_ string) { + // only invoked internally for unary RPCs; safe to ignore +} + +type errStreamingClientConn struct { + spec Spec + err error +} + +func (e *errStreamingClientConn) Spec() Spec { + return e.spec +} + +func (e *errStreamingClientConn) Peer() Peer { + return Peer{} +} + +func (e *errStreamingClientConn) Send(_ any) error { + return e.err +} + +func (e *errStreamingClientConn) RequestHeader() http.Header { + // Clients can add headers before calling Send, so this must be mutable/non-nil. + return http.Header{} // TODO: memoize so we never allocate more than one? +} + +func (e *errStreamingClientConn) CloseRequest() error { + return e.err +} + +func (e *errStreamingClientConn) Receive(_ any) error { + return e.err +} + +func (e *errStreamingClientConn) ResponseHeader() http.Header { + return nil +} + +func (e *errStreamingClientConn) ResponseTrailer() http.Header { + return nil +} + +func (e *errStreamingClientConn) CloseResponse() error { + return e.err +} + +type recoverStreamingClientConn struct { + StreamingClientConn + + //nolint:containedctx // must memoize the stream context to pass to recover handler + ctx context.Context + handle func(context.Context, AnyRequest, any) error + req atomic.Pointer[any] +} + +func (r *recoverStreamingClientConn) Send(msg any) error { + if r.Spec().StreamType == StreamTypeServer { + // Capture the request message for server-streaming RPCs. + r.req.Store(&msg) + } + return r.invoke(func() error { + return r.StreamingClientConn.Send(msg) + }) +} + +func (r *recoverStreamingClientConn) RequestHeader() http.Header { + if header := r.StreamingClientConn.RequestHeader(); header != nil { + return header + } + // Clients can add headers before calling Send, so this must be mutable/non-nil. + // We do this not to recover from a panic but in the hopes of preventing panics in the caller. + return http.Header{} // TODO: memoize so we never allocate more than one? +} + +func (r *recoverStreamingClientConn) CloseRequest() error { + return r.invoke(r.StreamingClientConn.CloseRequest) +} + +func (r *recoverStreamingClientConn) Receive(msg any) error { + return r.invoke(func() error { + return r.StreamingClientConn.Receive(msg) + }) +} + +func (r *recoverStreamingClientConn) CloseResponse() error { + return r.invoke(r.StreamingClientConn.CloseResponse) +} + +func (r *recoverStreamingClientConn) invoke(action func() error) (retErr error) { + defer func() { + if panicVal := recover(); panicVal != nil { + var msg any + if msgPtr := r.req.Load(); msgPtr != nil { + msg = *msgPtr + } + retErr = r.handle(r.ctx, &recoverStreamRequest{r, msg}, panicVal) + if retErr == nil { + retErr = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error") + } + } + }() + return action() } diff --git a/recover_ext_test.go b/recover_ext_test.go index a09f461c..53a3f784 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -18,9 +18,10 @@ import ( "context" "fmt" "net/http" + "sync/atomic" "testing" - connect "connectrpc.com/connect" + "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" @@ -34,12 +35,24 @@ type panicPingServer struct { } func (s *panicPingServer) Ping( - context.Context, - *connect.Request[pingv1.PingRequest], + _ context.Context, + _ *connect.Request[pingv1.PingRequest], ) (*connect.Response[pingv1.PingResponse], error) { panic(s.panicWith) //nolint:forbidigo } +func (s *panicPingServer) Sum( + _ context.Context, + stream *connect.ClientStream[pingv1.SumRequest], +) (*connect.Response[pingv1.SumResponse], error) { + if !stream.Receive() { + if err := stream.Err(); err != nil { + return nil, err + } + } + panic(s.panicWith) //nolint:forbidigo +} + func (s *panicPingServer) CountUp( _ context.Context, _ *connect.Request[pingv1.CountUpRequest], @@ -51,6 +64,20 @@ func (s *panicPingServer) CountUp( panic(s.panicWith) //nolint:forbidigo } +func (s *panicPingServer) CumSum( + _ context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], +) error { + req, err := stream.Receive() + if err != nil { + return err + } + if err := stream.Send(&pingv1.CumSumResponse{Sum: req.Number}); err != nil { + return err + } + panic(s.panicWith) //nolint:forbidigo +} + func TestWithRecover(t *testing.T) { t.Parallel() handle := func(_ context.Context, _ connect.Spec, _ http.Header, r any) error { @@ -69,7 +96,7 @@ func TestWithRecover(t *testing.T) { } drainStream := func(stream *connect.ServerStreamForClient[pingv1.CountUpResponse]) error { t.Helper() - defer stream.Close() + defer assertNoError(t, stream.Close) assert.True(t, stream.Receive()) // expect one response msg assert.False(t, stream.Receive()) // expect panic before second response msg return stream.Err() @@ -103,3 +130,109 @@ func TestWithRecover(t *testing.T) { assert.Nil(t, err) assertNotHandled(drainStream(stream)) } + +//nolint:tparallel // we can't run sub-tests in parallel due to server's statefulness +func TestRecoverInterceptor(t *testing.T) { + t.Parallel() + var check atomic.Value + handle := func(ctx context.Context, req connect.AnyRequest, panicVal any) error { + fn, ok := check.Load().(func(context.Context, connect.AnyRequest, any)) + if ok { + fn(ctx, req, panicVal) + } + return connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("panic: %v", panicVal)) + } + assertHandled := func(err error) { + t.Helper() + assert.NotNil(t, err) + assert.Equal(t, connect.CodeOf(err), connect.CodeFailedPrecondition) + } + drainServerStream := func(stream *connect.ServerStreamForClient[pingv1.CountUpResponse]) error { + t.Helper() + defer assertNoError(t, stream.Close) + assert.True(t, stream.Receive()) // expect one response msg + assert.False(t, stream.Receive()) // expect panic before second response msg + return stream.Err() + } + drainBidiStream := func(stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + t.Helper() + assertNoError(t, stream.CloseRequest) + defer assertNoError(t, stream.CloseResponse) + resp, err := stream.Receive() + // expect one response msg + assert.NotNil(t, resp) + assert.Nil(t, err) + // expect panic before second response msg + resp, err = stream.Receive() + assert.Nil(t, resp) + return err + } + pinger := &panicPingServer{} + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pinger, + connect.WithInterceptors(connect.RecoverInterceptor(handle)))) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + ) + + // Unary and server-stream RPCs return the request message from req.Any() + check.Store(func(ctx context.Context, req connect.AnyRequest, panicVal any) { + assert.NotNil(t, req.Any()) + }) + + t.Run("unary", func(t *testing.T) { //nolint:paralleltest + for _, panicWith := range []any{42, nil, http.ErrAbortHandler} { + pinger.panicWith = panicWith + + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) + assertHandled(err) + } + }) + + t.Run("server-stream", func(t *testing.T) { //nolint:paralleltest + for _, panicWith := range []any{42, nil, http.ErrAbortHandler} { + pinger.panicWith = panicWith + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + assert.Nil(t, err) + assertHandled(drainServerStream(stream)) + } + }) + + // But client-stream and bidi-stream RPCs return nil from req.Any() + check.Store(func(ctx context.Context, req connect.AnyRequest, panicVal any) { + assert.Nil(t, req.Any()) + }) + + t.Run("client-stream", func(t *testing.T) { //nolint:paralleltest + for _, panicWith := range []any{42, nil, http.ErrAbortHandler} { + pinger.panicWith = panicWith + + stream := client.Sum(context.Background()) + err := stream.Send(&pingv1.SumRequest{Number: 123}) + assert.Nil(t, err) + resp, err := stream.CloseAndReceive() + assert.Nil(t, resp) + assertHandled(err) + } + }) + + t.Run("bidi-stream", func(t *testing.T) { //nolint:paralleltest + for _, panicWith := range []any{42, nil, http.ErrAbortHandler} { + pinger.panicWith = panicWith + + stream := client.CumSum(context.Background()) + err := stream.Send(&pingv1.CumSumRequest{Number: 123}) + assert.Nil(t, err) + assertHandled(drainBidiStream(stream)) + } + }) +} + +func assertNoError(t *testing.T, fn func() error) { + t.Helper() + err := fn() + assert.Nil(t, err) +}