From f0e1dd7b79421faa83946e194507b95eb95df999 Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Tue, 25 Feb 2025 11:06:21 +0100 Subject: [PATCH] fix: syncronize RegisterTrigger client and server using a first ack/err message (#1048) --- .../core/services/capability/capabilities.go | 32 ++++++++++++++++- .../services/capability/capabilities_test.go | 35 ++++++++++++++++--- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/pkg/loop/internal/core/services/capability/capabilities.go b/pkg/loop/internal/core/services/capability/capabilities.go index f1e55c576..6ef721730 100644 --- a/pkg/loop/internal/core/services/capability/capabilities.go +++ b/pkg/loop/internal/core/services/capability/capabilities.go @@ -215,7 +215,26 @@ func (t *triggerExecutableServer) RegisterTrigger(request *capabilitiespb.Trigge } responseCh, err := t.impl.RegisterTrigger(server.Context(), req) if err != nil { - return fmt.Errorf("error registering trigger: %w", err) + // the first message sent to the client will be an ack or error message, this is done in order to syncronize the client and server and avoid + // errors to unregister not found triggers. If the error is not nil, we send an error message to the client and return the error + msg := &capabilitiespb.TriggerResponseMessage{ + Message: &capabilitiespb.TriggerResponseMessage_Response{ + Response: &capabilitiespb.TriggerResponse{ + Error: err.Error(), + }, + }, + } + return server.Send(msg) + } + + // Send ACK response to client + msg := &capabilitiespb.TriggerResponseMessage{ + Message: &capabilitiespb.TriggerResponseMessage_Ack{ + Ack: &emptypb.Empty{}, + }, + } + if err = server.Send(msg); err != nil { + return fmt.Errorf("failed sending ACK response for trigger registration %s: %w", request, err) } defer func() { @@ -270,6 +289,17 @@ func (t *triggerExecutableClient) RegisterTrigger(ctx context.Context, req capab return nil, fmt.Errorf("error registering trigger: %w", err) } + // In order to ensure the registration is successful, we need to wait for the first message from the server. + // This will be an ack or error message. If the error is not nil, we return an error. + ackMsg, err := responseStream.Recv() + if err != nil { + return nil, fmt.Errorf("failed to receive registering trigger ack message: %w", err) + } + + if ackMsg.GetAck() == nil { + return nil, errors.New(fmt.Sprintf("failed registering trigger: %s", ackMsg.GetResponse().GetError())) + } + return forwardTriggerResponsesToChannel(ctx, t.Logger, req, responseStream.Recv) } diff --git a/pkg/loop/internal/core/services/capability/capabilities_test.go b/pkg/loop/internal/core/services/capability/capabilities_test.go index 509e04ac4..6c1bd3c75 100644 --- a/pkg/loop/internal/core/services/capability/capabilities_test.go +++ b/pkg/loop/internal/core/services/capability/capabilities_test.go @@ -23,10 +23,11 @@ import ( type mockTrigger struct { capabilities.BaseCapability - callback chan capabilities.TriggerResponse - triggerActive bool - unregisterCalls chan bool - registerCalls chan bool + callback chan capabilities.TriggerResponse + triggerActive bool + unregisterCalls chan bool + registerCalls chan bool + failedToRegisterErr *string mu sync.Mutex } @@ -39,6 +40,10 @@ func (m *mockTrigger) RegisterTrigger(ctx context.Context, request capabilities. return nil, errors.New("already registered") } + if m.failedToRegisterErr != nil { + return nil, errors.New(*m.failedToRegisterErr) + } + m.triggerActive = true m.registerCalls <- true @@ -170,6 +175,28 @@ func newCapabilityPlugin(t *testing.T, capability capabilities.BaseCapability) ( return regClient.(capabilities.BaseCapability), client, server, nil } +func Test_RegisterTrigger(t *testing.T) { + testContext := tests.Context(t) + t.Run("async RegisterTrigger implementation returns error to server", func(t *testing.T) { + ctx, cancel := context.WithCancel(testContext) + defer cancel() + + errMsg := "boom" + mtr := mustMockTrigger(t) + mtr.failedToRegisterErr = &errMsg + + tr, _, _, err := newCapabilityPlugin(t, mtr) + require.NoError(t, err) + + ctr := tr.(capabilities.TriggerCapability) + + _, err = ctr.RegisterTrigger( + ctx, + capabilities.TriggerRegistrationRequest{}) + require.ErrorContains(t, err, fmt.Sprintf("failed registering trigger: %s", errMsg)) + }) +} + func Test_Capabilities(t *testing.T) { testContext := tests.Context(t)