Skip to content

Commit

Permalink
Implement call canceling in dealer
Browse files Browse the repository at this point in the history
  • Loading branch information
muzzammilshahid committed Jan 25, 2025
1 parent d01416f commit c4ee726
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 41 deletions.
75 changes: 66 additions & 9 deletions dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ import (
"sync"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/util"
)

const (
OptionReceiveProgress = "receive_progress"
OptionProgress = "progress"
OptionMode = "mode"
OptionReason = "reason"

CancelModeKill = "kill"
CancelModeKillNoWait = "killnowait"
CancelModeSkip = "skip"
)

const (
Expand Down Expand Up @@ -106,7 +113,7 @@ func (d *Dealer) HasProcedure(procedure string) bool {
return exists && len(reg.Registrants) > 0
}

func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*MessageWithRecipient, error) {
func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) ([]*MessageWithRecipient, error) {
d.Lock()
defer d.Unlock()

Expand All @@ -117,7 +124,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
if !exists || len(regs.Registrants) == 0 {
callErr := messages.NewError(messages.MessageTypeCall, call.RequestID(), map[string]any{},
"wamp.error.no_such_procedure", nil, nil)
return &MessageWithRecipient{Message: callErr, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: callErr, Recipient: sessionID}}, nil
}

var callee int64
Expand Down Expand Up @@ -156,7 +163,55 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs())
}

return &MessageWithRecipient{Message: invocation, Recipient: callee}, nil
return []*MessageWithRecipient{{Message: invocation, Recipient: callee}}, nil
case messages.MessageTypeCancel:
cancel := msg.(*messages.Cancel)
mode := util.ToString(cancel.Options()[OptionMode])
switch mode {
case CancelModeSkip, CancelModeKill, CancelModeKillNoWait:
case "":
mode = CancelModeKillNoWait
default:
errMsg := messages.NewError(messages.MessageTypeCancel, cancel.RequestID(), nil, ErrInvalidArgument,
[]any{fmt.Sprintf("invalid cancel mode: %s", mode)}, nil)
return []*MessageWithRecipient{{Message: errMsg, Recipient: sessionID}}, nil
}

callMap := CallMap{CallerID: sessionID, CallID: cancel.RequestID()}
invocationID, ok := d.invocationIDbyCall[callMap]
if !ok {
return nil, fmt.Errorf("no pending invocation to cancel")
}

pendingCall, ok := d.pendingCalls[invocationID]
if !ok {
return nil, fmt.Errorf("no pending call to cancel")
}

if sessionID != pendingCall.CallerID {
return nil, fmt.Errorf("cancel received from the session who doesn't own the call")
}

var messagesWithRecipient []*MessageWithRecipient
if mode != CancelModeSkip {
messagesWithRecipient = append(messagesWithRecipient, &MessageWithRecipient{
Message: messages.NewInterrupt(invocationID, map[string]any{OptionReason: ErrCanceled, OptionMode: mode}),
Recipient: pendingCall.CalleeID,
})
}

if mode != CancelModeKill {
errMessage := messages.NewError(messages.MessageTypeCall, cancel.RequestID(), nil, ErrCanceled, nil, nil)
messagesWithRecipient = append(messagesWithRecipient, &MessageWithRecipient{
Message: errMessage,
Recipient: pendingCall.CallerID,
})

delete(d.pendingCalls, invocationID)
delete(d.invocationIDbyCall, callMap)
}

return messagesWithRecipient, nil
case messages.MessageTypeYield:
yield := msg.(*messages.Yield)
pending, exists := d.pendingCalls[yield.RequestID()]
Expand All @@ -168,7 +223,9 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
var details map[string]any
if pending.ReceiveProgress && progress {
details = map[string]any{OptionProgress: progress}
} else {
}

if !pending.ReceiveProgress {
delete(d.pendingCalls, yield.RequestID())
}

Expand All @@ -179,7 +236,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
result = messages.NewResult(pending.RequestID, details, yield.Args(), yield.KwArgs())
}

return &MessageWithRecipient{Message: result, Recipient: pending.CallerID}, nil
return []*MessageWithRecipient{{Message: result, Recipient: pending.CallerID}}, nil
case messages.MessageTypeRegister:
register := msg.(*messages.Register)
_, exists := d.registrationsBySession[sessionID]
Expand All @@ -192,7 +249,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
// TODO: implement shared registrations
err := messages.NewError(messages.MessageTypeRegister, register.RequestID(), map[string]any{},
"wamp.error.procedure_already_exists", nil, nil)
return &MessageWithRecipient{Message: err, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: err, Recipient: sessionID}}, nil
} else {
registration = &Registration{
ID: d.idGen.NextID(),
Expand All @@ -205,7 +262,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
d.registrationsBySession[sessionID][registration.ID] = registration

registered := messages.NewRegistered(register.RequestID(), registration.ID)
return &MessageWithRecipient{Message: registered, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: registered, Recipient: sessionID}}, nil
case messages.MessageTypeUnregister:
unregister := msg.(*messages.Unregister)
registrations, exists := d.registrationsBySession[sessionID]
Expand All @@ -223,7 +280,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
}

unregistered := messages.NewUnregistered(unregister.RequestID())
return &MessageWithRecipient{Message: unregistered, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: unregistered, Recipient: sessionID}}, nil
case messages.MessageTypeError:
wErr := msg.(*messages.Error)
if wErr.MessageType() != messages.MessageTypeInvocation {
Expand All @@ -239,7 +296,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message

wErr = messages.NewError(messages.MessageTypeCall, pending.RequestID, wErr.Details(), wErr.URI(),
wErr.Args(), wErr.KwArgs())
return &MessageWithRecipient{Message: wErr, Recipient: pending.CallerID}, nil
return []*MessageWithRecipient{{Message: wErr, Recipient: pending.CallerID}}, nil
default:
return nil, fmt.Errorf("dealer: received unexpected message of type %T", msg)
}
Expand Down
141 changes: 109 additions & 32 deletions dealer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,26 @@ func TestDealerRegisterUnregister(t *testing.T) {

t.Run("Register", func(t *testing.T) {
register := messages.NewRegister(1, nil, "foo.bar")
msg, err := dealer.ReceiveMessage(callee.ID(), register)
msgs, err := dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.NotNil(t, msg)
require.Equal(t, msg.Recipient, callee.ID())
require.Equal(t, messages.MessageTypeRegistered, msg.Message.Type())
require.NotNil(t, msgs)
require.Len(t, msgs, 1)
require.Equal(t, msgs[0].Recipient, callee.ID())
require.Equal(t, messages.MessageTypeRegistered, msgs[0].Message.Type())

hasProcedure := dealer.HasProcedure("foo.bar")
require.True(t, hasProcedure)
registerationID = msg.Message.(*messages.Registered).RegistrationID()
registerationID = msgs[0].Message.(*messages.Registered).RegistrationID()

t.Run("DuplicateProcedure", func(t *testing.T) {
register = messages.NewRegister(2, nil, "foo.bar")
msg, err = dealer.ReceiveMessage(callee.ID(), register)
msgs, err = dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.NotNil(t, msg)
require.Equal(t, msg.Recipient, callee.ID())
require.Equal(t, messages.MessageTypeError, msg.Message.Type())
errMsg := msg.Message.(*messages.Error)
require.NotNil(t, msgs)
require.Len(t, msgs, 1)
require.Equal(t, msgs[0].Recipient, callee.ID())
require.Equal(t, messages.MessageTypeError, msgs[0].Message.Type())
errMsg := msgs[0].Message.(*messages.Error)
require.NotNil(t, errMsg)
require.Equal(t, errMsg.URI(), "wamp.error.procedure_already_exists")
})
Expand All @@ -76,26 +78,29 @@ func TestDealerRegisterUnregister(t *testing.T) {
require.NoError(t, err)

call := messages.NewCall(3, map[string]any{}, "foo.bar", []any{"abc"}, nil)
invWithRecipient, err := dealer.ReceiveMessage(caller.ID(), call)
msgs, err := dealer.ReceiveMessage(caller.ID(), call)
require.NoError(t, err)
require.NotNil(t, invWithRecipient)
require.Len(t, msgs, 1)
invWithRecipient := msgs[0]
require.Equal(t, callee.ID(), invWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeInvocation, invWithRecipient.Message.Type())

// receive yield for invocation
invocation := invWithRecipient.Message.(*messages.Invocation)
yield := messages.NewYield(invocation.RequestID(), map[string]any{}, []any{"abc"}, nil)
yieldWithRecipient, err := dealer.ReceiveMessage(caller.ID(), yield)
msgs, err = dealer.ReceiveMessage(caller.ID(), yield)
require.NoError(t, err)
require.NotNil(t, yieldWithRecipient)
require.Len(t, msgs, 1)
yieldWithRecipient := msgs[0]
require.Equal(t, caller.ID(), yieldWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeResult, yieldWithRecipient.Message.Type())

t.Run("NonExistingProcedure", func(t *testing.T) {
invalidCallMessage := messages.NewCall(3, map[string]any{}, "invalid", []any{"abc"}, nil)
errWithRecipient, err := dealer.ReceiveMessage(caller.ID(), invalidCallMessage)
msgs, err = dealer.ReceiveMessage(caller.ID(), invalidCallMessage)
require.NoError(t, err)
require.NotNil(t, errWithRecipient)
require.Len(t, msgs, 1)
errWithRecipient := msgs[0]
require.Equal(t, caller.ID(), errWithRecipient.Recipient)
require.Equal(t, errWithRecipient.Message.Type(), messages.MessageTypeError)
})
Expand All @@ -108,9 +113,10 @@ func TestDealerRegisterUnregister(t *testing.T) {

t.Run("Unregister", func(t *testing.T) {
unregister := messages.NewUnregister(callee.ID(), registerationID)
unregWithRecipient, err := dealer.ReceiveMessage(callee.ID(), unregister)
msgs, err := dealer.ReceiveMessage(callee.ID(), unregister)
require.NoError(t, err)
require.NotNil(t, unregWithRecipient)
require.Len(t, msgs, 1)
unregWithRecipient := msgs[0]
require.Equal(t, callee.ID(), unregWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeUnregistered, unregWithRecipient.Message.Type())

Expand Down Expand Up @@ -140,27 +146,27 @@ func TestProgressiveCallResults(t *testing.T) {
require.NoError(t, err)

call := messages.NewCall(caller.ID(), map[string]any{wampproto.OptionReceiveProgress: true}, "foo.bar", []any{}, nil)
messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
messagesWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
invocation := messageWithRecipient.Message.(*messages.Invocation)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
invocation := messagesWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invocation.Details()[wampproto.OptionReceiveProgress].(bool))

for i := 0; i < 10; i++ {
yield := messages.NewYield(invocation.RequestID(), map[string]any{wampproto.OptionProgress: true}, []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
messagesWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
result := messageWithRecipient.Message.(*messages.Result)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
result := messagesWithRecipient[0].Message.(*messages.Result)
require.Equal(t, call.RequestID(), result.RequestID())
require.True(t, result.Details()[wampproto.OptionProgress].(bool))
}

yield := messages.NewYield(invocation.RequestID(), map[string]any{}, []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
messagesWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
result := messageWithRecipient.Message.(*messages.Result)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
result := messagesWithRecipient[0].Message.(*messages.Result)
require.Equal(t, call.RequestID(), result.RequestID())
progress, _ := result.Details()[wampproto.OptionReceiveProgress].(bool)
require.False(t, progress)
Expand All @@ -184,9 +190,9 @@ func TestProgressiveCallInvocations(t *testing.T) {
call := messages.NewCall(4, map[string]any{wampproto.OptionProgress: true}, "foo.bar", []any{}, nil)
messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
require.Equal(t, callee.ID(), messageWithRecipient[0].Recipient)

invMessage := messageWithRecipient.Message.(*messages.Invocation)
invMessage := messageWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool))

invRequestID := invMessage.RequestID()
Expand All @@ -195,17 +201,88 @@ func TestProgressiveCallInvocations(t *testing.T) {
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)

invMessage = messageWithRecipient.Message.(*messages.Invocation)
invMessage = messageWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool))
require.Equal(t, invRequestID, invMessage.RequestID())
}

finalCall := messages.NewCall(4, map[string]any{}, "foo.bar", []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), finalCall)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
require.Equal(t, callee.ID(), messageWithRecipient[0].Recipient)

invocation := messageWithRecipient.Message.(*messages.Invocation)
invocation := messageWithRecipient[0].Message.(*messages.Invocation)
inProgress, _ := invocation.Details()[wampproto.OptionProgress].(bool)
require.False(t, inProgress)
}

func TestDealerCancelMessage(t *testing.T) {
dealer := wampproto.NewDealer()

caller := wampproto.NewSessionDetails(1, "realm", "authid", "anonymous", false)
callee := wampproto.NewSessionDetails(2, "realm", "authid", "anonymous", false)

require.NoError(t, dealer.AddSession(caller))
require.NoError(t, dealer.AddSession(callee))

procedure := "foo.bar"

Check failure on line 228 in dealer_test.go

View workflow job for this annotation

GitHub Actions / build

string `foo.bar` has 3 occurrences, make it a constant (goconst)
register := messages.NewRegister(1, nil, procedure)
msgs, err := dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.Len(t, msgs, 1)
require.Equal(t, messages.MessageTypeRegistered, msgs[0].Message.Type())

callAndCancel := func(requestID int64, cancelMode string) []*wampproto.MessageWithRecipient {
call := messages.NewCall(requestID, nil, procedure, nil, nil)
msgs, err = dealer.ReceiveMessage(caller.ID(), call)
require.NoError(t, err)
require.Len(t, msgs, 1)
require.Equal(t, messages.MessageTypeInvocation, msgs[0].Message.Type())

cancel := messages.NewCancel(requestID, map[string]any{wampproto.OptionMode: cancelMode})
msgs, err = dealer.ReceiveMessage(caller.ID(), cancel)
require.NoError(t, err)

return msgs
}

validateErrorMessage := func(msg *wampproto.MessageWithRecipient) {
require.Equal(t, caller.ID(), msg.Recipient)
require.Equal(t, messages.MessageTypeError, msg.Message.Type())
errorMsg := msg.Message.(*messages.Error)
require.Equal(t, wampproto.ErrCanceled, errorMsg.URI())
}

validateInterruptMessage := func(msg *wampproto.MessageWithRecipient) {
require.Equal(t, callee.ID(), msg.Recipient)
require.Equal(t, messages.MessageTypeInterrupt, msg.Message.Type())
interrupt := msg.Message.(*messages.Interrupt)
require.Equal(t, wampproto.ErrCanceled, interrupt.Options()[wampproto.OptionReason])
}

t.Run("CancelModeSkip", func(t *testing.T) {
msgs = callAndCancel(1, wampproto.CancelModeSkip)
require.Len(t, msgs, 1)
validateErrorMessage(msgs[0])
})

t.Run("CancelModeKill", func(t *testing.T) {
msgs = callAndCancel(2, wampproto.CancelModeKill)
require.Len(t, msgs, 1)
validateInterruptMessage(msgs[0])
})

t.Run("CancelModeKillNoWait", func(t *testing.T) {
msgs = callAndCancel(3, wampproto.CancelModeKillNoWait)
require.Len(t, msgs, 2)
validateInterruptMessage(msgs[0])
validateErrorMessage(msgs[1])
})

t.Run("CancelInvalidInvocation", func(t *testing.T) {
cancelInvalid := messages.NewCancel(999, nil)
msgs, err = dealer.ReceiveMessage(caller.ID(), cancelInvalid)
require.Error(t, err)
require.Contains(t, err.Error(), "no pending invocation to cancel")
})
}

0 comments on commit c4ee726

Please sign in to comment.