diff --git a/dealer.go b/dealer.go index 9a6b07a..8f0cb8f 100644 --- a/dealer.go +++ b/dealer.go @@ -27,11 +27,17 @@ type Registration struct { InvocationPolicy string } +type CallMap struct { + CallerID int64 + CallID int64 +} + type Dealer struct { sessions map[int64]*SessionDetails registrationsByProcedure map[string]*Registration registrationsBySession map[int64]map[int64]*Registration pendingCalls map[int64]*PendingInvocation + invocationIDbyCall map[CallMap]int64 idGen *SessionScopeIDGenerator sync.Mutex @@ -43,6 +49,7 @@ func NewDealer() *Dealer { registrationsByProcedure: make(map[string]*Registration), registrationsBySession: make(map[int64]map[int64]*Registration), pendingCalls: make(map[int64]*PendingInvocation), + invocationIDbyCall: make(map[CallMap]int64), idGen: &SessionScopeIDGenerator{}, } } @@ -113,13 +120,19 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message break } receiveProgress, _ := call.Options()[OptionReceiveProgress].(bool) - - invocationID := d.idGen.NextID() - d.pendingCalls[invocationID] = &PendingInvocation{ - RequestID: call.RequestID(), - CallerID: sessionID, - CalleeID: callee, - ReceiveProgress: receiveProgress, + progress, _ := call.Options()[OptionProgress].(bool) + + invocationID, ok := d.invocationIDbyCall[CallMap{CallerID: sessionID, CallID: call.RequestID()}] + if !ok || !progress { + invocationID = d.idGen.NextID() + d.pendingCalls[invocationID] = &PendingInvocation{ + RequestID: call.RequestID(), + CallerID: sessionID, + CalleeID: callee, + ReceiveProgress: receiveProgress, + Progress: progress, + } + d.invocationIDbyCall[CallMap{CallerID: sessionID, CallID: call.RequestID()}] = invocationID } var invocation *messages.Invocation @@ -127,7 +140,13 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message invocation = messages.NewInvocationBinary(invocationID, regs.ID, nil, call.Payload(), call.PayloadSerializer()) } else { - details := map[string]any{OptionReceiveProgress: receiveProgress} + details := map[string]any{} + if receiveProgress { + details[OptionReceiveProgress] = receiveProgress + } + if progress { + details[OptionProgress] = progress + } invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs()) }