diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fe3325b576a..7f064211eb2 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -77,12 +77,12 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "sync" "time" "golang.org/x/time/rate" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" @@ -129,7 +129,7 @@ type Controller struct { // issues such as sending on a closed channel while maintaining proper cleanup. multiplexedStream chan interface{} - dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] + dataProviders *concurrentmap.Map[SubscriptionID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory dataProvidersGroup *sync.WaitGroup limiter *rate.Limiter @@ -146,7 +146,7 @@ func NewWebSocketController( config: config, conn: conn, multiplexedStream: make(chan interface{}), - dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProviders: concurrentmap.New[SubscriptionID, dp.DataProvider](), dataProviderFactory: dataProviderFactory, dataProvidersGroup: &sync.WaitGroup{}, limiter: rate.NewLimiter(rate.Limit(config.MaxResponsesPerSecond), 1), @@ -246,7 +246,7 @@ func (c *Controller) keepalive(ctx context.Context) error { // If no messages are sent within InactivityTimeout and no active data providers exist, // the connection will be closed. func (c *Controller) writeMessages(ctx context.Context) error { - inactivityTicker := time.NewTicker(c.config.InactivityTimeout / 10) + inactivityTicker := time.NewTicker(c.inactivityTickerPeriod()) defer inactivityTicker.Stop() lastMessageSentAt := time.Now() @@ -301,6 +301,10 @@ func (c *Controller) writeMessages(ctx context.Context) error { } } +func (c *Controller) inactivityTickerPeriod() time.Duration { + return c.config.InactivityTimeout / 10 +} + // readMessages continuously reads messages from a client WebSocket connection, // validates each message, and processes it based on the message type. func (c *Controller) readMessages(ctx context.Context) error { @@ -314,7 +318,8 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error reading message", "", "", "")) + wrapErrorMessage(http.StatusBadRequest, "error reading message", "", ""), + ) continue } @@ -323,7 +328,8 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidMessage, "error parsing message", "", "", "")) + wrapErrorMessage(http.StatusBadRequest, "error parsing message", "", ""), + ) continue } } @@ -366,24 +372,34 @@ func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage) } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { + subscriptionID, err := c.parseOrCreateSubscriptionID(msg.SubscriptionID) + if err != nil { + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", + models.SubscribeAction, msg.SubscriptionID), + ) + return + } + // register new provider - provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) + provider, err := c.dataProviderFactory.NewDataProvider(ctx, subscriptionID.String(), msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""), + wrapErrorMessage(http.StatusBadRequest, "error creating data provider", + models.SubscribeAction, subscriptionID.String()), ) return } - c.dataProviders.Add(provider.ID(), provider) + c.dataProviders.Add(subscriptionID, provider) // write OK response to client responseOk := models.SubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - ClientMessageID: msg.ClientMessageID, - Success: true, - SubscriptionID: provider.ID().String(), + SubscriptionID: subscriptionID.String(), }, } c.writeResponse(ctx, responseOk) @@ -396,72 +412,63 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""), + wrapErrorMessage(http.StatusInternalServerError, "internal error", + models.SubscribeAction, subscriptionID.String()), ) } c.dataProvidersGroup.Done() - c.dataProviders.Remove(provider.ID()) + c.dataProviders.Remove(subscriptionID) }() } func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { - id, err := uuid.Parse(msg.SubscriptionID) + subscriptionID, err := ParseClientSubscriptionID(msg.SubscriptionID) if err != nil { c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(http.StatusBadRequest, "error parsing subscription id", + models.UnsubscribeAction, msg.SubscriptionID), ) return } - provider, ok := c.dataProviders.Get(id) + provider, ok := c.dataProviders.Get(subscriptionID) if !ok { c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(http.StatusNotFound, "subscription not found", + models.UnsubscribeAction, subscriptionID.String()), ) return } provider.Close() - c.dataProviders.Remove(id) + c.dataProviders.Remove(subscriptionID) responseOk := models.UnsubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - ClientMessageID: msg.ClientMessageID, - Success: true, - SubscriptionID: msg.SubscriptionID, + SubscriptionID: subscriptionID.String(), }, } c.writeResponse(ctx, responseOk) } -func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { +func (c *Controller) handleListSubscriptions(ctx context.Context, _ models.ListSubscriptionsMessageRequest) { var subs []*models.SubscriptionEntry - err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { + _ = c.dataProviders.ForEach(func(id SubscriptionID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ - ID: id.String(), - Topic: provider.Topic(), + SubscriptionID: id.String(), + Topic: provider.Topic(), }) return nil }) - if err != nil { - c.writeErrorResponse( - ctx, - err, - wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""), - ) - return - } - responseOk := models.ListSubscriptionsMessageResponse{ - Success: true, - ClientMessageID: msg.ClientMessageID, - Subscriptions: subs, + Subscriptions: subs, + Action: models.ListSubscriptionsAction, } c.writeResponse(ctx, responseOk) } @@ -472,13 +479,10 @@ func (c *Controller) shutdownConnection() { c.logger.Debug().Err(err).Msg("error closing connection") } - err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error { + _ = c.dataProviders.ForEach(func(_ SubscriptionID, provider dp.DataProvider) error { provider.Close() return nil }) - if err != nil { - c.logger.Debug().Err(err).Msg("error closing data provider") - } c.dataProviders.Clear() c.dataProvidersGroup.Wait() @@ -498,15 +502,26 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { } } -func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse { +func wrapErrorMessage(code int, message string, action string, subscriptionID string) models.BaseMessageResponse { return models.BaseMessageResponse{ - ClientMessageID: msgId, - Success: false, - SubscriptionID: subscriptionID, + SubscriptionID: subscriptionID, Error: models.ErrorMessage{ - Code: int(code), + Code: code, Message: message, - Action: action, }, + Action: action, } } + +func (c *Controller) parseOrCreateSubscriptionID(id string) (SubscriptionID, error) { + newId, err := NewSubscriptionID(id) + if err != nil { + return SubscriptionID{}, err + } + + if c.dataProviders.Has(newId) { + return SubscriptionID{}, fmt.Errorf("subscription ID is already in use: %s", newId) + } + + return newId, nil +} diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 4b3795f61b7..a8870f4a171 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "testing" "time" @@ -51,14 +52,11 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -71,8 +69,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.SubscribeAction, + SubscriptionID: "dummy-id", + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, Arguments: nil, @@ -98,9 +96,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) - require.True(t, response.Success) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) - require.Equal(t, id.String(), response.SubscriptionID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) return websocket.ErrCloseSent }) @@ -113,7 +109,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { dataProvider.AssertExpectations(t) }) - s.T().Run("Parse and validate error", func(t *testing.T) { + s.T().Run("Validate message error", func(t *testing.T) { t.Parallel() conn, dataProviderFactory, _ := newControllerMocks(t) @@ -148,9 +144,9 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, int(InvalidMessage), response.Error.Code) + require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, "", response.Action) return websocket.ErrCloseSent }) @@ -169,12 +165,13 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil, fmt.Errorf("error creating data provider")). Once() done := make(chan struct{}) - s.expectSubscribeRequest(t, conn) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -183,9 +180,9 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, int(InvalidArgument), response.Error.Code) + require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, models.SubscribeAction, response.Action) return websocket.ErrCloseSent }) @@ -204,7 +201,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) - dataProvider.On("ID").Return(uuid.New()) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -214,13 +210,14 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -229,9 +226,9 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, int(SubscriptionError), response.Error.Code) + require.Equal(t, http.StatusInternalServerError, response.Error.Code) + require.Equal(t, models.SubscribeAction, response.Action) return websocket.ErrCloseSent }) @@ -254,14 +251,11 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -272,15 +266,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: subscriptionID, + Action: models.UnsubscribeAction, }, - SubscriptionID: id.String(), } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -302,9 +296,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.UnsubscribeMessageResponse) require.True(t, ok) - require.True(t, response.Success) require.Empty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) require.Equal(t, request.SubscriptionID, response.SubscriptionID) return websocket.ErrCloseSent @@ -327,14 +320,11 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -345,15 +335,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: uuid.New().String() + " .42", // invalid subscription ID + Action: models.UnsubscribeAction, }, - SubscriptionID: "invalid-uuid", } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -375,10 +365,10 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) - require.Equal(t, int(InvalidArgument), response.Error.Code) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) + require.Equal(t, http.StatusBadRequest, response.Error.Code) + require.Equal(t, models.UnsubscribeAction, response.Action) return websocket.ErrCloseSent }). @@ -400,14 +390,11 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() done := make(chan struct{}) - - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -418,15 +405,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + SubscriptionID: "unknown-sub-id", + Action: models.UnsubscribeAction, }, - SubscriptionID: uuid.New().String(), } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -448,11 +435,12 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.False(t, response.Success) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) + require.NotEmpty(t, response.Error) + require.Equal(t, http.StatusNotFound, response.Error.Code) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) - require.Equal(t, int(NotFound), response.Error.Code) + require.Equal(t, models.UnsubscribeAction, response.Action) return websocket.ErrCloseSent }). @@ -475,15 +463,13 @@ func (s *WsControllerSuite) TestListSubscriptions() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() done := make(chan struct{}) - id := uuid.New() topic := dp.BlocksTopic - dataProvider.On("ID").Return(id) dataProvider.On("Topic").Return(topic) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -495,13 +481,14 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) request := models.ListSubscriptionsMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.ListSubscriptionsAction, + SubscriptionID: "", + Action: models.ListSubscriptionsAction, }, } requestJson, err := json.Marshal(request) @@ -524,12 +511,10 @@ func (s *WsControllerSuite) TestListSubscriptions() { response, ok := msg.(models.ListSubscriptionsMessageResponse) require.True(t, ok) - require.True(t, response.Success) - require.Empty(t, response.Error) - require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, 1, len(response.Subscriptions)) - require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, subscriptionID, response.Subscriptions[0].SubscriptionID) require.Equal(t, topic, response.Subscriptions[0].Topic) + require.Equal(t, models.ListSubscriptionsAction, response.Action) return websocket.ErrCloseSent }). @@ -554,12 +539,10 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -574,8 +557,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -609,12 +593,10 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -631,8 +613,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -814,12 +797,10 @@ func (s *WsControllerSuite) TestControllerShutdown() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(dataProvider, nil). Once() - id := uuid.New() - dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() @@ -832,8 +813,9 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() done := make(chan struct{}) - msgID := s.expectSubscribeRequest(t, conn) - s.expectSubscribeResponse(t, conn, msgID) + subscriptionID := "dummy-id" + s.expectSubscribeRequest(t, conn, subscriptionID) + s.expectSubscribeResponse(t, conn, subscriptionID) conn. On("WriteJSON", mock.Anything). @@ -893,15 +875,14 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { - // waiting more than InactivityTimeout to make sure that read message routine busy and do not return - // an error before than inactivity tracker initiate shut down - <-time.After(wsConfig.InactivityTimeout) + // make sure the reader routine sleeps for more time than InactivityTimeout + inactivity ticker period. + // meanwhile, the writer routine must shut down the controller. + <-time.After(wsConfig.InactivityTimeout + controller.inactivityTickerPeriod()*2) return websocket.ErrCloseSent }). Once() controller.HandleConnection(context.Background()) - time.Sleep(wsConfig.InactivityTimeout) conn.AssertExpectations(t) }) @@ -993,11 +974,11 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection) string { +func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection, subscriptionID string) { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - ClientMessageID: uuid.New().String(), - Action: models.SubscribeAction, + SubscriptionID: subscriptionID, + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, } @@ -1014,19 +995,16 @@ func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock. }). Return(nil). Once() - - return request.ClientMessageID } // expectSubscribeResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, msgId string) { +func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, subscriptionID string) { conn. On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(t, ok) - require.Equal(t, msgId, response.ClientMessageID) - require.Equal(t, true, response.Success) + require.Equal(t, subscriptionID, response.SubscriptionID) }). Return(nil). Once() diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider.go b/engine/access/rest/websockets/data_providers/account_statuses_provider.go index f8b9e5dd4c5..9d799f220ee 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider.go @@ -41,6 +41,7 @@ func NewAccountStatusesDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, @@ -63,6 +64,7 @@ func NewAccountStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go index 157c50e7deb..7ff14ea8597 100644 --- a/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/account_statuses_provider_test.go @@ -4,6 +4,7 @@ import ( "context" "strconv" "testing" + "time" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" @@ -176,6 +177,7 @@ func (s *AccountStatusesProviderSuite) TestAccountStatusesDataProvider_InvalidAr ctx, s.log, s.api, + "dummy-id", topic, test.arguments, send, @@ -203,7 +205,7 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe // Create a mock subscription and mock the channel sub := ssmock.NewSubscription(s.T()) sub.On("Channel").Return((<-chan interface{})(accountStatusesChan)) - sub.On("Err").Return(nil) + sub.On("Err").Return(nil).Once() s.api.On("SubscribeAccountStatusesFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) @@ -217,6 +219,7 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe ctx, s.log, s.api, + "dummy-id", topic, arguments, send, @@ -231,7 +234,9 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe defer provider.Close() // Run the provider in a separate goroutine to simulate subscription processing + done := make(chan struct{}) go func() { + defer close(done) err = provider.Run() s.Require().NoError(err) }() @@ -254,6 +259,9 @@ func (s *AccountStatusesProviderSuite) TestMessageIndexAccountStatusesProviderRe responses = append(responses, accountStatusesRes) } + // Wait for the provider goroutine to finish + unittest.RequireCloseBefore(s.T(), done, time.Second, "provider failed to stop") + // Verifying that indices are starting from 0 s.Require().Equal(uint64(0), responses[0].MessageIndex, "Expected MessageIndex to start with 0") diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index 0ee040cd4ac..27b757dbf74 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -3,39 +3,38 @@ package data_providers import ( "context" - "github.com/google/uuid" - "github.com/onflow/flow-go/engine/access/subscription" ) // baseDataProvider holds common objects for the provider type baseDataProvider struct { - id uuid.UUID - topic string - cancel context.CancelFunc - send chan<- interface{} - subscription subscription.Subscription + subscriptionID string + topic string + cancel context.CancelFunc + send chan<- interface{} + subscription subscription.Subscription } // newBaseDataProvider creates a new instance of baseDataProvider. func newBaseDataProvider( + subscriptionID string, topic string, cancel context.CancelFunc, send chan<- interface{}, subscription subscription.Subscription, ) *baseDataProvider { return &baseDataProvider{ - id: uuid.New(), - topic: topic, - cancel: cancel, - send: send, - subscription: subscription, + subscriptionID: subscriptionID, + topic: topic, + cancel: cancel, + send: send, + subscription: subscription, } } -// ID returns the unique identifier of the data provider. -func (b *baseDataProvider) ID() uuid.UUID { - return b.id +// ID returns the subscription ID associated with current data provider +func (b *baseDataProvider) ID() string { + return b.subscriptionID } // Topic returns the topic associated with the data provider. diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider.go b/engine/access/rest/websockets/data_providers/block_digests_provider.go index 7798fc6579c..12d46daf03f 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider.go @@ -28,6 +28,7 @@ func NewBlockDigestsDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, @@ -45,6 +46,7 @@ func NewBlockDigestsDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go index 57576fe60a0..395e57dfa04 100644 --- a/engine/access/rest/websockets/data_providers/block_digests_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_digests_provider_test.go @@ -43,7 +43,7 @@ func (s *BlockDigestsProviderSuite) TestBlockDigestsDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + provider, err := NewBlockDigestsDataProvider(ctx, s.log, s.api, "dummy-id", topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider.go b/engine/access/rest/websockets/data_providers/block_headers_provider.go index 4ace250554a..d6b39d17082 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider.go @@ -29,6 +29,7 @@ func NewBlockHeadersDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, @@ -46,6 +47,7 @@ func NewBlockHeadersDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go index 9f71ae124f2..8834d21d498 100644 --- a/engine/access/rest/websockets/data_providers/block_headers_provider_test.go +++ b/engine/access/rest/websockets/data_providers/block_headers_provider_test.go @@ -44,7 +44,7 @@ func (s *BlockHeadersProviderSuite) TestBlockHeadersDataProvider_InvalidArgument for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, topic, test.arguments, send) + provider, err := NewBlockHeadersDataProvider(ctx, s.log, s.api, "dummy-id", topic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/blocks_provider.go b/engine/access/rest/websockets/data_providers/blocks_provider.go index 3de10f2cd95..00f9e6120dc 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider.go @@ -40,6 +40,7 @@ func NewBlocksDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, @@ -60,6 +61,7 @@ func NewBlocksDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/blocks_provider_test.go b/engine/access/rest/websockets/data_providers/blocks_provider_test.go index e73dc99df1d..2754572dff7 100644 --- a/engine/access/rest/websockets/data_providers/blocks_provider_test.go +++ b/engine/access/rest/websockets/data_providers/blocks_provider_test.go @@ -131,7 +131,7 @@ func (s *BlocksProviderSuite) TestBlocksDataProvider_InvalidArguments() { for _, test := range s.invalidArgumentsTestCases() { s.Run(test.name, func() { - provider, err := NewBlocksDataProvider(ctx, s.log, s.api, nil, BlocksTopic, test.arguments, send) + provider, err := NewBlocksDataProvider(ctx, s.log, s.api, "dummy-id", nil, BlocksTopic, test.arguments, send) s.Require().Nil(provider) s.Require().Error(err) s.Require().Contains(err.Error(), test.expectedErrorMsg) diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index ab48ebeb9f9..acaa857ead2 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -1,14 +1,10 @@ package data_providers -import ( - "github.com/google/uuid" -) - // The DataProvider is the interface abstracts of the actual data provider used by the WebSocketCollector. -// It provides methods for retrieving the provider's unique ID, topic, and a methods to close and run the provider. +// It provides methods for retrieving the provider's unique SubscriptionID, topic, and a methods to close and run the provider. type DataProvider interface { // ID returns the unique identifier of the data provider. - ID() uuid.UUID + ID() string // Topic returns the topic associated with the data provider. Topic() string // Close terminates the data provider. diff --git a/engine/access/rest/websockets/data_providers/events_provider.go b/engine/access/rest/websockets/data_providers/events_provider.go index f412d35373a..22979ef4d16 100644 --- a/engine/access/rest/websockets/data_providers/events_provider.go +++ b/engine/access/rest/websockets/data_providers/events_provider.go @@ -40,6 +40,7 @@ func NewEventsDataProvider( ctx context.Context, logger zerolog.Logger, stateStreamApi state_stream.API, + subscriptionID string, topic string, arguments models.Arguments, send chan<- interface{}, @@ -62,6 +63,7 @@ func NewEventsDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, @@ -103,7 +105,6 @@ func (p *EventsDataProvider) handleResponse() func(eventsResponse *backend.Event var response models.EventResponse response.Build(eventsResponse, index) - p.send <- &response return nil diff --git a/engine/access/rest/websockets/data_providers/events_provider_test.go b/engine/access/rest/websockets/data_providers/events_provider_test.go index ab9f6110820..4fbe4908ca8 100644 --- a/engine/access/rest/websockets/data_providers/events_provider_test.go +++ b/engine/access/rest/websockets/data_providers/events_provider_test.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" "testing" + "time" "github.com/rs/zerolog" "github.com/stretchr/testify/mock" @@ -202,6 +203,7 @@ func (s *EventsProviderSuite) TestEventsDataProvider_InvalidArguments() { ctx, s.log, s.api, + "dummy-id", topic, test.arguments, send, @@ -229,7 +231,7 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() // Create a mock subscription and mock the channel sub := ssmock.NewSubscription(s.T()) sub.On("Channel").Return((<-chan interface{})(eventChan)) - sub.On("Err").Return(nil) + sub.On("Err").Return(nil).Once() s.api.On("SubscribeEventsFromStartBlockID", mock.Anything, mock.Anything, mock.Anything).Return(sub) @@ -243,6 +245,7 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() ctx, s.log, s.api, + "dummy-id", topic, arguments, send, @@ -258,7 +261,9 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() defer provider.Close() // Run the provider in a separate goroutine to simulate subscription processing + done := make(chan struct{}) go func() { + defer close(done) err = provider.Run() s.Require().NoError(err) }() @@ -283,6 +288,9 @@ func (s *EventsProviderSuite) TestMessageIndexEventProviderResponse_HappyPath() responses = append(responses, eventRes) } + // Wait for the provider goroutine to finish + unittest.RequireCloseBefore(s.T(), done, time.Second, "provider failed to stop") + // Verifying that indices are starting from 1 s.Require().Equal(uint64(0), responses[0].MessageIndex, "Expected MessageIndex to start with 0") diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index cb70b8bcb79..02d1a1320dd 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -33,7 +33,13 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider( + ctx context.Context, + subscriptionID string, + topic string, + args models.Arguments, + ch chan<- interface{}, + ) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) @@ -91,27 +97,22 @@ func NewDataProviderFactory( // - ch: Channel to which the data provider sends data. // // No errors are expected during normal operations. -func (s *DataProviderFactoryImpl) NewDataProvider( - ctx context.Context, - topic string, - arguments models.Arguments, - ch chan<- interface{}, -) (DataProvider, error) { +func (s *DataProviderFactoryImpl) NewDataProvider(ctx context.Context, subscriptionID string, topic string, arguments models.Arguments, ch chan<- interface{}) (DataProvider, error) { switch topic { case BlocksTopic: - return NewBlocksDataProvider(ctx, s.logger, s.accessApi, s.linkGenerator, topic, arguments, ch) + return NewBlocksDataProvider(ctx, s.logger, s.accessApi, subscriptionID, s.linkGenerator, topic, arguments, ch) case BlockHeadersTopic: - return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewBlockHeadersDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case BlockDigestsTopic: - return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, topic, arguments, ch) + return NewBlockDigestsDataProvider(ctx, s.logger, s.accessApi, subscriptionID, topic, arguments, ch) case EventsTopic: - return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + return NewEventsDataProvider(ctx, s.logger, s.stateStreamApi, subscriptionID, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) case AccountStatusesTopic: - return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) + return NewAccountStatusesDataProvider(ctx, s.logger, s.stateStreamApi, subscriptionID, topic, arguments, ch, s.chain, s.eventFilterConfig, s.heartbeatInterval) case TransactionStatusesTopic: - return NewTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, s.linkGenerator, topic, arguments, ch) + return NewTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, subscriptionID, s.linkGenerator, topic, arguments, ch) case SendAndGetTransactionStatusesTopic: - return NewSendAndGetTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, s.linkGenerator, topic, arguments, ch) + return NewSendAndGetTransactionStatusesDataProvider(ctx, s.logger, s.accessApi, subscriptionID, s.linkGenerator, topic, arguments, ch) default: return nil, fmt.Errorf("unsupported topic \"%s\"", topic) } diff --git a/engine/access/rest/websockets/data_providers/factory_test.go b/engine/access/rest/websockets/data_providers/factory_test.go index 3e1de651650..c89a532fddb 100644 --- a/engine/access/rest/websockets/data_providers/factory_test.go +++ b/engine/access/rest/websockets/data_providers/factory_test.go @@ -160,7 +160,7 @@ func (s *DataProviderFactorySuite) TestSupportedTopics() { s.T().Parallel() test.setupSubscription() - provider, err := s.factory.NewDataProvider(s.ctx, test.topic, test.arguments, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, "dummy-id", test.topic, test.arguments, s.ch) s.Require().NotNil(provider, "Expected provider for topic %s", test.topic) s.Require().NoError(err, "Expected no error for topic %s", test.topic) s.Require().Equal(test.topic, provider.Topic()) @@ -182,7 +182,7 @@ func (s *DataProviderFactorySuite) TestUnsupportedTopics() { } for _, topic := range unsupportedTopics { - provider, err := s.factory.NewDataProvider(s.ctx, topic, nil, s.ch) + provider, err := s.factory.NewDataProvider(s.ctx, "dummy-id", topic, nil, s.ch) s.Require().Nil(provider, "Expected no provider for unsupported topic %s", topic) s.Require().Error(err, "Expected error for unsupported topic %s", topic) s.Require().EqualError(err, fmt.Sprintf("unsupported topic \"%s\"", topic)) diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index 48debb23ae3..478f1625ad5 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -2,10 +2,7 @@ package mock -import ( - uuid "github.com/google/uuid" - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" // DataProvider is an autogenerated mock type for the DataProvider type type DataProvider struct { @@ -18,20 +15,18 @@ func (_m *DataProvider) Close() { } // ID provides a mock function with given fields: -func (_m *DataProvider) ID() uuid.UUID { +func (_m *DataProvider) ID() string { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for ID") } - var r0 uuid.UUID - if rf, ok := ret.Get(0).(func() uuid.UUID); ok { + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(uuid.UUID) - } + r0 = ret.Get(0).(string) } return r0 diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go index af49cb4e687..c18fcc5e56a 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -16,9 +16,9 @@ type DataProviderFactory struct { mock.Mock } -// NewDataProvider provides a mock function with given fields: ctx, topic, args, ch -func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { - ret := _m.Called(ctx, topic, args, ch) +// NewDataProvider provides a mock function with given fields: ctx, subscriptionID, topic, args, ch +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, subscriptionID string, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { + ret := _m.Called(ctx, subscriptionID, topic, args, ch) if len(ret) == 0 { panic("no return value specified for NewDataProvider") @@ -26,19 +26,19 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string var r0 data_providers.DataProvider var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { - return rf(ctx, topic, args, ch) + if rf, ok := ret.Get(0).(func(context.Context, string, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { + return rf(ctx, subscriptionID, topic, args, ch) } - if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { - r0 = rf(ctx, topic, args, ch) + if rf, ok := ret.Get(0).(func(context.Context, string, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { + r0 = rf(ctx, subscriptionID, topic, args, ch) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(data_providers.DataProvider) } } - if rf, ok := ret.Get(1).(func(context.Context, string, models.Arguments, chan<- interface{}) error); ok { - r1 = rf(ctx, topic, args, ch) + if rf, ok := ret.Get(1).(func(context.Context, string, string, models.Arguments, chan<- interface{}) error); ok { + r1 = rf(ctx, subscriptionID, topic, args, ch) } else { r1 = ret.Error(1) } diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go index c923be87d22..cf1e54919ec 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider.go @@ -39,6 +39,7 @@ func NewSendAndGetTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, @@ -59,6 +60,7 @@ func NewSendAndGetTransactionStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go index 7fad80de987..0907f70c674 100644 --- a/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/send_and_get_transaction_statuses_provider_test.go @@ -135,6 +135,7 @@ func (s *SendTransactionStatusesProviderSuite) TestSendTransactionStatusesDataPr ctx, s.log, s.api, + "dummy-id", s.linkGenerator, topic, test.arguments, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go index 86779fe5784..4bdf6a53359 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider.go @@ -42,6 +42,7 @@ func NewTransactionStatusesDataProvider( ctx context.Context, logger zerolog.Logger, api access.API, + subscriptionID string, linkGenerator commonmodels.LinkGenerator, topic string, arguments models.Arguments, @@ -62,6 +63,7 @@ func NewTransactionStatusesDataProvider( subCtx, cancel := context.WithCancel(ctx) p.baseDataProvider = newBaseDataProvider( + subscriptionID, topic, cancel, send, diff --git a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go index 72101f769f7..d28f6ae671c 100644 --- a/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go +++ b/engine/access/rest/websockets/data_providers/transaction_statuses_provider_test.go @@ -187,6 +187,7 @@ func (s *TransactionStatusesProviderSuite) TestTransactionStatusesDataProvider_I ctx, s.log, s.api, + "dummy-id", s.linkGenerator, topic, test.arguments, @@ -281,6 +282,7 @@ func (s *TransactionStatusesProviderSuite) TestMessageIndexTransactionStatusesPr ctx, s.log, s.api, + "dummy-id", s.linkGenerator, topic, arguments, diff --git a/engine/access/rest/websockets/data_providers/utittest.go b/engine/access/rest/websockets/data_providers/unit_test.go similarity index 91% rename from engine/access/rest/websockets/data_providers/utittest.go rename to engine/access/rest/websockets/data_providers/unit_test.go index 94d9534798f..cbc75393db9 100644 --- a/engine/access/rest/websockets/data_providers/utittest.go +++ b/engine/access/rest/websockets/data_providers/unit_test.go @@ -63,7 +63,7 @@ func testHappyPath( test.setupBackend(sub) // Create the data provider instance - provider, err := factory.NewDataProvider(ctx, topic, test.arguments, send) + provider, err := factory.NewDataProvider(ctx, "dummy-id", topic, test.arguments, send) require.NotNil(t, provider) require.NoError(t, err) @@ -72,7 +72,9 @@ func testHappyPath( defer provider.Close() // Run the provider in a separate goroutine + done := make(chan struct{}) go func() { + defer close(done) err = provider.Run() require.NoError(t, err) }() @@ -83,6 +85,9 @@ func testHappyPath( sendData(dataChan) }() + // Wait for the provider goroutine to finish + unittest.RequireCloseBefore(t, done, time.Second, "provider failed to stop") + // Collect responses for i, expected := range test.expectedResponses { unittest.RequireReturnsBefore(t, func() { diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go deleted file mode 100644 index fd206bed0b3..00000000000 --- a/engine/access/rest/websockets/error_codes.go +++ /dev/null @@ -1,10 +0,0 @@ -package websockets - -type Code int - -const ( - InvalidMessage Code = iota - InvalidArgument - NotFound - SubscriptionError -) diff --git a/engine/access/rest/websockets/models/account_models.go b/engine/access/rest/websockets/models/account_models.go index 7ad7243a96c..bad5e155721 100644 --- a/engine/access/rest/websockets/models/account_models.go +++ b/engine/access/rest/websockets/models/account_models.go @@ -2,7 +2,7 @@ package models // AccountStatusesResponse is the response message for 'events' topic. type AccountStatusesResponse struct { - BlockID string `json:"blockID"` + BlockID string `json:"block_id"` Height string `json:"height"` AccountEvents AccountEvents `json:"account_events"` MessageIndex uint64 `json:"message_index"` diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index cdcd72eb1ed..09c10d3ef8c 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -8,14 +8,20 @@ const ( // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { - Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions - ClientMessageID string `json:"message_id"` // ClientMessageID is a uuid generated by client to identify request/response uniquely + // SubscriptionID is UUID generated by either client or server to uniquely identify subscription. + // It is empty for 'list_subscription' action + SubscriptionID string `json:"subscription_id,omitempty"` + Action string `json:"action"` // Action is an action to perform (e.g. 'subscribe' to some data) } // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - SubscriptionID string `json:"subscription_id"` - ClientMessageID string `json:"message_id,omitempty"` // ClientMessageID may be empty in case we send msg by ourselves (e.g. error occurred) - Success bool `json:"success"` - Error ErrorMessage `json:"error,omitempty"` + SubscriptionID string `json:"subscription_id"` // SubscriptionID might be empty in case of error response + Error ErrorMessage `json:"error,omitempty"` // Error might be empty in case of OK response + Action string `json:"action"` +} + +type ErrorMessage struct { + Code int `json:"code"` // Code is an error code that categorizes an error + Message string `json:"message"` } diff --git a/engine/access/rest/websockets/models/error_message.go b/engine/access/rest/websockets/models/error_message.go deleted file mode 100644 index d5c0670926f..00000000000 --- a/engine/access/rest/websockets/models/error_message.go +++ /dev/null @@ -1,7 +0,0 @@ -package models - -type ErrorMessage struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action,omitempty"` -} diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index 4893a34b09d..49c8edf5b96 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -8,8 +8,7 @@ type ListSubscriptionsMessageRequest struct { // ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. // It contains a list of active subscriptions for the current WebSocket connection. type ListSubscriptionsMessageResponse struct { - ClientMessageID string `json:"message_id"` - Success bool `json:"success"` - Error ErrorMessage `json:"error,omitempty"` - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` + // Subscription list might be empty in case of no active subscriptions + Subscriptions []*SubscriptionEntry `json:"subscriptions"` + Action string `json:"action"` } diff --git a/engine/access/rest/websockets/models/subscribe_message.go b/engine/access/rest/websockets/models/subscribe_message.go index 532e4c6a987..b4bd9e871da 100644 --- a/engine/access/rest/websockets/models/subscribe_message.go +++ b/engine/access/rest/websockets/models/subscribe_message.go @@ -6,7 +6,7 @@ type Arguments map[string]interface{} type SubscribeMessageRequest struct { BaseMessageRequest Topic string `json:"topic"` // Topic to subscribe to - Arguments Arguments `json:"arguments"` // Additional arguments for subscription + Arguments Arguments `json:"arguments"` // Arguments are the arguments for the subscribed topic } // SubscribeMessageResponse represents the response to a subscription request. diff --git a/engine/access/rest/websockets/models/subscription_entry.go b/engine/access/rest/websockets/models/subscription_entry.go index d3f2b352bb7..9a60ab1a0d9 100644 --- a/engine/access/rest/websockets/models/subscription_entry.go +++ b/engine/access/rest/websockets/models/subscription_entry.go @@ -2,6 +2,7 @@ package models // SubscriptionEntry represents an active subscription entry. type SubscriptionEntry struct { - Topic string `json:"topic,omitempty"` // Topic of the subscription - ID string `json:"id,omitempty"` // Unique subscription ID + SubscriptionID string `json:"subscription_id"` // ID is a client generated UUID for subscription + Topic string `json:"topic"` // Topic of the subscription + Arguments Arguments `json:"arguments"` } diff --git a/engine/access/rest/websockets/models/unsubscribe_message.go b/engine/access/rest/websockets/models/unsubscribe_message.go index 1402189a601..f72e6cb5c7b 100644 --- a/engine/access/rest/websockets/models/unsubscribe_message.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -2,8 +2,8 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { + // Note: subscription_id is mandatory for this request BaseMessageRequest - SubscriptionID string `json:"id"` } // UnsubscribeMessageResponse represents the response to an unsubscription request. diff --git a/engine/access/rest/websockets/subscription_id.go b/engine/access/rest/websockets/subscription_id.go new file mode 100644 index 00000000000..09ffa5f7d5e --- /dev/null +++ b/engine/access/rest/websockets/subscription_id.go @@ -0,0 +1,54 @@ +package websockets + +import ( + "fmt" + + "github.com/google/uuid" +) + +const maxLen = 20 + +// SubscriptionID represents a subscription identifier used in websockets. +// The ID can either be provided by the client or generated by the server. +// - If provided by the client, it must adhere to specific restrictions. +// - If generated by the server, it is created as a UUID. +type SubscriptionID struct { + id string +} + +// NewSubscriptionID creates a new SubscriptionID based on the provided input. +// - If the input `id` is empty, a new UUID is generated and returned. +// - If the input `id` is non-empty, it is validated and returned if no errors. +func NewSubscriptionID(id string) (SubscriptionID, error) { + if len(id) == 0 { + return SubscriptionID{ + id: uuid.New().String(), + }, nil + } + + newID, err := ParseClientSubscriptionID(id) + if err != nil { + return SubscriptionID{}, err + } + + return newID, nil +} + +func ParseClientSubscriptionID(id string) (SubscriptionID, error) { + if len(id) == 0 { + return SubscriptionID{}, fmt.Errorf("subscription ID provided by the client must not be empty") + } + + if len(id) > maxLen { + return SubscriptionID{}, fmt.Errorf("subscription ID provided by the client must not exceed %d characters", maxLen) + } + + return SubscriptionID{ + id: id, + }, nil +} + +// String returns the string representation of the SubscriptionID. +func (id SubscriptionID) String() string { + return id.id +} diff --git a/engine/access/rest/websockets/subscription_id_test.go b/engine/access/rest/websockets/subscription_id_test.go new file mode 100644 index 00000000000..bbe5c3f7f9a --- /dev/null +++ b/engine/access/rest/websockets/subscription_id_test.go @@ -0,0 +1,60 @@ +package websockets + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewSubscriptionID(t *testing.T) { + t.Run("should generate new ID when input ID is empty", func(t *testing.T) { + subscriptionID, err := NewSubscriptionID("") + + assert.NoError(t, err) + assert.NotEmpty(t, subscriptionID.id) + assert.NoError(t, uuid.Validate(subscriptionID.id), "Generated ID should be a valid UUID") + }) + + t.Run("should return valid SubscriptionID when input ID is valid", func(t *testing.T) { + validID := "subscription/blocks" + subscriptionID, err := NewSubscriptionID(validID) + + assert.NoError(t, err) + assert.Equal(t, validID, subscriptionID.id) + }) + + t.Run("should return an error for invalid input in ParseClientSubscriptionID", func(t *testing.T) { + longID := fmt.Sprintf("%s%s", "id-", make([]byte, maxLen+1)) + _, err := NewSubscriptionID(longID) + + assert.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("subscription ID provided by the client must not exceed %d characters", maxLen)) + }) +} + +func TestParseClientSubscriptionID(t *testing.T) { + t.Run("should return error if input ID is empty", func(t *testing.T) { + _, err := ParseClientSubscriptionID("") + + assert.Error(t, err) + assert.EqualError(t, err, "subscription ID provided by the client must not be empty") + }) + + t.Run("should return error if input ID exceeds max length", func(t *testing.T) { + longID := fmt.Sprintf("%s%s", "id-", make([]byte, maxLen+1)) + _, err := ParseClientSubscriptionID(longID) + + assert.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("subscription ID provided by the client must not exceed %d characters", maxLen)) + }) + + t.Run("should return valid SubscriptionID for valid input", func(t *testing.T) { + validID := "subscription/blocks" + subscription, err := ParseClientSubscriptionID(validID) + + assert.NoError(t, err) + assert.Equal(t, validID, subscription.id) + }) +}