forked from livekit/psrpc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterceptors.go
95 lines (78 loc) · 3.12 KB
/
interceptors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package psrpc
import (
"context"
"runtime/debug"
"google.golang.org/protobuf/proto"
)
// Server interceptors wrap the service implementation
type ServerInterceptor func(ctx context.Context, req proto.Message, info RPCInfo, handler Handler) (proto.Message, error)
type Handler func(context.Context, proto.Message) (proto.Message, error)
// Request hooks are called as soon as the request is made
type ClientRequestHook func(ctx context.Context, req proto.Message, info RPCInfo)
// Response hooks are called just before responses are returned
// For multi-requests, response hooks are called on every response, and block while executing
type ClientResponseHook func(ctx context.Context, req proto.Message, info RPCInfo, resp proto.Message, err error)
type RPCInterceptor func(ctx context.Context, req proto.Message, opts ...RequestOption) (proto.Message, error)
type RPCInterceptorFactory func(info RPCInfo, next RPCInterceptor) RPCInterceptor
type MultiRPCInterceptor interface {
Send(ctx context.Context, msg proto.Message, opts ...RequestOption) error
Recv(msg proto.Message, err error)
Close()
}
type MultiRPCInterceptorFactory func(info RPCInfo, next MultiRPCInterceptor) MultiRPCInterceptor
type StreamInterceptor interface {
Recv(msg proto.Message) error
Send(msg proto.Message, opts ...StreamOption) error
Close(cause error) error
}
type StreamInterceptorFactory func(info RPCInfo, next StreamInterceptor) StreamInterceptor
type RPCInfo struct {
Service string
Method string
Topic []string
Multi bool
}
// Recover from server panics. Should always be the last interceptor
func WithServerRecovery() ServerInterceptor {
return func(ctx context.Context, req proto.Message, _ RPCInfo, handler Handler) (resp proto.Message, err error) {
defer func() {
if r := recover(); r != nil {
err = NewErrorf(Internal, "Caught server panic. Stack trace:\n%s", string(debug.Stack()))
}
}()
resp, err = handler(ctx, req)
return
}
}
func chainServerInterceptors(interceptors []ServerInterceptor) ServerInterceptor {
switch n := len(interceptors); n {
case 0:
return nil
case 1:
return interceptors[0]
default:
return func(ctx context.Context, req proto.Message, info RPCInfo, handler Handler) (proto.Message, error) {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next Handler
}
state.next = func(ctx context.Context, req proto.Message) (proto.Message, error) {
if state.i == len(interceptors)-1 {
return interceptors[state.i](ctx, req, info, handler)
}
state.i++
return interceptors[state.i-1](ctx, req, info, state.next)
}
return state.next(ctx, req)
}
}
}
func chainClientInterceptors[InterceptorType any, FactoryType ~func(RPCInfo, InterceptorType) InterceptorType](factories []FactoryType, info RPCInfo, interceptor InterceptorType) InterceptorType {
for i := len(factories) - 1; i >= 0; i-- {
interceptor = factories[i](info, interceptor)
}
return interceptor
}