forked from livekit/psrpc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_rpc.go
112 lines (92 loc) · 2.57 KB
/
multi_rpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package psrpc
import (
"context"
"time"
"google.golang.org/protobuf/proto"
"github.com/livekit/psrpc/internal"
)
type multiRPCInterceptorRoot[ResponseType proto.Message] struct {
*multiRPC[ResponseType]
}
func (m *multiRPCInterceptorRoot[ResponseType]) Send(ctx context.Context, msg proto.Message, opts ...RequestOption) error {
return m.send(ctx, msg, opts...)
}
func (m *multiRPCInterceptorRoot[ResponseType]) Recv(msg proto.Message, err error) {
m.recv(msg, err)
}
func (m *multiRPCInterceptorRoot[ResponseType]) Close() {
m.close()
}
type multiRPC[ResponseType proto.Message] struct {
c *RPCClient
resChan chan<- *Response[ResponseType]
interceptor MultiRPCInterceptor
requestID string
info RPCInfo
}
func (m *multiRPC[ResponseType]) send(ctx context.Context, msg proto.Message, opts ...RequestOption) (err error) {
o := getRequestOpts(m.c.clientOpts, opts...)
b, a, err := serializePayload(msg)
if err != nil {
err = NewError(MalformedRequest, err)
return
}
now := time.Now()
req := &internal.Request{
RequestId: m.requestID,
ClientId: m.c.id,
SentAt: now.UnixNano(),
Expiry: now.Add(o.timeout).UnixNano(),
Multi: true,
Request: a,
RawRequest: b,
Metadata: OutgoingContextMetadata(ctx),
}
irChan := make(chan *internal.Response, m.c.channelSize)
m.c.mu.Lock()
m.c.responseChannels[m.requestID] = irChan
m.c.mu.Unlock()
go m.handleResponses(ctx, o, msg, irChan)
if err = m.c.bus.Publish(ctx, getRPCChannel(m.c.serviceName, m.info.Method, m.info.Topic), req); err != nil {
err = NewError(Internal, err)
}
return
}
func (m *multiRPC[ResponseType]) handleResponses(ctx context.Context, o reqOpts, msg proto.Message, irChan chan *internal.Response) {
timer := time.NewTimer(o.timeout)
for {
select {
case res := <-irChan:
var v ResponseType
var err error
if res.Error != "" {
err = newErrorFromResponse(res.Code, res.Error)
} else {
v, err = deserializePayload[ResponseType](res.RawResponse, res.Response)
if err != nil {
err = NewError(MalformedResponse, err)
}
}
// response hooks
for _, hook := range m.c.responseHooks {
hook(ctx, msg, m.info, v, err)
}
m.interceptor.Recv(v, err)
case <-timer.C:
m.interceptor.Close()
return
}
}
}
func (m *multiRPC[ResponseType]) recv(msg proto.Message, err error) {
res := &Response[ResponseType]{}
res.Result, _ = msg.(ResponseType)
res.Err = err
m.resChan <- res
}
func (m *multiRPC[ResponseType]) close() {
m.c.mu.Lock()
delete(m.c.responseChannels, m.requestID)
m.c.mu.Unlock()
close(m.resChan)
}