From 65e6d5447f61382bc9d4e83b59a6c8451e4afac7 Mon Sep 17 00:00:00 2001 From: Luke Rogerson Date: Thu, 21 Mar 2024 12:28:49 +0000 Subject: [PATCH] Order updates websockets --- data/redisrepo/pub_sub.go | 53 +++++++++ data/redisrepo/pub_sub_test.go | 41 +++++++ data/redisrepo/repository.go | 18 ++- data/store/store.go | 4 + go.mod | 2 +- mocks/service.go | 10 ++ mocks/transport.go | 5 + models/events.go | 13 ++ models/order.go | 4 + service/cancel_order.go | 3 + service/cancel_orders_for_user.go | 3 + service/create_order.go | 3 + service/evm_check_pending_txs_utils.go | 2 + service/pub_sub.go | 58 +++++++++ service/pub_sub_test.go | 36 ++++++ service/service.go | 2 + service/taker.go | 3 + transport/middleware/validate_user.go | 4 +- transport/rest/handler.go | 3 + transport/websocket/user_order_handler.go | 60 ++++++++++ .../websocket/user_order_handler_test.go | 112 ++++++++++++++++++ 21 files changed, 431 insertions(+), 8 deletions(-) create mode 100644 data/redisrepo/pub_sub.go create mode 100644 data/redisrepo/pub_sub_test.go create mode 100644 models/events.go create mode 100644 service/pub_sub.go create mode 100644 service/pub_sub_test.go create mode 100644 transport/websocket/user_order_handler.go create mode 100644 transport/websocket/user_order_handler_test.go diff --git a/data/redisrepo/pub_sub.go b/data/redisrepo/pub_sub.go new file mode 100644 index 0000000..01a40b7 --- /dev/null +++ b/data/redisrepo/pub_sub.go @@ -0,0 +1,53 @@ +package redisrepo + +import ( + "context" + "fmt" + + "github.com/orbs-network/order-book/utils/logger" + "github.com/orbs-network/order-book/utils/logger/logctx" +) + +// PublishEvent publishes an event to Redis +func (r *redisRepository) PublishEvent(ctx context.Context, key string, value interface{}) error { + err := r.client.Publish(ctx, key, value).Err() + + if err != nil { + logctx.Error(ctx, "failed to publish redis event", logger.Error(err), logger.String("key", key)) + return fmt.Errorf("failed to publish redis event: %v", err) + } + + logctx.Info(ctx, "published redis event", logger.String("key", key)) + return nil +} + +// SubscribeToEvents subscribes to events on a given Redis channel +func (r *redisRepository) SubscribeToEvents(ctx context.Context, channel string) (chan []byte, error) { + logctx.Info(ctx, "subscribing to channel", logger.String("channel", channel)) + + // Subscribe to the specified channel + pubsub := r.client.Subscribe(ctx, channel) + + // Wait for confirmation that subscription is created + _, err := pubsub.Receive(ctx) + if err != nil { + logctx.Error(ctx, "error on receiving from pubsub", logger.Error(err), logger.String("channel", channel)) + return nil, fmt.Errorf("error on receiving from pubsub: %w", err) + } + + // Create a channel to pass messages to the caller + messages := make(chan []byte) + + // Listen for messages + go func() { + defer pubsub.Close() + ch := pubsub.Channel() + for msg := range ch { + messages <- []byte(msg.Payload) + } + logctx.Info(ctx, "subscription ended", logger.String("channel", channel)) + close(messages) + }() + + return messages, nil +} diff --git a/data/redisrepo/pub_sub_test.go b/data/redisrepo/pub_sub_test.go new file mode 100644 index 0000000..911f7bb --- /dev/null +++ b/data/redisrepo/pub_sub_test.go @@ -0,0 +1,41 @@ +package redisrepo + +import ( + "testing" + + "github.com/go-redis/redismock/v9" + "github.com/stretchr/testify/assert" +) + +var orderJson, _ = order.ToJson() + +func TestRedisRepository_PublishEvent(t *testing.T) { + + t.Run("should publish event", func(t *testing.T) { + db, mock := redismock.NewClientMock() + + repo := &redisRepository{ + client: db, + } + + mock.ExpectPublish(order.Id.String(), orderJson).SetVal(1) + + err := repo.PublishEvent(ctx, order.Id.String(), orderJson) + + assert.NoError(t, err) + }) + + t.Run("should return error when failed to publish event", func(t *testing.T) { + db, mock := redismock.NewClientMock() + + repo := &redisRepository{ + client: db, + } + + mock.ExpectPublish(order.Id.String(), orderJson).SetErr(assert.AnError) + + err := repo.PublishEvent(ctx, order.Id.String(), orderJson) + + assert.ErrorContains(t, err, "failed to publish redis event") + }) +} diff --git a/data/redisrepo/repository.go b/data/redisrepo/repository.go index 5ff6524..ac0ca9b 100644 --- a/data/redisrepo/repository.go +++ b/data/redisrepo/repository.go @@ -7,18 +7,26 @@ import ( ) type redisRepository struct { - client redis.Cmdable + cmdable redis.Cmdable + client *redis.Client txMap map[uint]redis.Pipeliner ixIndex uint } -func NewRedisRepository(client redis.Cmdable) (*redisRepository, error) { - if client == nil { +func NewRedisRepository(cmdable redis.Cmdable) (*redisRepository, error) { + if cmdable == nil { return nil, fmt.Errorf("redis client cannot be nil") } + + client, ok := cmdable.(*redis.Client) + if !ok { + return nil, fmt.Errorf("cmdable is not a *redis.Client") + } + txMap := make(map[uint]redis.Pipeliner) return &redisRepository{ - client: client, - txMap: txMap, + cmdable: cmdable, + client: client, + txMap: txMap, }, nil } diff --git a/data/store/store.go b/data/store/store.go index e401e4b..4159250 100644 --- a/data/store/store.go +++ b/data/store/store.go @@ -72,4 +72,8 @@ type OrderBookStore interface { StoreUserResolvedSwap(ctx context.Context, userId uuid.UUID, swap models.Swap) error // utils EnumSubKeysOf(ctx context.Context, key string) ([]string, error) + + // PubSub + PublishEvent(ctx context.Context, key string, value interface{}) error + SubscribeToEvents(ctx context.Context, channel string) (chan []byte, error) } diff --git a/go.mod b/go.mod index 2c30a9a..46b02a7 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-redis/redismock/v9 v9.2.0 github.com/google/uuid v1.3.1 github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.4.2 github.com/redis/go-redis/v9 v9.2.1 github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.4 @@ -46,7 +47,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect - github.com/gorilla/websocket v1.4.2 // indirect github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/holiman/uint256 v1.2.3 // indirect github.com/huin/goupnp v1.3.0 // indirect diff --git a/mocks/service.go b/mocks/service.go index 64b0575..c838a23 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -27,6 +27,8 @@ type MockOrderBookStore struct { Sets map[string]map[string]struct{} // Pending swaps PendingSwaps []models.SwapTx + // PubSub + EventsChan chan []byte } func (m *MockOrderBookStore) StoreOpenOrder(ctx context.Context, order models.Order) error { @@ -188,3 +190,11 @@ func (m *MockOrderBookStore) EnumSubKeysOf(tx context.Context, key string) ([]st func (m *MockOrderBookStore) CancelPendingOrder(ctx context.Context, order models.Order) error { return m.Error } + +func (m *MockOrderBookStore) PublishEvent(ctx context.Context, key string, value interface{}) error { + return m.Error +} + +func (m *MockOrderBookStore) SubscribeToEvents(ctx context.Context, channel string) (chan []byte, error) { + return m.EventsChan, m.Error +} diff --git a/mocks/transport.go b/mocks/transport.go index fdb4dc3..5d936fc 100644 --- a/mocks/transport.go +++ b/mocks/transport.go @@ -19,6 +19,7 @@ type MockOrderBookService struct { Symbols []models.Symbol User *models.User BeginSwapRes models.BeginSwapRes + OrderEvents chan []byte } func (m *MockOrderBookService) GetUserByPublicKey(ctx context.Context, publicKey string) (*models.User, error) { @@ -73,6 +74,10 @@ func (m *MockOrderBookService) CancelOrdersForUser(ctx context.Context, userId u return ids, m.Error } +func (m *MockOrderBookService) SubscribeUserOrders(ctx context.Context, userId uuid.UUID) (chan []byte, error) { + return m.OrderEvents, m.Error +} + func (m *MockOrderBookService) GetQuote(ctx context.Context, symbol models.Symbol, side models.Side, inAmount decimal.Decimal, minOutAmount *decimal.Decimal) (models.QuoteRes, error) { return m.QuoteRes, m.Error } diff --git a/models/events.go b/models/events.go new file mode 100644 index 0000000..93ca237 --- /dev/null +++ b/models/events.go @@ -0,0 +1,13 @@ +// Pub/Sub events + +package models + +import ( + "fmt" + + "github.com/google/uuid" +) + +func CreateUserOrdersEventKey(userId uuid.UUID) string { + return fmt.Sprintf("user_orders:%s", userId) +} diff --git a/models/order.go b/models/order.go index c19589d..9f090cc 100644 --- a/models/order.go +++ b/models/order.go @@ -207,6 +207,10 @@ func (o *Order) MapToOrder(data map[string]string) error { return nil } +func (o *Order) ToJson() ([]byte, error) { + return json.Marshal(o) +} + // GetAvailableSize returns the size that is available to be filled func (o *Order) GetAvailableSize() decimal.Decimal { used := o.SizePending.Add(o.SizeFilled) diff --git a/service/cancel_order.go b/service/cancel_order.go index 35511f6..71f348a 100644 --- a/service/cancel_order.go +++ b/service/cancel_order.go @@ -102,6 +102,9 @@ func (s *Service) CancelOrder(ctx context.Context, input CancelOrderInput) (*uui }) logctx.Debug(ctx, "order cancelled", logger.String("orderId", order.Id.String()), logger.String("userId", order.UserId.String()), logger.String("size", order.Size.String()), logger.String("sizeFilled", order.SizeFilled.String()), logger.String("sizePending", order.SizePending.String())) + + s.publishOrderEvent(ctx, order) + return &order.Id, nil } diff --git a/service/cancel_orders_for_user.go b/service/cancel_orders_for_user.go index cf69ec6..4dd5c74 100644 --- a/service/cancel_orders_for_user.go +++ b/service/cancel_orders_for_user.go @@ -34,6 +34,9 @@ func (s *Service) CancelOrdersForUser(ctx context.Context, userId uuid.UUID, sym if err != nil { logctx.Error(ctx, "could not cancel order", logger.Error(err), logger.String("orderId", uid.String())) } + + s.publishOrderEvent(ctx, &order) + res = append(res, *uid) } } diff --git a/service/create_order.go b/service/create_order.go index f89d87b..f5b1a80 100644 --- a/service/create_order.go +++ b/service/create_order.go @@ -79,5 +79,8 @@ func (s *Service) createNewOrder(ctx context.Context, input CreateOrderInput, us } logctx.Debug(ctx, "new order created", logger.String("ID", order.Id.String()), logger.String("price", order.Price.String()), logger.String("size", order.Size.String())) + + s.publishOrderEvent(ctx, &order) + return order, nil } diff --git a/service/evm_check_pending_txs_utils.go b/service/evm_check_pending_txs_utils.go index b144906..792bb84 100644 --- a/service/evm_check_pending_txs_utils.go +++ b/service/evm_check_pending_txs_utils.go @@ -73,6 +73,8 @@ func (e *EvmClient) ResolveSwap(ctx context.Context, swap models.Swap, isSuccess continue } + e.publishOrderEvent(ctx, &order) + if order.IsFilled() { // add to filled orders if completely filled filledOrders = append(filledOrders, order) diff --git a/service/pub_sub.go b/service/pub_sub.go new file mode 100644 index 0000000..5787dea --- /dev/null +++ b/service/pub_sub.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/orbs-network/order-book/models" + "github.com/orbs-network/order-book/utils/logger" + "github.com/orbs-network/order-book/utils/logger/logctx" +) + +func (s *Service) SubscribeUserOrders(ctx context.Context, userId uuid.UUID) (chan []byte, error) { + logctx.Debug(ctx, "subscribing to user orders", logger.String("userId", userId.String())) + + eventKey := models.CreateUserOrdersEventKey(userId) + + channel, err := s.orderBookStore.SubscribeToEvents(ctx, fmt.Sprintf("user_orders:%s", userId)) + if err != nil { + logctx.Error(ctx, "failed to subscribe to user orders", logger.String("event", eventKey), logger.Error(err)) + return nil, fmt.Errorf("failed to subscribe to user orders: %w", err) + } + + return channel, nil +} + +func (s *EvmClient) publishOrderEvent(ctx context.Context, order *models.Order) { + key, value, err := createOrderEvent(ctx, order) + if err != nil { + return + } + + if err := s.orderBookStore.PublishEvent(ctx, key, value); err != nil { + logctx.Error(ctx, "failed to publish order event", logger.String("event", key), logger.Error(err)) + } +} + +func (s *Service) publishOrderEvent(ctx context.Context, order *models.Order) { + key, value, err := createOrderEvent(ctx, order) + if err != nil { + return + } + + if err := s.orderBookStore.PublishEvent(ctx, key, value); err != nil { + logctx.Error(ctx, "failed to publish order event", logger.String("event", key), logger.Error(err)) + } +} + +func createOrderEvent(ctx context.Context, order *models.Order) (key string, value []byte, err error) { + value, err = order.ToJson() + if err != nil { + logctx.Error(ctx, "failed to marshal order to json", logger.Error(err)) + } + + key = models.CreateUserOrdersEventKey(order.UserId) + + return key, value, err +} diff --git a/service/pub_sub_test.go b/service/pub_sub_test.go new file mode 100644 index 0000000..0e0a1ee --- /dev/null +++ b/service/pub_sub_test.go @@ -0,0 +1,36 @@ +package service_test + +import ( + "context" + "testing" + + "github.com/orbs-network/order-book/mocks" + "github.com/orbs-network/order-book/service" + "github.com/stretchr/testify/assert" +) + +func TestService_publishOrderEvent(t *testing.T) { + ctx := context.Background() + + t.Run("should subscribe to user order event updates", func(t *testing.T) { + svc, _ := service.New(&mocks.MockOrderBookStore{ + EventsChan: make(chan []byte), + }, &mocks.MockBcClient{}) + + channel, err := svc.SubscribeUserOrders(ctx, mocks.UserId) + + assert.NotNil(t, channel) + assert.NoError(t, err) + }) + + t.Run("should return error when failed to subscribe to user order event updates", func(t *testing.T) { + svc, _ := service.New(&mocks.MockOrderBookStore{ + Error: assert.AnError, + }, &mocks.MockBcClient{}) + + channel, err := svc.SubscribeUserOrders(ctx, mocks.UserId) + + assert.Nil(t, channel) + assert.Error(t, err) + }) +} diff --git a/service/service.go b/service/service.go index 7ef5b92..3e76db0 100644 --- a/service/service.go +++ b/service/service.go @@ -22,6 +22,8 @@ type OrderBookService interface { GetSymbols(ctx context.Context) ([]models.Symbol, error) GetOpenOrdersForUser(ctx context.Context, userId uuid.UUID) (orders []models.Order, totalOrders int, err error) GetFilledOrdersForUser(ctx context.Context, userId uuid.UUID) (orders []models.Order, totalOrders int, err error) + // Subscribe to order updates for a specific user + SubscribeUserOrders(ctx context.Context, userId uuid.UUID) (chan []byte, error) // taker api - INSTEAD GetQuote(ctx context.Context, symbol models.Symbol, side models.Side, inAmount decimal.Decimal, minOutAmount *decimal.Decimal) (models.QuoteRes, error) diff --git a/service/taker.go b/service/taker.go index dfa9a67..22bb080 100644 --- a/service/taker.go +++ b/service/taker.go @@ -69,6 +69,7 @@ func (s *Service) BeginSwap(ctx context.Context, data models.QuoteRes) (models.B logctx.Error(ctx, "Lock order Failed", logger.Error(err)) return models.BeginSwapRes{}, err } + s.publishOrderEvent(ctx, &res.Orders[i]) } // save @@ -128,6 +129,7 @@ func (s *Service) AbortSwap(ctx context.Context, swapId uuid.UUID) error { logctx.Error(ctx, "Unlock Failed", logger.Error(err)) return err } + s.publishOrderEvent(ctx, order) orders = append(orders, *order) } // cancelled orders @@ -215,6 +217,7 @@ func (s *Service) FillSwap(ctx context.Context, swapId uuid.UUID) error { } else { openOrders = append(openOrders, *order) } + s.publishOrderEvent(ctx, order) } } // store partial orders diff --git a/transport/middleware/validate_user.go b/transport/middleware/validate_user.go index 23aad29..630ae37 100644 --- a/transport/middleware/validate_user.go +++ b/transport/middleware/validate_user.go @@ -20,7 +20,7 @@ func ValidateUserMiddleware(getUserByApiKey GetUserByApiKeyFunc) func(http.Handl return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - key, err := bearerToken(r, "X-API-KEY") + key, err := BearerToken(r, "X-API-KEY") if err != nil { logctx.Warn(r.Context(), "incorrect API key format", logger.Error(err)) @@ -49,7 +49,7 @@ func ValidateUserMiddleware(getUserByApiKey GetUserByApiKeyFunc) func(http.Handl } } -func bearerToken(r *http.Request, header string) (string, error) { +func BearerToken(r *http.Request, header string) (string, error) { rawToken := r.Header.Get(header) pieces := strings.SplitN(rawToken, " ", 2) diff --git a/transport/rest/handler.go b/transport/rest/handler.go index e7f1125..5b96e5f 100644 --- a/transport/rest/handler.go +++ b/transport/rest/handler.go @@ -13,6 +13,7 @@ import ( "github.com/orbs-network/order-book/service" "github.com/orbs-network/order-book/transport/middleware" "github.com/orbs-network/order-book/transport/restutils" + "github.com/orbs-network/order-book/transport/websocket" "github.com/orbs-network/order-book/utils/logger" "github.com/orbs-network/order-book/utils/logger/logctx" ) @@ -117,6 +118,8 @@ func (h *Handler) initMakerRoutes(getUserByApiKey middleware.GetUserByApiKeyFunc mmApi.HandleFunc("/order/{orderId}", h.CancelOrderByOrderId).Methods("DELETE") // Cancel all orders for a user mmApi.HandleFunc("/orders", h.CancelOrdersForUser).Methods("DELETE") + // Subscribe to order events (websocket) + mmApi.HandleFunc("/ws/orders", websocket.WebSocketOrderHandler(h.svc, getUserByApiKey)).Methods("GET") } // Liquidity Hub specific routes diff --git a/transport/websocket/user_order_handler.go b/transport/websocket/user_order_handler.go new file mode 100644 index 0000000..0608f36 --- /dev/null +++ b/transport/websocket/user_order_handler.go @@ -0,0 +1,60 @@ +package websocket + +import ( + "net/http" + + "github.com/gorilla/websocket" + "github.com/orbs-network/order-book/service" + "github.com/orbs-network/order-book/transport/middleware" + "github.com/orbs-network/order-book/utils/logger" + "github.com/orbs-network/order-book/utils/logger/logctx" +) + +// WebSocketOrderHandler returns a handler that upgrades the connection to WebSocket and subscribes to order updates for a particular user +// The user is authenticated using the API key in the request +func WebSocketOrderHandler(orderSvc service.OrderBookService, getUserByApiKey middleware.GetUserByApiKeyFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Extract API key from query parameters + apiKey, err := middleware.BearerToken(r, "X-API-KEY") + if err != nil { + logctx.Warn(r.Context(), "incorrect API key format", logger.Error(err)) + http.Error(w, "Invalid API key (ensure the format is 'Bearer YOUR-API-KEY')", http.StatusBadRequest) + return + } + + // Authenticate user + user, err := getUserByApiKey(r.Context(), apiKey) + if err != nil { + logctx.Warn(r.Context(), "user not found by api key") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Upgrade to WebSocket + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logctx.Error(r.Context(), "error upgrading to websocket", logger.Error(err)) + http.Error(w, "Error subscribing to orders", http.StatusInternalServerError) + return + } + defer conn.Close() + + // Subscribe to that user's order updates + messageChan, err := orderSvc.SubscribeUserOrders(r.Context(), user.Id) + if err != nil { + logctx.Error(r.Context(), "error subscribing to user orders", logger.Error(err)) + http.Error(w, "Error subscribing to orders", http.StatusInternalServerError) + return + } + + // Read messages from the channel and send to WebSocket + for msg := range messageChan { + if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil { + logctx.Error(r.Context(), "error writing to websocket", logger.Error(err)) + break + } + + } + } +} diff --git a/transport/websocket/user_order_handler_test.go b/transport/websocket/user_order_handler_test.go new file mode 100644 index 0000000..d1fd301 --- /dev/null +++ b/transport/websocket/user_order_handler_test.go @@ -0,0 +1,112 @@ +package websocket + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + + "github.com/orbs-network/order-book/mocks" + "github.com/orbs-network/order-book/models" +) + +func TestWebSocketOrderHandler(t *testing.T) { + + t.Run("Test successful websocket lifecycle", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + WebSocketOrderHandler(&mocks.MockOrderBookService{}, mockGetUserByApiKey)(w, r) + })) + defer server.Close() + + wsURL := "ws" + server.URL[len("http"):] + + dialer := websocket.Dialer{} + headers := http.Header{} + headers.Set("X-API-KEY", "Bearer mock-api-key") + + conn, _, err := dialer.Dial(wsURL, headers) + assert.NoError(t, err) + defer conn.Close() + + assert.NotNil(t, conn, "The WebSocket connection should be established") + + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err, "Should be able to close the WebSocket connection") + + conn.Close() + }) + + t.Run("Test invalid API key", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + WebSocketOrderHandler(&mocks.MockOrderBookService{}, mockGetUserByApiKeyError)(w, r) + })) + defer server.Close() + + wsURL := "ws" + server.URL[len("http"):] + + dialer := websocket.Dialer{} + headers := http.Header{} + headers.Set("X-API-KEY", "Bearer invalid-api-key") + + conn, _, err := dialer.Dial(wsURL, headers) + assert.Error(t, err) + assert.Nil(t, conn, "The WebSocket connection should not be established") + }) + + t.Run("Test error upgrading to WebSocket", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + WebSocketOrderHandler(&mocks.MockOrderBookService{}, mockGetUserByApiKey)(w, r) + })) + defer server.Close() + + wsURL := "ws" + server.URL[len("http"):] + + dialer := websocket.Dialer{} + headers := http.Header{} + headers.Set("X-API-KEY", "Bearer mock-api-key") + + conn, _, err := dialer.Dial(wsURL, headers) + assert.NoError(t, err) + defer conn.Close() + + conn.Close() + + _, _, err = conn.ReadMessage() + assert.Error(t, err, "Expect an error after the connection is closed") + }) + + t.Run("Test error subscribing to user orders", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + WebSocketOrderHandler(&mocks.MockOrderBookService{Error: assert.AnError}, mockGetUserByApiKey)(w, r) + })) + defer server.Close() + + wsURL := "ws" + server.URL[len("http"):] + + dialer := websocket.Dialer{} + headers := http.Header{} + headers.Set("X-API-KEY", "Bearer mock-api-key") + + conn, _, err := dialer.Dial(wsURL, headers) + assert.NoError(t, err) + defer conn.Close() + + _, _, err = conn.ReadMessage() + assert.Error(t, err, "Expect an error after the connection is closed") + }) +} + +var mockGetUserByApiKey = func(ctx context.Context, apiKey string) (*models.User, error) { + return &models.User{ + Id: uuid.MustParse("00000000-0000-0000-0000-000000000007"), + ApiKey: "mock-api-key", + }, nil +} + +var mockGetUserByApiKeyError = func(ctx context.Context, apiKey string) (*models.User, error) { + return nil, assert.AnError +}