Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement call canceling in dealer #75

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ import (
"sync"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/util"
)

const (
OptionReceiveProgress = "receive_progress"
OptionProgress = "progress"
OptionMode = "mode"
OptionReason = "reason"

CancelModeKill = "kill"
CancelModeKillNoWait = "killnowait"
CancelModeSkip = "skip"
)

const (
Expand Down Expand Up @@ -106,7 +113,7 @@ func (d *Dealer) HasProcedure(procedure string) bool {
return exists && len(reg.Registrants) > 0
}

func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*MessageWithRecipient, error) {
func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) ([]*MessageWithRecipient, error) {
d.Lock()
defer d.Unlock()

Expand All @@ -117,7 +124,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
if !exists || len(regs.Registrants) == 0 {
callErr := messages.NewError(messages.MessageTypeCall, call.RequestID(), map[string]any{},
"wamp.error.no_such_procedure", nil, nil)
return &MessageWithRecipient{Message: callErr, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: callErr, Recipient: sessionID}}, nil
}

var callee int64
Expand Down Expand Up @@ -156,7 +163,55 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs())
}

return &MessageWithRecipient{Message: invocation, Recipient: callee}, nil
return []*MessageWithRecipient{{Message: invocation, Recipient: callee}}, nil
case messages.MessageTypeCancel:
cancel := msg.(*messages.Cancel)
mode := util.ToString(cancel.Options()[OptionMode])
switch mode {
case CancelModeSkip, CancelModeKill, CancelModeKillNoWait:
case "":
mode = CancelModeKillNoWait
default:
errMsg := messages.NewError(messages.MessageTypeCancel, cancel.RequestID(), nil, ErrInvalidArgument,
[]any{fmt.Sprintf("invalid cancel mode: %s", mode)}, nil)
return []*MessageWithRecipient{{Message: errMsg, Recipient: sessionID}}, nil
}

callMap := CallMap{CallerID: sessionID, CallID: cancel.RequestID()}
invocationID, ok := d.invocationIDbyCall[callMap]
if !ok {
return nil, fmt.Errorf("no pending invocation to cancel")
}

pendingCall, ok := d.pendingCalls[invocationID]
if !ok {
return nil, fmt.Errorf("no pending call to cancel")
}

if sessionID != pendingCall.CallerID {
return nil, fmt.Errorf("cancel received from the session who doesn't own the call")
}

var messagesWithRecipient []*MessageWithRecipient
if mode != CancelModeSkip {
messagesWithRecipient = append(messagesWithRecipient, &MessageWithRecipient{
Message: messages.NewInterrupt(invocationID, map[string]any{OptionReason: ErrCanceled, OptionMode: mode}),
Recipient: pendingCall.CalleeID,
})
}

if mode != CancelModeKill {
errMessage := messages.NewError(messages.MessageTypeCall, cancel.RequestID(), nil, ErrCanceled, nil, nil)
messagesWithRecipient = append(messagesWithRecipient, &MessageWithRecipient{
Message: errMessage,
Recipient: pendingCall.CallerID,
})

delete(d.pendingCalls, invocationID)
delete(d.invocationIDbyCall, callMap)
}

return messagesWithRecipient, nil
case messages.MessageTypeYield:
yield := msg.(*messages.Yield)
pending, exists := d.pendingCalls[yield.RequestID()]
Expand All @@ -168,7 +223,9 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
var details map[string]any
if pending.ReceiveProgress && progress {
details = map[string]any{OptionProgress: progress}
} else {
}

if !pending.ReceiveProgress {
delete(d.pendingCalls, yield.RequestID())
}

Expand All @@ -179,7 +236,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
result = messages.NewResult(pending.RequestID, details, yield.Args(), yield.KwArgs())
}

return &MessageWithRecipient{Message: result, Recipient: pending.CallerID}, nil
return []*MessageWithRecipient{{Message: result, Recipient: pending.CallerID}}, nil
case messages.MessageTypeRegister:
register := msg.(*messages.Register)
_, exists := d.registrationsBySession[sessionID]
Expand All @@ -192,7 +249,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
// TODO: implement shared registrations
err := messages.NewError(messages.MessageTypeRegister, register.RequestID(), map[string]any{},
"wamp.error.procedure_already_exists", nil, nil)
return &MessageWithRecipient{Message: err, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: err, Recipient: sessionID}}, nil
} else {
registration = &Registration{
ID: d.idGen.NextID(),
Expand All @@ -205,7 +262,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
d.registrationsBySession[sessionID][registration.ID] = registration

registered := messages.NewRegistered(register.RequestID(), registration.ID)
return &MessageWithRecipient{Message: registered, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: registered, Recipient: sessionID}}, nil
case messages.MessageTypeUnregister:
unregister := msg.(*messages.Unregister)
registrations, exists := d.registrationsBySession[sessionID]
Expand All @@ -223,7 +280,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message
}

unregistered := messages.NewUnregistered(unregister.RequestID())
return &MessageWithRecipient{Message: unregistered, Recipient: sessionID}, nil
return []*MessageWithRecipient{{Message: unregistered, Recipient: sessionID}}, nil
case messages.MessageTypeError:
wErr := msg.(*messages.Error)
if wErr.MessageType() != messages.MessageTypeInvocation {
Expand All @@ -239,7 +296,7 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message

wErr = messages.NewError(messages.MessageTypeCall, pending.RequestID, wErr.Details(), wErr.URI(),
wErr.Args(), wErr.KwArgs())
return &MessageWithRecipient{Message: wErr, Recipient: pending.CallerID}, nil
return []*MessageWithRecipient{{Message: wErr, Recipient: pending.CallerID}}, nil
default:
return nil, fmt.Errorf("dealer: received unexpected message of type %T", msg)
}
Expand Down
141 changes: 109 additions & 32 deletions dealer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,26 @@ func TestDealerRegisterUnregister(t *testing.T) {

t.Run("Register", func(t *testing.T) {
register := messages.NewRegister(1, nil, "foo.bar")
msg, err := dealer.ReceiveMessage(callee.ID(), register)
msgs, err := dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.NotNil(t, msg)
require.Equal(t, msg.Recipient, callee.ID())
require.Equal(t, messages.MessageTypeRegistered, msg.Message.Type())
require.NotNil(t, msgs)
require.Len(t, msgs, 1)
require.Equal(t, msgs[0].Recipient, callee.ID())
require.Equal(t, messages.MessageTypeRegistered, msgs[0].Message.Type())

hasProcedure := dealer.HasProcedure("foo.bar")
require.True(t, hasProcedure)
registerationID = msg.Message.(*messages.Registered).RegistrationID()
registerationID = msgs[0].Message.(*messages.Registered).RegistrationID()

t.Run("DuplicateProcedure", func(t *testing.T) {
register = messages.NewRegister(2, nil, "foo.bar")
msg, err = dealer.ReceiveMessage(callee.ID(), register)
msgs, err = dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.NotNil(t, msg)
require.Equal(t, msg.Recipient, callee.ID())
require.Equal(t, messages.MessageTypeError, msg.Message.Type())
errMsg := msg.Message.(*messages.Error)
require.NotNil(t, msgs)
require.Len(t, msgs, 1)
require.Equal(t, msgs[0].Recipient, callee.ID())
require.Equal(t, messages.MessageTypeError, msgs[0].Message.Type())
errMsg := msgs[0].Message.(*messages.Error)
require.NotNil(t, errMsg)
require.Equal(t, errMsg.URI(), "wamp.error.procedure_already_exists")
})
Expand All @@ -76,26 +78,29 @@ func TestDealerRegisterUnregister(t *testing.T) {
require.NoError(t, err)

call := messages.NewCall(3, map[string]any{}, "foo.bar", []any{"abc"}, nil)
invWithRecipient, err := dealer.ReceiveMessage(caller.ID(), call)
msgs, err := dealer.ReceiveMessage(caller.ID(), call)
require.NoError(t, err)
require.NotNil(t, invWithRecipient)
require.Len(t, msgs, 1)
invWithRecipient := msgs[0]
require.Equal(t, callee.ID(), invWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeInvocation, invWithRecipient.Message.Type())

// receive yield for invocation
invocation := invWithRecipient.Message.(*messages.Invocation)
yield := messages.NewYield(invocation.RequestID(), map[string]any{}, []any{"abc"}, nil)
yieldWithRecipient, err := dealer.ReceiveMessage(caller.ID(), yield)
msgs, err = dealer.ReceiveMessage(caller.ID(), yield)
require.NoError(t, err)
require.NotNil(t, yieldWithRecipient)
require.Len(t, msgs, 1)
yieldWithRecipient := msgs[0]
require.Equal(t, caller.ID(), yieldWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeResult, yieldWithRecipient.Message.Type())

t.Run("NonExistingProcedure", func(t *testing.T) {
invalidCallMessage := messages.NewCall(3, map[string]any{}, "invalid", []any{"abc"}, nil)
errWithRecipient, err := dealer.ReceiveMessage(caller.ID(), invalidCallMessage)
msgs, err = dealer.ReceiveMessage(caller.ID(), invalidCallMessage)
require.NoError(t, err)
require.NotNil(t, errWithRecipient)
require.Len(t, msgs, 1)
errWithRecipient := msgs[0]
require.Equal(t, caller.ID(), errWithRecipient.Recipient)
require.Equal(t, errWithRecipient.Message.Type(), messages.MessageTypeError)
})
Expand All @@ -108,9 +113,10 @@ func TestDealerRegisterUnregister(t *testing.T) {

t.Run("Unregister", func(t *testing.T) {
unregister := messages.NewUnregister(callee.ID(), registerationID)
unregWithRecipient, err := dealer.ReceiveMessage(callee.ID(), unregister)
msgs, err := dealer.ReceiveMessage(callee.ID(), unregister)
require.NoError(t, err)
require.NotNil(t, unregWithRecipient)
require.Len(t, msgs, 1)
unregWithRecipient := msgs[0]
require.Equal(t, callee.ID(), unregWithRecipient.Recipient)
require.Equal(t, messages.MessageTypeUnregistered, unregWithRecipient.Message.Type())

Expand Down Expand Up @@ -140,27 +146,27 @@ func TestProgressiveCallResults(t *testing.T) {
require.NoError(t, err)

call := messages.NewCall(caller.ID(), map[string]any{wampproto.OptionReceiveProgress: true}, "foo.bar", []any{}, nil)
messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
messagesWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
invocation := messageWithRecipient.Message.(*messages.Invocation)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
invocation := messagesWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invocation.Details()[wampproto.OptionReceiveProgress].(bool))

for i := 0; i < 10; i++ {
yield := messages.NewYield(invocation.RequestID(), map[string]any{wampproto.OptionProgress: true}, []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
messagesWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
result := messageWithRecipient.Message.(*messages.Result)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
result := messagesWithRecipient[0].Message.(*messages.Result)
require.Equal(t, call.RequestID(), result.RequestID())
require.True(t, result.Details()[wampproto.OptionProgress].(bool))
}

yield := messages.NewYield(invocation.RequestID(), map[string]any{}, []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
messagesWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
result := messageWithRecipient.Message.(*messages.Result)
require.Equal(t, callee.ID(), messagesWithRecipient[0].Recipient)
result := messagesWithRecipient[0].Message.(*messages.Result)
require.Equal(t, call.RequestID(), result.RequestID())
progress, _ := result.Details()[wampproto.OptionReceiveProgress].(bool)
require.False(t, progress)
Expand All @@ -184,9 +190,9 @@ func TestProgressiveCallInvocations(t *testing.T) {
call := messages.NewCall(4, map[string]any{wampproto.OptionProgress: true}, "foo.bar", []any{}, nil)
messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
require.Equal(t, callee.ID(), messageWithRecipient[0].Recipient)

invMessage := messageWithRecipient.Message.(*messages.Invocation)
invMessage := messageWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool))

invRequestID := invMessage.RequestID()
Expand All @@ -195,17 +201,88 @@ func TestProgressiveCallInvocations(t *testing.T) {
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), call)
require.NoError(t, err)

invMessage = messageWithRecipient.Message.(*messages.Invocation)
invMessage = messageWithRecipient[0].Message.(*messages.Invocation)
require.True(t, invMessage.Details()[wampproto.OptionProgress].(bool))
require.Equal(t, invRequestID, invMessage.RequestID())
}

finalCall := messages.NewCall(4, map[string]any{}, "foo.bar", []any{}, nil)
messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), finalCall)
require.NoError(t, err)
require.Equal(t, callee.ID(), messageWithRecipient.Recipient)
require.Equal(t, callee.ID(), messageWithRecipient[0].Recipient)

invocation := messageWithRecipient.Message.(*messages.Invocation)
invocation := messageWithRecipient[0].Message.(*messages.Invocation)
inProgress, _ := invocation.Details()[wampproto.OptionProgress].(bool)
require.False(t, inProgress)
}

func TestDealerCancelMessage(t *testing.T) {
dealer := wampproto.NewDealer()

caller := wampproto.NewSessionDetails(1, "realm", "authid", "anonymous", false)
callee := wampproto.NewSessionDetails(2, "realm", "authid", "anonymous", false)

require.NoError(t, dealer.AddSession(caller))
require.NoError(t, dealer.AddSession(callee))

const procedure = "foo.bar"
register := messages.NewRegister(1, nil, procedure)
msgs, err := dealer.ReceiveMessage(callee.ID(), register)
require.NoError(t, err)
require.Len(t, msgs, 1)
require.Equal(t, messages.MessageTypeRegistered, msgs[0].Message.Type())

callAndCancel := func(requestID int64, cancelMode string) []*wampproto.MessageWithRecipient {
call := messages.NewCall(requestID, nil, procedure, nil, nil)
msgs, err = dealer.ReceiveMessage(caller.ID(), call)
require.NoError(t, err)
require.Len(t, msgs, 1)
require.Equal(t, messages.MessageTypeInvocation, msgs[0].Message.Type())

cancel := messages.NewCancel(requestID, map[string]any{wampproto.OptionMode: cancelMode})
msgs, err = dealer.ReceiveMessage(caller.ID(), cancel)
require.NoError(t, err)

return msgs
}

validateErrorMessage := func(msg *wampproto.MessageWithRecipient) {
require.Equal(t, caller.ID(), msg.Recipient)
require.Equal(t, messages.MessageTypeError, msg.Message.Type())
errorMsg := msg.Message.(*messages.Error)
require.Equal(t, wampproto.ErrCanceled, errorMsg.URI())
}

validateInterruptMessage := func(msg *wampproto.MessageWithRecipient) {
require.Equal(t, callee.ID(), msg.Recipient)
require.Equal(t, messages.MessageTypeInterrupt, msg.Message.Type())
interrupt := msg.Message.(*messages.Interrupt)
require.Equal(t, wampproto.ErrCanceled, interrupt.Options()[wampproto.OptionReason])
}

t.Run("CancelModeSkip", func(t *testing.T) {
msgs = callAndCancel(1, wampproto.CancelModeSkip)
require.Len(t, msgs, 1)
validateErrorMessage(msgs[0])
})

t.Run("CancelModeKill", func(t *testing.T) {
msgs = callAndCancel(2, wampproto.CancelModeKill)
require.Len(t, msgs, 1)
validateInterruptMessage(msgs[0])
})

t.Run("CancelModeKillNoWait", func(t *testing.T) {
msgs = callAndCancel(3, wampproto.CancelModeKillNoWait)
require.Len(t, msgs, 2)
validateInterruptMessage(msgs[0])
validateErrorMessage(msgs[1])
})

t.Run("CancelInvalidInvocation", func(t *testing.T) {
cancelInvalid := messages.NewCancel(999, nil)
msgs, err = dealer.ReceiveMessage(caller.ID(), cancelInvalid)
require.Error(t, err)
require.Contains(t, err.Error(), "no pending invocation to cancel")
})
}
Loading