Skip to content

Commit

Permalink
Implement progressive call results
Browse files Browse the repository at this point in the history
  • Loading branch information
muzzammilshahid committed Sep 19, 2024
1 parent aaf15c0 commit 5399dea
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
24 changes: 18 additions & 6 deletions dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import (
"github.com/xconnio/wampproto-go/messages"
)

const (
OptionReceiveProgress = "receive_progress"
OptionProgress = "progress"
)

type PendingInvocation struct {
RequestID int64
CallerID int64
Expand Down Expand Up @@ -110,17 +115,19 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message

invocationID := d.idGen.NextID()
d.pendingCalls[invocationID] = &PendingInvocation{
RequestID: call.RequestID(),
CallerID: sessionID,
CalleeID: callee,
RequestID: call.RequestID(),
CallerID: sessionID,
CalleeID: callee,
ReceiveProgress: call.Options()[OptionReceiveProgress].(bool),
}

var invocation *messages.Invocation
if call.PayloadIsBinary() && d.sessions[callee].StaticSerializer() {
invocation = messages.NewInvocationBinary(invocationID, regs.ID, nil, call.Payload(),
call.PayloadSerializer())
} else {
invocation = messages.NewInvocation(invocationID, regs.ID, nil, call.Args(), call.KwArgs())
details := map[string]any{OptionReceiveProgress: call.Options()[OptionReceiveProgress].(bool)}
invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs())
}

return &MessageWithRecipient{Message: invocation, Recipient: callee}, nil
Expand All @@ -131,13 +138,18 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
return nil, fmt.Errorf("yield: not pending calls for session %d", sessionID)
}

delete(d.pendingCalls, yield.RequestID())
var details map[string]any
if pending.ReceiveProgress {
details = map[string]any{OptionProgress: yield.Options()[OptionProgress].(bool)}
} else {
delete(d.pendingCalls, yield.RequestID())
}

var result *messages.Result
if yield.PayloadIsBinary() && d.sessions[pending.CallerID].StaticSerializer() {
result = messages.NewResultBinary(pending.RequestID, nil, yield.Payload(), yield.PayloadSerializer())
} else {
result = messages.NewResult(pending.RequestID, nil, yield.Args(), yield.KwArgs())
result = messages.NewResult(pending.RequestID, details, yield.Args(), yield.KwArgs())
}

return &MessageWithRecipient{Message: result, Recipient: pending.CallerID}, nil
Expand Down
2 changes: 1 addition & 1 deletion idgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
const maxID int64 = 1 << 53

func init() {
source := rand.NewSource(uint64(time.Now().UnixNano()))
source := rand.NewSource(uint64(time.Now().UnixNano())) // #nosec
rand.New(source)
}

Expand Down
4 changes: 2 additions & 2 deletions messages/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func AsInt64(i interface{}) (int64, bool) {
case int64:
return v, true
case uint64:
return int64(v), true
return int64(v), true // #nosec
case uint8:
return int64(v), true
case int:
Expand All @@ -302,7 +302,7 @@ func AsInt64(i interface{}) (int64, bool) {
case int32:
return int64(v), true
case uint:
return int64(v), true
return int64(v), true // #nosec
case uint16:
return int64(v), true
case uint32:
Expand Down
10 changes: 8 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ func (w *Session) SendMessage(msg messages.Message) ([]byte, error) {
return data, nil
case messages.MessageTypeYield:
yield := msg.(*messages.Yield)
w.invocationRequests.Delete(yield.RequestID())
if !yield.Options()[OptionProgress].(bool) {
w.invocationRequests.Delete(yield.RequestID())
}

return data, nil
case messages.MessageTypeRegister:
Expand Down Expand Up @@ -125,11 +127,15 @@ func (w *Session) ReceiveMessage(msg messages.Message) (messages.Message, error)
switch msg.Type() {
case messages.MessageTypeResult:
result := msg.(*messages.Result)
_, exists := w.callRequests.LoadAndDelete(result.RequestID())
_, exists := w.callRequests.Load(result.RequestID())
if !exists {
return nil, fmt.Errorf("received RESULT for invalid requestID")
}

if !result.Details()[OptionProgress].(bool) {
w.callRequests.Delete(result.RequestID())
}

return result, nil
case messages.MessageTypeRegistered:
registered := msg.(*messages.Registered)
Expand Down

0 comments on commit 5399dea

Please sign in to comment.