From 5ca3e1b5627b96cb4d18896bc10b8bf285f30b56 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Thu, 19 Sep 2024 16:45:51 +0500 Subject: [PATCH 1/4] Implement progressive call results --- dealer.go | 26 ++++++++++++++++++++------ idgen.go | 2 +- messages/validator.go | 4 ++-- session.go | 12 ++++++++++-- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/dealer.go b/dealer.go index 121a1fd..9a6b07a 100644 --- a/dealer.go +++ b/dealer.go @@ -7,6 +7,11 @@ import ( "github.com/xconnio/wampproto-go/messages" ) +const ( + OptionReceiveProgress = "receive_progress" + OptionProgress = "progress" +) + type PendingInvocation struct { RequestID int64 CallerID int64 @@ -107,12 +112,14 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message callee = session break } + receiveProgress, _ := call.Options()[OptionReceiveProgress].(bool) invocationID := d.idGen.NextID() d.pendingCalls[invocationID] = &PendingInvocation{ - RequestID: call.RequestID(), - CallerID: sessionID, - CalleeID: callee, + RequestID: call.RequestID(), + CallerID: sessionID, + CalleeID: callee, + ReceiveProgress: receiveProgress, } var invocation *messages.Invocation @@ -120,7 +127,8 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message invocation = messages.NewInvocationBinary(invocationID, regs.ID, nil, call.Payload(), call.PayloadSerializer()) } else { - invocation = messages.NewInvocation(invocationID, regs.ID, nil, call.Args(), call.KwArgs()) + details := map[string]any{OptionReceiveProgress: receiveProgress} + invocation = messages.NewInvocation(invocationID, regs.ID, details, call.Args(), call.KwArgs()) } return &MessageWithRecipient{Message: invocation, Recipient: callee}, nil @@ -131,13 +139,19 @@ func (d *Dealer) ReceiveMessage(sessionID int64, msg messages.Message) (*Message return nil, fmt.Errorf("yield: not pending calls for session %d", sessionID) } - delete(d.pendingCalls, yield.RequestID()) + progress, _ := yield.Options()[OptionProgress].(bool) + var details map[string]any + if pending.ReceiveProgress && progress { + details = map[string]any{OptionProgress: progress} + } else { + delete(d.pendingCalls, yield.RequestID()) + } var result *messages.Result if yield.PayloadIsBinary() && d.sessions[pending.CallerID].StaticSerializer() { result = messages.NewResultBinary(pending.RequestID, nil, yield.Payload(), yield.PayloadSerializer()) } else { - result = messages.NewResult(pending.RequestID, nil, yield.Args(), yield.KwArgs()) + result = messages.NewResult(pending.RequestID, details, yield.Args(), yield.KwArgs()) } return &MessageWithRecipient{Message: result, Recipient: pending.CallerID}, nil diff --git a/idgen.go b/idgen.go index 5fa159e..6a697e4 100644 --- a/idgen.go +++ b/idgen.go @@ -10,7 +10,7 @@ import ( const maxID int64 = 1 << 53 func init() { - source := rand.NewSource(uint64(time.Now().UnixNano())) + source := rand.NewSource(uint64(time.Now().UnixNano())) // #nosec rand.New(source) } diff --git a/messages/validator.go b/messages/validator.go index 5a9c96f..63877da 100644 --- a/messages/validator.go +++ b/messages/validator.go @@ -292,7 +292,7 @@ func AsInt64(i interface{}) (int64, bool) { case int64: return v, true case uint64: - return int64(v), true + return int64(v), true // #nosec case uint8: return int64(v), true case int: @@ -302,7 +302,7 @@ func AsInt64(i interface{}) (int64, bool) { case int32: return int64(v), true case uint: - return int64(v), true + return int64(v), true // #nosec case uint16: return int64(v), true case uint32: diff --git a/session.go b/session.go index d59ca77..66fefd1 100644 --- a/session.go +++ b/session.go @@ -60,7 +60,10 @@ func (w *Session) SendMessage(msg messages.Message) ([]byte, error) { return data, nil case messages.MessageTypeYield: yield := msg.(*messages.Yield) - w.invocationRequests.Delete(yield.RequestID()) + progress, _ := yield.Options()[OptionProgress].(bool) + if !progress { + w.invocationRequests.Delete(yield.RequestID()) + } return data, nil case messages.MessageTypeRegister: @@ -125,11 +128,16 @@ func (w *Session) ReceiveMessage(msg messages.Message) (messages.Message, error) switch msg.Type() { case messages.MessageTypeResult: result := msg.(*messages.Result) - _, exists := w.callRequests.LoadAndDelete(result.RequestID()) + _, exists := w.callRequests.Load(result.RequestID()) if !exists { return nil, fmt.Errorf("received RESULT for invalid requestID") } + progress, _ := result.Details()[OptionProgress].(bool) + if !progress { + w.callRequests.Delete(result.RequestID()) + } + return result, nil case messages.MessageTypeRegistered: registered := msg.(*messages.Registered) From a8b84556d78b8a5d9e11ac361460ecd357256948 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Thu, 19 Sep 2024 16:45:58 +0500 Subject: [PATCH 2/4] Add callee features --- joiner.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/joiner.go b/joiner.go index a5ebf9e..7c6e1df 100644 --- a/joiner.go +++ b/joiner.go @@ -14,7 +14,10 @@ var ClientRoles = map[string]any{ //nolint:gochecknoglobals "features": map[string]any{}, }, "callee": map[string]any{ - "features": map[string]any{}, + "features": map[string]any{ + "progressive_call_results": true, + "call_canceling": true, + }, }, "publisher": map[string]any{ "features": map[string]any{}, From 0cbf66d14e39254d878a596e6b9a3e72e4503d35 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 14:00:11 +0500 Subject: [PATCH 3/4] Add test for progressive call results --- dealer_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/dealer_test.go b/dealer_test.go index 95041ea..0dd7122 100644 --- a/dealer_test.go +++ b/dealer_test.go @@ -123,3 +123,45 @@ func TestDealerRegisterUnregister(t *testing.T) { }) }) } + +func TestProgressiveCallResults(t *testing.T) { + dealer := wampproto.NewDealer() + + callee := wampproto.NewSessionDetails(1, "realm", "authid", "anonymous", false) + caller := wampproto.NewSessionDetails(2, "realm", "authid", "anonymous", false) + + err := dealer.AddSession(callee) + require.NoError(t, err) + err = dealer.AddSession(caller) + require.NoError(t, err) + + register := messages.NewRegister(1, nil, "foo.bar") + _, err = dealer.ReceiveMessage(callee.ID(), register) + require.NoError(t, err) + + call := messages.NewCall(caller.ID(), map[string]any{wampproto.OptionReceiveProgress: true}, "foo.bar", []any{}, nil) + messageWithRecipient, err := dealer.ReceiveMessage(callee.ID(), call) + require.NoError(t, err) + require.Equal(t, callee.ID(), messageWithRecipient.Recipient) + invocation := messageWithRecipient.Message.(*messages.Invocation) + require.True(t, invocation.Details()[wampproto.OptionReceiveProgress].(bool)) + + for i := 0; i < 10; i++ { + yield := messages.NewYield(invocation.RequestID(), map[string]any{wampproto.OptionProgress: true}, []any{}, nil) + messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield) + require.NoError(t, err) + require.Equal(t, callee.ID(), messageWithRecipient.Recipient) + result := messageWithRecipient.Message.(*messages.Result) + require.Equal(t, call.RequestID(), result.RequestID()) + require.True(t, result.Details()[wampproto.OptionProgress].(bool)) + } + + yield := messages.NewYield(invocation.RequestID(), map[string]any{}, []any{}, nil) + messageWithRecipient, err = dealer.ReceiveMessage(callee.ID(), yield) + require.NoError(t, err) + require.Equal(t, callee.ID(), messageWithRecipient.Recipient) + result := messageWithRecipient.Message.(*messages.Result) + require.Equal(t, call.RequestID(), result.RequestID()) + progress, _ := result.Details()[wampproto.OptionReceiveProgress].(bool) + require.False(t, progress) +} From 2bd4e84d2d5e9a37cff9473b703db01a15b97f80 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 20 Sep 2024 14:11:16 +0500 Subject: [PATCH 4/4] Fix validator test to handle varying error message order --- messages/validator_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/messages/validator_test.go b/messages/validator_test.go index 971d42c..ea4879a 100644 --- a/messages/validator_test.go +++ b/messages/validator_test.go @@ -405,8 +405,12 @@ func TestValidateMessage(t *testing.T) { wampMsg := []any{1, "io.xconn.test", map[string]any{}, "invalidType", "extra"} _, err := messages.ValidateMessage(wampMsg, spec) - require.EqualError(t, err, `item at index 3 must be of type []any but was string -item at index 4 must be of type map[string]any but was string`) + require.Contains(t, []string{ + `item at index 3 must be of type []any but was string +item at index 4 must be of type map[string]any but was string`, + `item at index 4 must be of type map[string]any but was string +item at index 3 must be of type []any but was string`, + }, err.Error()) }) }