diff --git a/internal/http/checkin/checkin.go b/internal/http/checkin/checkin.go index 983ca4dc..e33b02c9 100644 --- a/internal/http/checkin/checkin.go +++ b/internal/http/checkin/checkin.go @@ -3,8 +3,8 @@ package checkin import ( "context" "net/http" - "time" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "github.com/uptrace/opentelemetry-go-extra/otelzap" ) @@ -26,18 +26,6 @@ type Config struct { WebsocketSubdomain string } -func (c *Config) WriteWait() time.Duration { - return time.Duration(c.WriteWaitSec) * time.Second -} - -func (c *Config) PingPeriod() time.Duration { - return time.Duration(c.PingPeriodSec) * time.Second -} - -func (c *Config) PongWait() time.Duration { - return time.Duration(c.PongWaitSec) * time.Second -} - type CheckinDataSvc interface { CheckinList(ctx context.Context, TeamId string, Date string, TimeZone string) ([]*thunderdome.TeamCheckin, error) CheckinCreate(ctx context.Context, TeamId string, UserId string, Yesterday string, Today string, Blockers string, Discuss string, GoalsMet bool) error @@ -68,11 +56,11 @@ type Service struct { logger *otelzap.Logger validateSessionCookie func(w http.ResponseWriter, r *http.Request) (string, error) validateUserCookie func(w http.ResponseWriter, r *http.Request) (string, error) - eventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool) UserService UserDataSvc AuthService AuthDataSvc CheckinService CheckinDataSvc TeamService TeamDataSvc + hub *wshub.Hub } // New returns a new retro with websocket hub/client and event handlers @@ -95,16 +83,26 @@ func New( TeamService: teamService, } - c.eventHandlers = map[string]func(context.Context, string, string, string) ([]byte, error, bool){ + c.hub = wshub.NewHub(logger, wshub.Config{ + AppDomain: config.AppDomain, + WebsocketSubdomain: config.WebsocketSubdomain, + WriteWaitSec: config.WriteWaitSec, + PongWaitSec: config.PongWaitSec, + PingPeriodSec: config.PingPeriodSec, + }, map[string]func(context.Context, string, string, string) ([]byte, error, bool){ "checkin_create": c.CheckinCreate, "checkin_update": c.CheckinUpdate, "checkin_delete": c.CheckinDelete, "comment_create": c.CommentCreate, "comment_update": c.CommentUpdate, "comment_delete": c.CommentDelete, - } + }, + map[string]struct{}{}, + nil, + nil, + ) - go h.run() + go c.hub.Run() return c } diff --git a/internal/http/checkin/client.go b/internal/http/checkin/client.go index b43af302..80f50c59 100644 --- a/internal/http/checkin/client.go +++ b/internal/http/checkin/client.go @@ -2,294 +2,98 @@ package checkin import ( "context" - "encoding/json" - "fmt" "net/http" - "net/url" - "time" - "unicode/utf8" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "go.uber.org/zap" - "github.com/gorilla/mux" "github.com/gorilla/websocket" ) -const ( - // Maximum message size allowed from peer. - maxMessageSize int64 = 1024 * 1024 -) - -// connection is a middleman between the websocket connection and the hub. -type connection struct { - config *Config - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (sub subscription) readPump(b *Service, ctx context.Context) { - var forceClosed bool - c := sub.conn - UserID := sub.UserID - TeamID := sub.arena - - defer func() { - h.unregister <- sub - if forceClosed { - cm := websocket.FormatCloseMessage(4002, "abandoned") - if err := c.ws.WriteControl(websocket.CloseMessage, cm, time.Now().Add(sub.config.WriteWait())); err != nil { - b.logger.Ctx(ctx).Error("abandon error", zap.Error(err), - zap.String("team_id", TeamID), zap.String("session_user_id", UserID)) - } - } - if err := c.ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err), - zap.String("team_id", TeamID), zap.String("session_user_id", UserID)) - } - }() - c.ws.SetReadLimit(maxMessageSize) - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - c.ws.SetPongHandler(func(string) error { - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - return nil - }) - - for { - var badEvent bool - var eventErr error - _, msg, err := c.ws.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("team_id", TeamID), zap.String("session_user_id", UserID)) - } - break - } - - keyVal := make(map[string]string) - err = json.Unmarshal(msg, &keyVal) - if err != nil { - badEvent = true - b.logger.Error("unexpected retro event json error", zap.Error(err), - zap.String("team_id", TeamID), zap.String("session_user_id", UserID)) - } - - eventType := keyVal["type"] - eventValue := keyVal["value"] - - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok && !badEvent { - msg, eventErr, forceClosed = b.eventHandlers[eventType](ctx, TeamID, UserID, eventValue) - if eventErr != nil { - badEvent = true - - // don't log forceClosed events e.g. Abandon - if !forceClosed { - b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(eventErr), - zap.String("team_id", TeamID), zap.String("session_user_id", UserID), - zap.String("checkin_event_type", eventType)) - } - } - } - - if !badEvent { - m := message{msg, sub.arena} - h.broadcast <- m - } - - if forceClosed { - break - } - } -} - -// write a message with the given message type and payload. -func (c *connection) write(mt int, payload []byte) error { - _ = c.ws.SetWriteDeadline(time.Now().Add(c.config.WriteWait())) - return c.ws.WriteMessage(mt, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (sub *subscription) writePump() { - c := sub.conn - ticker := time.NewTicker(sub.config.PingPeriod()) - defer func() { - ticker.Stop() - _ = c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - _ = c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, nil); err != nil { - return - } - } - } -} - -func (b *Service) createWebsocketUpgrader() websocket.Upgrader { - return websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, b.config.AppDomain, b.config.WebsocketSubdomain) - }, - } -} - -func checkOrigin(r *http.Request, appDomain string, subDomain string) bool { - origin := r.Header.Get("Origin") - if len(origin) == 0 { - return true - } - originUrl, err := url.Parse(origin) - if err != nil { - return false - } - appDomainCheck := equalASCIIFold(originUrl.Host, appDomain) - subDomainCheck := equalASCIIFold(originUrl.Host, fmt.Sprintf("%s.%s", subDomain, appDomain)) - hostCheck := equalASCIIFold(originUrl.Host, r.Host) - - return appDomainCheck || subDomainCheck || hostCheck -} - -// equalASCIIFold returns true if s is equal to t with ASCII case folding as -// defined in RFC 4790. -// Taken from Gorilla Websocket, https://github.com/gorilla/websocket/blob/main/util.go -func equalASCIIFold(s, t string) bool { - for s != "" && t != "" { - sr, size := utf8.DecodeRuneInString(s) - s = s[size:] - tr, size := utf8.DecodeRuneInString(t) - t = t[size:] - if sr == tr { - continue - } - if 'A' <= sr && sr <= 'Z' { - sr = sr + 'a' - 'A' - } - if 'A' <= tr && tr <= 'Z' { - tr = tr + 'a' - 'A' - } - if sr != tr { - return false - } - } - return s == t -} - -// handleSocketUnauthorized sets the format close message and closes the websocket -func (b *Service) handleSocketClose(ctx context.Context, ws *websocket.Conn, closeCode int, text string) { - cm := websocket.FormatCloseMessage(closeCode, text) - if err := ws.WriteMessage(websocket.CloseMessage, cm); err != nil { - b.logger.Ctx(ctx).Error("unauthorized close error", zap.Error(err)) - } - if err := ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err)) - } -} - // ServeWs handles websocket requests from the peer. func (b *Service) ServeWs() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - teamID := vars["teamId"] + return b.hub.WebSocketHandler("teamId", func(w http.ResponseWriter, r *http.Request, c *wshub.Connection, roomID string) *wshub.AuthError { ctx := r.Context() var User *thunderdome.User - // upgrade to WebSocket connection - var upgrader = b.createWebsocketUpgrader() - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - b.logger.Ctx(ctx).Error("websocket upgrade error", zap.Error(err), - zap.String("team_id", teamID)) - return - } - c := &connection{config: &b.config, send: make(chan []byte, 256), ws: ws} - SessionId, cookieErr := b.validateSessionCookie(w, r) if cookieErr != nil && cookieErr.Error() != "COOKIE_NOT_FOUND" { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } if SessionId != "" { var userErr error User, userErr = b.AuthService.GetSessionUser(ctx, SessionId) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } else { UserID, err := b.validateUserCookie(w, r) if err != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } var userErr error User, userErr = b.UserService.GetGuestUser(ctx, UserID) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } // make sure team is legit - _, retroErr := b.TeamService.TeamGet(context.Background(), teamID) + _, retroErr := b.TeamService.TeamGet(context.Background(), roomID) if retroErr != nil { - b.handleSocketClose(ctx, ws, 4004, "team not found") - return + authErr := wshub.AuthError{ + Code: 4004, + Message: "team not found", + } + return &authErr } // make sure user is a team user - _, UserErr := b.TeamService.TeamUserRole(ctx, User.Id, teamID) + _, UserErr := b.TeamService.TeamUserRole(ctx, User.Id, roomID) if UserErr != nil { b.logger.Ctx(ctx).Error("REQUIRES_TEAM_USER", zap.Error(UserErr), - zap.String("team_id", teamID), zap.String("session_user_id", User.Id)) - b.handleSocketClose(ctx, ws, 4005, "REQUIRES_TEAM_USER") - return - } + zap.String("team_id", roomID), zap.String("session_user_id", User.Id)) - ss := subscription{&b.config, c, teamID, User.Id} - h.register <- ss + authErr := wshub.AuthError{ + Code: 4005, + Message: "REQUIRES_TEAM_USER", + } + return &authErr + } - initEvent := createSocketEvent("init", "", User.Id) - _ = c.write(websocket.TextMessage, initEvent) + sub := b.hub.NewSubscriber(c.Ws, User.Id, roomID) - go ss.writePump() - go ss.readPump(b, ctx) - } -} + initEvent := wshub.CreateSocketEvent("init", "", User.Id) + _ = sub.Conn.Write(websocket.TextMessage, initEvent) -// APIEvent handles api driven events into the arena (if active) -func (b *Service) APIEvent(ctx context.Context, arenaID string, UserID, eventType string, eventValue string) error { - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok { - msg, eventErr, _ := b.eventHandlers[eventType](ctx, arenaID, UserID, eventValue) - if eventErr != nil { - return eventErr - } + go sub.WritePump() + go sub.ReadPump(ctx, b.hub) - if _, ok := h.arenas[arenaID]; ok { - m := message{msg, arenaID} - h.broadcast <- m - } - } + return nil + }) +} - return nil +// APIEvent handles api driven events into the team checkin (if active) +func (b *Service) APIEvent(ctx context.Context, teamID string, UserID, eventType string, eventValue string) error { + return b.hub.ProcessAPIEventHandler(ctx, UserID, teamID, eventType, eventValue) } diff --git a/internal/http/checkin/events.go b/internal/http/checkin/events.go index e20dd894..3308983c 100644 --- a/internal/http/checkin/events.go +++ b/internal/http/checkin/events.go @@ -3,6 +3,8 @@ package checkin import ( "context" "encoding/json" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" ) // CheckinCreate creates a checkin @@ -25,7 +27,7 @@ func (b *Service) CheckinCreate(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("checkin_added", "", "") + msg := wshub.CreateSocketEvent("checkin_added", "", "") return msg, nil, false } @@ -50,7 +52,7 @@ func (b *Service) CheckinUpdate(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("checkin_updated", "", "") + msg := wshub.CreateSocketEvent("checkin_updated", "", "") return msg, nil, false } @@ -70,7 +72,7 @@ func (b *Service) CheckinDelete(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("checkin_deleted", "", "") + msg := wshub.CreateSocketEvent("checkin_deleted", "", "") return msg, nil, false } @@ -92,7 +94,7 @@ func (b *Service) CommentCreate(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("comment_added", "", "") + msg := wshub.CreateSocketEvent("comment_added", "", "") return msg, nil, false } @@ -114,7 +116,7 @@ func (b *Service) CommentUpdate(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("comment_updated", "", "") + msg := wshub.CreateSocketEvent("comment_updated", "", "") return msg, nil, false } @@ -134,26 +136,7 @@ func (b *Service) CommentDelete(ctx context.Context, TeamID string, UserID strin return nil, err, false } - msg := createSocketEvent("comment_deleted", "", "") + msg := wshub.CreateSocketEvent("comment_deleted", "", "") return msg, nil, false } - -// socketEvent is the event structure used for socket messages -type socketEvent struct { - Type string `json:"type"` - Value string `json:"value"` - User string `json:"userId"` -} - -func createSocketEvent(Type string, Value string, User string) []byte { - newEvent := &socketEvent{ - Type: Type, - Value: Value, - User: User, - } - - event, _ := json.Marshal(newEvent) - - return event -} diff --git a/internal/http/checkin/hub.go b/internal/http/checkin/hub.go deleted file mode 100644 index af5becf6..00000000 --- a/internal/http/checkin/hub.go +++ /dev/null @@ -1,74 +0,0 @@ -package checkin - -type message struct { - data []byte - arena string -} - -type subscription struct { - config *Config - conn *connection - arena string - UserID string -} - -// hub maintains the set of active connections and broadcasts messages to the -// connections. -type hub struct { - // Registered connections. - arenas map[string]map[*connection]struct{} - - // Inbound messages from the connections. - broadcast chan message - - // Register requests from the connections. - register chan subscription - - // Unregister requests from connections. - unregister chan subscription -} - -var h = hub{ - broadcast: make(chan message), - register: make(chan subscription), - unregister: make(chan subscription), - arenas: make(map[string]map[*connection]struct{}), -} - -func (h *hub) run() { - for { - select { - case a := <-h.register: - connections := h.arenas[a.arena] - if connections == nil { - connections = make(map[*connection]struct{}) - h.arenas[a.arena] = connections - } - h.arenas[a.arena][a.conn] = struct{}{} - case a := <-h.unregister: - connections := h.arenas[a.arena] - if connections != nil { - if _, ok := connections[a.conn]; ok { - delete(connections, a.conn) - close(a.conn.send) - if len(connections) == 0 { - delete(h.arenas, a.arena) - } - } - } - case m := <-h.broadcast: - connections := h.arenas[m.arena] - for c := range connections { - select { - case c.send <- m.data: - default: - close(c.send) - delete(connections, c) - if len(connections) == 0 { - delete(h.arenas, m.arena) - } - } - } - } - } -} diff --git a/internal/http/poker/client.go b/internal/http/poker/client.go index cbb9b242..dca5ff02 100644 --- a/internal/http/poker/client.go +++ b/internal/http/poker/client.go @@ -5,316 +5,104 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" "net/http" - "net/url" - "time" - "unicode/utf8" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "go.uber.org/zap" - "github.com/gorilla/mux" "github.com/gorilla/websocket" ) -const ( - // Maximum message size allowed from peer. - maxMessageSize = 1024 * 1024 -) - -// leaderOnlyOperations contains a map of operations that only a battle leader can execute -var leaderOnlyOperations = map[string]struct{}{ - "add_plan": {}, - "revise_plan": {}, - "burn_plan": {}, - "activate_plan": {}, - "skip_plan": {}, - "end_voting": {}, - "finalize_plan": {}, - "jab_warrior": {}, - "promote_leader": {}, - "demote_leader": {}, - "revise_battle": {}, - "concede_battle": {}, -} - -// connection is a middleman between the websocket connection and the hub. -type connection struct { - config *Config - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (sub subscription) readPump(b *Service, ctx context.Context) { - var forceClosed bool - c := sub.conn - UserID := sub.UserID - BattleID := sub.arena - - defer func() { - Users := b.BattleService.RetreatUser(BattleID, UserID) - UpdatedUsers, _ := json.Marshal(Users) - - retreatEvent := createSocketEvent("warrior_retreated", string(UpdatedUsers), UserID) - m := message{retreatEvent, BattleID} - h.broadcast <- m - - h.unregister <- sub - if forceClosed { - cm := websocket.FormatCloseMessage(4002, "abandoned") - if err := c.ws.WriteControl(websocket.CloseMessage, cm, time.Now().Add(sub.config.WriteWait())); err != nil { - b.logger.Ctx(ctx).Error("abandon error", zap.Error(err), - zap.String("poker_id", BattleID), zap.String("session_user_id", UserID)) - } - } - if err := c.ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err), - zap.String("poker_id", BattleID), zap.String("session_user_id", UserID)) - } - }() - c.ws.SetReadLimit(maxMessageSize) - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - c.ws.SetPongHandler(func(string) error { - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - return nil - }) - - for { - var badEvent bool - var eventErr error - _, msg, err := c.ws.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("poker_id", BattleID), zap.String("session_user_id", UserID)) - } - break - } - - keyVal := make(map[string]string) - err = json.Unmarshal(msg, &keyVal) - if err != nil { - badEvent = true - b.logger.Error("unexpected battle event json error", zap.Error(err), - zap.String("poker_id", BattleID), zap.String("session_user_id", UserID)) - } - - eventType := keyVal["type"] - eventValue := keyVal["value"] - - // confirm leader for any operation that requires it - if _, ok := leaderOnlyOperations[eventType]; ok && !badEvent { - err := b.BattleService.ConfirmFacilitator(BattleID, UserID) - if err != nil { - badEvent = true - } - } - - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok && !badEvent { - msg, eventErr, forceClosed = b.eventHandlers[eventType](ctx, BattleID, UserID, eventValue) - if eventErr != nil { - badEvent = true - - // don't log forceClosed events e.g. Abandon - if !forceClosed { - b.logger.Ctx(ctx).Error("close error", zap.Error(eventErr), - zap.String("poker_id", BattleID), zap.String("session_user_id", UserID), - zap.String("poker_event_type", eventType)) - } - } - } - - if !badEvent { - m := message{msg, sub.arena} - h.broadcast <- m - } - - if forceClosed { - break - } - } -} - -// write a message with the given message type and payload. -func (c *connection) write(mt int, payload []byte) error { - _ = c.ws.SetWriteDeadline(time.Now().Add(c.config.WriteWait())) - return c.ws.WriteMessage(mt, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (sub *subscription) writePump() { - c := sub.conn - ticker := time.NewTicker(sub.config.PingPeriod()) - defer func() { - ticker.Stop() - _ = c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - _ = c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, nil); err != nil { - return - } - } - } -} - -func (b *Service) createWebsocketUpgrader() websocket.Upgrader { - return websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, b.config.AppDomain, b.config.WebsocketSubdomain) - }, - } -} - -func checkOrigin(r *http.Request, appDomain string, subDomain string) bool { - origin := r.Header.Get("Origin") - if len(origin) == 0 { - return true - } - originUrl, err := url.Parse(origin) - if err != nil { - return false - } - appDomainCheck := equalASCIIFold(originUrl.Host, appDomain) - subDomainCheck := equalASCIIFold(originUrl.Host, fmt.Sprintf("%s.%s", subDomain, appDomain)) - hostCheck := equalASCIIFold(originUrl.Host, r.Host) - - return appDomainCheck || subDomainCheck || hostCheck -} - -// equalASCIIFold returns true if s is equal to t with ASCII case folding as -// defined in RFC 4790. -// Taken from Gorilla Websocket, https://github.com/gorilla/websocket/blob/main/util.go -func equalASCIIFold(s, t string) bool { - for s != "" && t != "" { - sr, size := utf8.DecodeRuneInString(s) - s = s[size:] - tr, size := utf8.DecodeRuneInString(t) - t = t[size:] - if sr == tr { - continue - } - if 'A' <= sr && sr <= 'Z' { - sr = sr + 'a' - 'A' - } - if 'A' <= tr && tr <= 'Z' { - tr = tr + 'a' - 'A' - } - if sr != tr { - return false - } - } - return s == t -} - -// handleSocketUnauthorized sets the format close message and closes the websocket -func (b *Service) handleSocketClose(ctx context.Context, ws *websocket.Conn, closeCode int, text string) { - cm := websocket.FormatCloseMessage(closeCode, text) - if err := ws.WriteMessage(websocket.CloseMessage, cm); err != nil { - b.logger.Ctx(ctx).Error("unauthorized close error", zap.Error(err)) - } - if err := ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err)) - } -} - // ServeBattleWs handles websocket requests from the peer. func (b *Service) ServeBattleWs() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - battleID := vars["battleId"] + return b.hub.WebSocketHandler("battleId", func(w http.ResponseWriter, r *http.Request, c *wshub.Connection, roomID string) *wshub.AuthError { ctx := r.Context() var User *thunderdome.User - var UserAuthed bool - - // upgrade to WebSocket connection - var upgrader = b.createWebsocketUpgrader() - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - b.logger.Ctx(ctx).Error("websocket upgrade error", zap.Error(err), - zap.String("poker_id", battleID)) - return - } - c := &connection{config: &b.config, send: make(chan []byte, 256), ws: ws} SessionId, cookieErr := b.validateSessionCookie(w, r) if cookieErr != nil && cookieErr.Error() != "COOKIE_NOT_FOUND" { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } if SessionId != "" { var userErr error User, userErr = b.AuthService.GetSessionUser(ctx, SessionId) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } else { UserID, err := b.validateUserCookie(w, r) if err != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } var userErr error User, userErr = b.UserService.GetGuestUser(ctx, UserID) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } // make sure battle is legit - battle, battleErr := b.BattleService.GetGame(battleID, User.Id) + battle, battleErr := b.BattleService.GetGame(roomID, User.Id) if battleErr != nil { - b.handleSocketClose(ctx, ws, 4004, "battle not found") - return + authErr := wshub.AuthError{ + Code: 4004, + Message: "poker game not found", + } + return &authErr } // check users battle active status - UserErr := b.BattleService.GetUserActiveStatus(battleID, User.Id) + UserErr := b.BattleService.GetUserActiveStatus(roomID, User.Id) if UserErr != nil && !errors.Is(UserErr, sql.ErrNoRows) { usrErrMsg := UserErr.Error() + var authErr wshub.AuthError if usrErrMsg == "DUPLICATE_BATTLE_USER" { - b.handleSocketClose(ctx, ws, 4003, "duplicate session") + authErr = wshub.AuthError{ + Code: 4003, + Message: "duplicate session", + } } else { b.logger.Ctx(ctx).Error("error finding user", zap.Error(UserErr), - zap.String("poker_id", battleID), zap.String("session_user_id", User.Id)) - b.handleSocketClose(ctx, ws, 4005, "internal error") - } - return - } + zap.String("poker_id", roomID), zap.String("session_user_id", User.Id)) - if battle.JoinCode != "" && (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) { - jcrEvent := createSocketEvent("join_code_required", "", User.Id) - _ = c.write(websocket.TextMessage, jcrEvent) + authErr = wshub.AuthError{ + Code: 4005, + Message: "internal error", + } + } + return &authErr + } else if (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) && battle.JoinCode != "" { + jcrEvent := wshub.CreateSocketEvent("join_code_required", "", User.Id) + _ = c.Write(websocket.TextMessage, jcrEvent) for { - _, msg, err := c.ws.ReadMessage() + _, msg, err := c.Ws.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("poker_id", battleID), zap.String("session_user_id", User.Id)) + zap.String("poker_id", roomID), zap.String("session_user_id", User.Id)) } break } @@ -322,66 +110,47 @@ func (b *Service) ServeBattleWs() http.HandlerFunc { keyVal := make(map[string]string) err = json.Unmarshal(msg, &keyVal) if err != nil { - b.logger.Error("unexpected battle message error", zap.Error(err), - zap.String("poker_id", battleID), zap.String("session_user_id", User.Id)) + b.logger.Error("unexpected message error", zap.Error(err), + zap.String("poker_id", roomID), zap.String("session_user_id", User.Id)) } - if keyVal["type"] == "auth_battle" && keyVal["value"] == battle.JoinCode { - UserAuthed = true + if keyVal["type"] == "auth_game" && keyVal["value"] == battle.JoinCode { + // join code is valid, continue to room break - } else if keyVal["type"] == "auth_battle" { - authIncorrect := createSocketEvent("join_code_incorrect", "", User.Id) - _ = c.write(websocket.TextMessage, authIncorrect) + } else if keyVal["type"] == "auth_game" { + authIncorrect := wshub.CreateSocketEvent("join_code_incorrect", "", User.Id) + _ = c.Write(websocket.TextMessage, authIncorrect) } } - } else { - UserAuthed = true } - if UserAuthed { - ss := subscription{&b.config, c, battleID, User.Id} - h.register <- ss - - Users, _ := b.BattleService.AddUser(ss.arena, User.Id) - UpdatedUsers, _ := json.Marshal(Users) + sub := b.hub.NewSubscriber(c.Ws, User.Id, roomID) - Battle, _ := json.Marshal(battle) - initEvent := createSocketEvent("init", string(Battle), User.Id) - _ = c.write(websocket.TextMessage, initEvent) + Users, _ := b.BattleService.AddUser(roomID, User.Id) + UpdatedUsers, _ := json.Marshal(Users) - joinedEvent := createSocketEvent("warrior_joined", string(UpdatedUsers), User.Id) - m := message{joinedEvent, ss.arena} - h.broadcast <- m + Battle, _ := json.Marshal(battle) + initEvent := wshub.CreateSocketEvent("init", string(Battle), User.Id) + _ = sub.Conn.Write(websocket.TextMessage, initEvent) - go ss.writePump() - go ss.readPump(b, ctx) - } - } -} + userJoinedEvent := wshub.CreateSocketEvent("user_joined", string(UpdatedUsers), User.Id) + b.hub.Broadcast(wshub.Message{Data: userJoinedEvent, Room: roomID}) -// APIEvent handles api driven events into the arena (if active) -func (b *Service) APIEvent(ctx context.Context, arenaID string, UserID, eventType string, eventValue string) error { + go sub.WritePump() + go sub.ReadPump(ctx, b.hub) - // confirm leader for any operation that requires it - if _, ok := leaderOnlyOperations[eventType]; ok { - err := b.BattleService.ConfirmFacilitator(arenaID, UserID) - if err != nil { - return err - } - } + return nil + }) +} - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok { - msg, eventErr, _ := b.eventHandlers[eventType](ctx, arenaID, UserID, eventValue) - if eventErr != nil { - return eventErr - } +func (b *Service) RetreatUser(roomID string, userID string) string { + Users := b.BattleService.RetreatUser(roomID, userID) + UpdatedUsers, _ := json.Marshal(Users) - if _, ok := h.arenas[arenaID]; ok { - m := message{msg, arenaID} - h.broadcast <- m - } - } + return string(UpdatedUsers) +} - return nil +// APIEvent handles api driven events into the poker game (if active) +func (b *Service) APIEvent(ctx context.Context, pokerID string, UserID, eventType string, eventValue string) error { + return b.hub.ProcessAPIEventHandler(ctx, UserID, pokerID, eventType, eventValue) } diff --git a/internal/http/poker/events.go b/internal/http/poker/events.go index 3eda1f0c..b17022f0 100644 --- a/internal/http/poker/events.go +++ b/internal/http/poker/events.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "errors" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" ) // UserNudge handles notifying user that they need to vote func (b *Service) UserNudge(ctx context.Context, BattleID string, UserID string, EventValue string) ([]byte, error, bool) { - msg := createSocketEvent("jab_warrior", EventValue, UserID) + msg := wshub.CreateSocketEvent("jab_warrior", EventValue, UserID) return msg, nil, false } @@ -30,7 +32,7 @@ func (b *Service) UserVote(ctx context.Context, BattleID string, UserID string, Storys, AllVoted := b.BattleService.SetVote(BattleID, UserID, wv.StoryID, wv.VoteValue) updatedStorys, _ := json.Marshal(Storys) - msg = createSocketEvent("vote_activity", string(updatedStorys), UserID) + msg = wshub.CreateSocketEvent("vote_activity", string(updatedStorys), UserID) if AllVoted && wv.AutoFinishVoting { plans, err := b.BattleService.EndStoryVoting(BattleID, wv.StoryID) @@ -38,7 +40,7 @@ func (b *Service) UserVote(ctx context.Context, BattleID string, UserID string, return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg = createSocketEvent("voting_ended", string(updatedStorys), "") + msg = wshub.CreateSocketEvent("voting_ended", string(updatedStorys), "") } return msg, nil, false @@ -54,7 +56,7 @@ func (b *Service) UserVoteRetract(ctx context.Context, BattleID string, UserID s } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("vote_retracted", string(updatedStorys), UserID) + msg := wshub.CreateSocketEvent("vote_retracted", string(updatedStorys), UserID) return msg, nil, false } @@ -67,7 +69,7 @@ func (b *Service) UserPromote(ctx context.Context, BattleID string, UserID strin } leadersJson, _ := json.Marshal(leaders) - msg := createSocketEvent("leaders_updated", string(leadersJson), "") + msg := wshub.CreateSocketEvent("leaders_updated", string(leadersJson), "") return msg, nil, false } @@ -80,7 +82,7 @@ func (b *Service) UserDemote(ctx context.Context, BattleID string, UserID string } leadersJson, _ := json.Marshal(leaders) - msg := createSocketEvent("leaders_updated", string(leadersJson), "") + msg := wshub.CreateSocketEvent("leaders_updated", string(leadersJson), "") return msg, nil, false } @@ -99,7 +101,7 @@ func (b *Service) UserPromoteSelf(ctx context.Context, BattleID string, UserID s } leadersJson, _ := json.Marshal(leaders) - msg := createSocketEvent("leaders_updated", string(leadersJson), "") + msg := wshub.CreateSocketEvent("leaders_updated", string(leadersJson), "") return msg, nil, false } else { @@ -122,7 +124,7 @@ func (b *Service) UserSpectatorToggle(ctx context.Context, BattleID string, User } usersJson, _ := json.Marshal(users) - msg := createSocketEvent("users_updated", string(usersJson), "") + msg := wshub.CreateSocketEvent("users_updated", string(usersJson), "") return msg, nil, false } @@ -134,7 +136,7 @@ func (b *Service) StoryVoteEnd(ctx context.Context, BattleID string, UserID stri return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("voting_ended", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("voting_ended", string(updatedStorys), "") return msg, nil, false } @@ -174,7 +176,7 @@ func (b *Service) Revise(ctx context.Context, BattleID string, UserID string, Ev rb.LeaderCode = "" updatedBattle, _ := json.Marshal(rb) - msg := createSocketEvent("battle_revised", string(updatedBattle), "") + msg := wshub.CreateSocketEvent("battle_revised", string(updatedBattle), "") return msg, nil, false } @@ -185,7 +187,7 @@ func (b *Service) Delete(ctx context.Context, BattleID string, UserID string, Ev if err != nil { return nil, err, false } - msg := createSocketEvent("battle_conceded", "", "") + msg := wshub.CreateSocketEvent("battle_conceded", "", "") return msg, nil, false } @@ -211,7 +213,7 @@ func (b *Service) StoryAdd(ctx context.Context, BattleID string, UserID string, return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_added", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_added", string(updatedStorys), "") return msg, nil, false } @@ -238,7 +240,7 @@ func (b *Service) StoryRevise(ctx context.Context, BattleID string, UserID strin return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_revised", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_revised", string(updatedStorys), "") return msg, nil, false } @@ -250,7 +252,7 @@ func (b *Service) StoryDelete(ctx context.Context, BattleID string, UserID strin return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_burned", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_burned", string(updatedStorys), "") return msg, nil, false } @@ -271,7 +273,7 @@ func (b *Service) StoryArrange(ctx context.Context, BattleID string, UserID stri return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("story_arranged", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("story_arranged", string(updatedStorys), "") return msg, nil, false } @@ -283,7 +285,7 @@ func (b *Service) StoryActivate(ctx context.Context, BattleID string, UserID str return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_activated", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_activated", string(updatedStorys), "") return msg, nil, false } @@ -295,7 +297,7 @@ func (b *Service) StorySkip(ctx context.Context, BattleID string, UserID string, return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_skipped", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_skipped", string(updatedStorys), "") return msg, nil, false } @@ -316,7 +318,7 @@ func (b *Service) StoryFinalize(ctx context.Context, BattleID string, UserID str return nil, err, false } updatedStorys, _ := json.Marshal(plans) - msg := createSocketEvent("plan_finalized", string(updatedStorys), "") + msg := wshub.CreateSocketEvent("plan_finalized", string(updatedStorys), "") return msg, nil, false } @@ -330,22 +332,3 @@ func (b *Service) Abandon(ctx context.Context, BattleID string, UserID string, E return nil, errors.New("ABANDONED_BATTLE"), true } - -// socketEvent is the event structure used for socket messages -type socketEvent struct { - Type string `json:"type"` - Value string `json:"value"` - User string `json:"warriorId"` -} - -func createSocketEvent(Type string, Value string, User string) []byte { - newEvent := &socketEvent{ - Type: Type, - Value: Value, - User: User, - } - - event, _ := json.Marshal(newEvent) - - return event -} diff --git a/internal/http/poker/hub.go b/internal/http/poker/hub.go deleted file mode 100644 index 666a39b0..00000000 --- a/internal/http/poker/hub.go +++ /dev/null @@ -1,74 +0,0 @@ -package poker - -type message struct { - data []byte - arena string -} - -type subscription struct { - config *Config - conn *connection - arena string - UserID string -} - -// hub maintains the set of active connections and broadcasts messages to the -// connections. -type hub struct { - // Registered connections. - arenas map[string]map[*connection]struct{} - - // Inbound messages from the connections. - broadcast chan message - - // Register requests from the connections. - register chan subscription - - // Unregister requests from connections. - unregister chan subscription -} - -var h = hub{ - broadcast: make(chan message), - register: make(chan subscription), - unregister: make(chan subscription), - arenas: make(map[string]map[*connection]struct{}), -} - -func (h *hub) run() { - for { - select { - case a := <-h.register: - connections := h.arenas[a.arena] - if connections == nil { - connections = make(map[*connection]struct{}) - h.arenas[a.arena] = connections - } - h.arenas[a.arena][a.conn] = struct{}{} - case a := <-h.unregister: - connections := h.arenas[a.arena] - if connections != nil { - if _, ok := connections[a.conn]; ok { - delete(connections, a.conn) - close(a.conn.send) - if len(connections) == 0 { - delete(h.arenas, a.arena) - } - } - } - case m := <-h.broadcast: - connections := h.arenas[m.arena] - for c := range connections { - select { - case c.send <- m.data: - default: - close(c.send) - delete(connections, c) - if len(connections) == 0 { - delete(h.arenas, m.arena) - } - } - } - } - } -} diff --git a/internal/http/poker/poker.go b/internal/http/poker/poker.go index 3486d8f5..04cd3ba0 100644 --- a/internal/http/poker/poker.go +++ b/internal/http/poker/poker.go @@ -4,7 +4,8 @@ package poker import ( "context" "net/http" - "time" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "github.com/uptrace/opentelemetry-go-extra/otelzap" @@ -13,32 +14,16 @@ import ( type Config struct { // Time allowed to write a message to the peer. WriteWaitSec int - // Time allowed to read the next pong message from the peer. PongWaitSec int - // Send pings to peer with this period. Must be less than pongWait. PingPeriodSec int - // App Domain (for Websocket origin check) AppDomain string - // Websocket Subdomain (for Websocket origin check) WebsocketSubdomain string } -func (c *Config) WriteWait() time.Duration { - return time.Duration(c.WriteWaitSec) * time.Second -} - -func (c *Config) PingPeriod() time.Duration { - return time.Duration(c.PingPeriodSec) * time.Second -} - -func (c *Config) PongWait() time.Duration { - return time.Duration(c.PongWaitSec) * time.Second -} - type AuthDataSvc interface { GetSessionUser(ctx context.Context, SessionId string) (*thunderdome.User, error) } @@ -53,10 +38,10 @@ type Service struct { logger *otelzap.Logger validateSessionCookie func(w http.ResponseWriter, r *http.Request) (string, error) validateUserCookie func(w http.ResponseWriter, r *http.Request) (string, error) - eventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool) UserService UserDataSvc AuthService AuthDataSvc BattleService thunderdome.PokerDataSvc + hub *wshub.Hub } // New returns a new battle with websocket hub/client and event handlers @@ -77,7 +62,13 @@ func New( BattleService: battleService, } - b.eventHandlers = map[string]func(context.Context, string, string, string) ([]byte, error, bool){ + b.hub = wshub.NewHub(logger, wshub.Config{ + AppDomain: config.AppDomain, + WebsocketSubdomain: config.WebsocketSubdomain, + WriteWaitSec: config.WriteWaitSec, + PongWaitSec: config.PongWaitSec, + PingPeriodSec: config.PingPeriodSec, + }, map[string]func(context.Context, string, string, string) ([]byte, error, bool){ "jab_warrior": b.UserNudge, "vote": b.UserVote, "retract_vote": b.UserVoteRetract, @@ -96,9 +87,26 @@ func New( "revise_battle": b.Revise, "concede_battle": b.Delete, "abandon_battle": b.Abandon, - } - - go h.run() + }, + map[string]struct{}{ + "add_plan": {}, + "revise_plan": {}, + "burn_plan": {}, + "activate_plan": {}, + "skip_plan": {}, + "end_voting": {}, + "finalize_plan": {}, + "jab_warrior": {}, + "promote_leader": {}, + "demote_leader": {}, + "revise_battle": {}, + "concede_battle": {}, + }, + b.BattleService.ConfirmFacilitator, + b.RetreatUser, + ) + + go b.hub.Run() return b } diff --git a/internal/http/retro/client.go b/internal/http/retro/client.go index 3aacd365..a2625ead 100644 --- a/internal/http/retro/client.go +++ b/internal/http/retro/client.go @@ -5,311 +5,102 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" "net/http" - "net/url" - "time" - "unicode/utf8" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "go.uber.org/zap" - "github.com/gorilla/mux" "github.com/gorilla/websocket" ) -const ( - // Maximum message size allowed from peer. - maxMessageSize = 1024 * 1024 -) - -// ownerOnlyOperations contains a map of operations that only a retro leader can execute -var ownerOnlyOperations = map[string]struct{}{ - "advance_phase": {}, - "add_facilitator": {}, - "remove_facilitator": {}, - "edit_retro": {}, - "concede_retro": {}, - "phase_time_ran_out": {}, - "phase_all_ready": {}, -} - -// connection is a middleman between the websocket connection and the hub. -type connection struct { - config *Config - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (sub subscription) readPump(b *Service, ctx context.Context) { - var forceClosed bool - c := sub.conn - UserID := sub.UserID - RetroID := sub.arena - - defer func() { - Users := b.RetroService.RetroRetreatUser(RetroID, UserID) - UpdatedUsers, _ := json.Marshal(Users) - - retreatEvent := createSocketEvent("user_left", string(UpdatedUsers), UserID) - m := message{retreatEvent, RetroID} - h.broadcast <- m - - h.unregister <- sub - if forceClosed { - cm := websocket.FormatCloseMessage(4002, "abandoned") - if err := c.ws.WriteControl(websocket.CloseMessage, cm, time.Now().Add(sub.config.WriteWait())); err != nil { - b.logger.Ctx(ctx).Error("abandon error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("retro_id", RetroID)) - } - } - if err := c.ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("retro_id", RetroID)) - } - }() - c.ws.SetReadLimit(maxMessageSize) - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - c.ws.SetPongHandler(func(string) error { - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - return nil - }) - - for { - var badEvent bool - var eventErr error - _, msg, err := c.ws.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("retro_id", RetroID)) - } - break - } - - keyVal := make(map[string]string) - err = json.Unmarshal(msg, &keyVal) - if err != nil { - badEvent = true - b.logger.Error("unexpected retro event json error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("retro_id", RetroID)) - } - - eventType := keyVal["type"] - eventValue := keyVal["value"] - - // confirm owner for any operation that requires it - if _, ok := ownerOnlyOperations[eventType]; ok && !badEvent { - err := b.RetroService.RetroConfirmFacilitator(RetroID, UserID) - if err != nil { - badEvent = true - } - } - - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok && !badEvent { - msg, eventErr, forceClosed = b.eventHandlers[eventType](ctx, RetroID, UserID, eventValue) - if eventErr != nil { - badEvent = true - - // don't log forceClosed events e.g. Abandon - if !forceClosed { - b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(eventErr), - zap.String("session_user_id", UserID), zap.String("retro_id", RetroID), - zap.String("retro_event_type", eventType)) - } - } - } - - if !badEvent { - m := message{msg, sub.arena} - h.broadcast <- m - } - - if forceClosed { - break - } - } -} - -// write a message with the given message type and payload. -func (c *connection) write(mt int, payload []byte) error { - _ = c.ws.SetWriteDeadline(time.Now().Add(c.config.WriteWait())) - return c.ws.WriteMessage(mt, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (sub *subscription) writePump() { - c := sub.conn - ticker := time.NewTicker(sub.config.PingPeriod()) - defer func() { - ticker.Stop() - _ = c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - _ = c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, nil); err != nil { - return - } - } - } -} - -func (b *Service) createWebsocketUpgrader() websocket.Upgrader { - return websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, b.config.AppDomain, b.config.WebsocketSubdomain) - }, - } -} - -func checkOrigin(r *http.Request, appDomain string, subDomain string) bool { - origin := r.Header.Get("Origin") - if len(origin) == 0 { - return true - } - originUrl, err := url.Parse(origin) - if err != nil { - return false - } - appDomainCheck := equalASCIIFold(originUrl.Host, appDomain) - subDomainCheck := equalASCIIFold(originUrl.Host, fmt.Sprintf("%s.%s", subDomain, appDomain)) - hostCheck := equalASCIIFold(originUrl.Host, r.Host) - - return appDomainCheck || subDomainCheck || hostCheck -} - -// equalASCIIFold returns true if s is equal to t with ASCII case folding as -// defined in RFC 4790. -// Taken from Gorilla Websocket, https://github.com/gorilla/websocket/blob/main/util.go -func equalASCIIFold(s, t string) bool { - for s != "" && t != "" { - sr, size := utf8.DecodeRuneInString(s) - s = s[size:] - tr, size := utf8.DecodeRuneInString(t) - t = t[size:] - if sr == tr { - continue - } - if 'A' <= sr && sr <= 'Z' { - sr = sr + 'a' - 'A' - } - if 'A' <= tr && tr <= 'Z' { - tr = tr + 'a' - 'A' - } - if sr != tr { - return false - } - } - return s == t -} - -// handleSocketUnauthorized sets the format close message and closes the websocket -func (b *Service) handleSocketClose(ctx context.Context, ws *websocket.Conn, closeCode int, text string) { - cm := websocket.FormatCloseMessage(closeCode, text) - if err := ws.WriteMessage(websocket.CloseMessage, cm); err != nil { - b.logger.Ctx(ctx).Error("unauthorized close error", zap.Error(err)) - } - if err := ws.Close(); err != nil { - b.logger.Ctx(ctx).Error("close error", zap.Error(err)) - } -} - // ServeWs handles websocket requests from the peer. func (b *Service) ServeWs() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - retroID := vars["retroId"] + return b.hub.WebSocketHandler("retroId", func(w http.ResponseWriter, r *http.Request, c *wshub.Connection, roomID string) *wshub.AuthError { ctx := r.Context() var User *thunderdome.User - var UserAuthed bool - - // upgrade to WebSocket connection - var upgrader = b.createWebsocketUpgrader() - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - b.logger.Ctx(ctx).Error("websocket upgrade error", zap.Error(err), - zap.String("retro_id", retroID)) - return - } - c := &connection{config: &b.config, send: make(chan []byte, 256), ws: ws} SessionId, cookieErr := b.validateSessionCookie(w, r) if cookieErr != nil && cookieErr.Error() != "COOKIE_NOT_FOUND" { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } if SessionId != "" { var userErr error User, userErr = b.AuthService.GetSessionUser(ctx, SessionId) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } else { UserID, err := b.validateUserCookie(w, r) if err != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } var userErr error User, userErr = b.UserService.GetGuestUser(ctx, UserID) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } // make sure retro is legit - retro, retroErr := b.RetroService.RetroGet(retroID, User.Id) + retro, retroErr := b.RetroService.RetroGet(roomID, User.Id) if retroErr != nil { - b.handleSocketClose(ctx, ws, 4004, "retro not found") - return + authErr := wshub.AuthError{ + Code: 4004, + Message: "retro not found", + } + return &authErr } // check users retro active status - UserErr := b.RetroService.GetRetroUserActiveStatus(retroID, User.Id) + UserErr := b.RetroService.GetRetroUserActiveStatus(roomID, User.Id) if UserErr != nil && !errors.Is(UserErr, sql.ErrNoRows) { usrErrMsg := UserErr.Error() + var authErr wshub.AuthError if usrErrMsg == "DUPLICATE_RETRO_USER" { - b.handleSocketClose(ctx, ws, 4003, "duplicate session") + authErr = wshub.AuthError{ + Code: 4003, + Message: "duplicate session", + } } else { b.logger.Ctx(ctx).Error("error finding user", zap.Error(UserErr), - zap.String("retro_id", retroID)) - b.handleSocketClose(ctx, ws, 4005, "internal error") + zap.String("retro_id", roomID), zap.String("session_user_id", User.Id)) + authErr = wshub.AuthError{ + Code: 4005, + Message: "internal error", + } } - return - } - - if retro.JoinCode != "" && (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) { - jcrEvent := createSocketEvent("join_code_required", "", User.Id) - _ = c.write(websocket.TextMessage, jcrEvent) + return &authErr + } else if retro.JoinCode != "" && (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) { + jcrEvent := wshub.CreateSocketEvent("join_code_required", "", User.Id) + _ = c.Write(websocket.TextMessage, jcrEvent) for { - _, msg, err := c.ws.ReadMessage() + _, msg, err := c.Ws.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("retro_id", retroID), zap.String("session_user_id", User.Id)) + zap.String("retro_id", roomID), zap.String("session_user_id", User.Id)) } break } @@ -317,65 +108,47 @@ func (b *Service) ServeWs() http.HandlerFunc { keyVal := make(map[string]string) err = json.Unmarshal(msg, &keyVal) if err != nil { - b.logger.Error("unexpected retro message error", zap.Error(err), - zap.String("retro_id", retroID), zap.String("session_user_id", User.Id)) + b.logger.Error("unexpected message error", zap.Error(err), + zap.String("retro_id", roomID), zap.String("session_user_id", User.Id)) } if keyVal["type"] == "auth_retro" && keyVal["value"] == retro.JoinCode { - UserAuthed = true + // join code is valid, continue to room break } else if keyVal["type"] == "auth_retro" { - authIncorrect := createSocketEvent("join_code_incorrect", "", User.Id) - _ = c.write(websocket.TextMessage, authIncorrect) + authIncorrect := wshub.CreateSocketEvent("join_code_incorrect", "", User.Id) + _ = c.Write(websocket.TextMessage, authIncorrect) } } - } else { - UserAuthed = true } - if UserAuthed { - ss := subscription{&b.config, c, retroID, User.Id} - h.register <- ss + sub := b.hub.NewSubscriber(c.Ws, User.Id, roomID) - Users, _ := b.RetroService.RetroAddUser(ss.arena, User.Id) - UpdatedUsers, _ := json.Marshal(Users) + Users, _ := b.RetroService.RetroAddUser(roomID, User.Id) + UpdatedUsers, _ := json.Marshal(Users) - Retro, _ := json.Marshal(retro) - initEvent := createSocketEvent("init", string(Retro), User.Id) - _ = c.write(websocket.TextMessage, initEvent) + Retro, _ := json.Marshal(retro) + initEvent := wshub.CreateSocketEvent("init", string(Retro), User.Id) + _ = sub.Conn.Write(websocket.TextMessage, initEvent) - joinedEvent := createSocketEvent("user_joined", string(UpdatedUsers), User.Id) - m := message{joinedEvent, ss.arena} - h.broadcast <- m + userJoinedEvent := wshub.CreateSocketEvent("user_joined", string(UpdatedUsers), User.Id) + b.hub.Broadcast(wshub.Message{Data: userJoinedEvent, Room: roomID}) - go ss.writePump() - go ss.readPump(b, ctx) - } - } -} + go sub.WritePump() + go sub.ReadPump(ctx, b.hub) -// APIEvent handles api driven events into the arena (if active) -func (b *Service) APIEvent(ctx context.Context, arenaID string, UserID, eventType string, eventValue string) error { - // confirm leader for any operation that requires it - if _, ok := ownerOnlyOperations[eventType]; ok { - err := b.RetroService.RetroConfirmFacilitator(arenaID, UserID) - if err != nil { - return err - } - } + return nil + }) +} - // find event handler and execute otherwise invalid event - if _, ok := b.eventHandlers[eventType]; ok { - msg, eventErr, _ := b.eventHandlers[eventType](ctx, arenaID, UserID, eventValue) - if eventErr != nil { - return eventErr - } +func (b *Service) RetreatUser(roomID string, userID string) string { + Users := b.RetroService.RetroRetreatUser(roomID, userID) + UpdatedUsers, _ := json.Marshal(Users) - if _, ok := h.arenas[arenaID]; ok { - m := message{msg, arenaID} - h.broadcast <- m - } - } + return string(UpdatedUsers) +} - return nil +// APIEvent handles api driven events into the retro (if active) +func (b *Service) APIEvent(ctx context.Context, retroID string, UserID, eventType string, eventValue string) error { + return b.hub.ProcessAPIEventHandler(ctx, UserID, retroID, eventType, eventValue) } diff --git a/internal/http/retro/events.go b/internal/http/retro/events.go index f96a90ca..0df7d21f 100644 --- a/internal/http/retro/events.go +++ b/internal/http/retro/events.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" + "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "go.uber.org/zap" ) @@ -27,7 +29,7 @@ func (b *Service) CreateItem(ctx context.Context, RetroID string, UserID string, } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("items_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("items_updated", string(updatedItems), "") return msg, nil, false } @@ -49,7 +51,7 @@ func (b *Service) ItemCommentAdd(ctx context.Context, RetroID string, UserID str } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("items_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("items_updated", string(updatedItems), "") return msg, nil, false } @@ -71,7 +73,7 @@ func (b *Service) ItemCommentEdit(ctx context.Context, RetroID string, UserID st } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("items_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("items_updated", string(updatedItems), "") return msg, nil, false } @@ -92,7 +94,7 @@ func (b *Service) ItemCommentDelete(ctx context.Context, RetroID string, UserID } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("items_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("items_updated", string(updatedItems), "") return msg, nil, false } @@ -105,7 +107,7 @@ func (b *Service) UserMarkReady(ctx context.Context, RetroID string, UserID stri } updatedReadyUsers, _ := json.Marshal(readyUsers) - msg := createSocketEvent("user_marked_ready", string(updatedReadyUsers), UserID) + msg := wshub.CreateSocketEvent("user_marked_ready", string(updatedReadyUsers), UserID) return msg, nil, false } @@ -118,7 +120,7 @@ func (b *Service) UserUnMarkReady(ctx context.Context, RetroID string, UserID st } updatedReadyUsers, _ := json.Marshal(readyUsers) - msg := createSocketEvent("user_marked_unready", string(updatedReadyUsers), UserID) + msg := wshub.CreateSocketEvent("user_marked_unready", string(updatedReadyUsers), UserID) return msg, nil, false } @@ -140,7 +142,7 @@ func (b *Service) GroupItem(ctx context.Context, RetroID string, UserID string, } updatedItem, _ := json.Marshal(item) - msg := createSocketEvent("item_moved", string(updatedItem), "") + msg := wshub.CreateSocketEvent("item_moved", string(updatedItem), "") return msg, nil, false } @@ -163,7 +165,7 @@ func (b *Service) DeleteItem(ctx context.Context, RetroID string, UserID string, } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("items_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("items_updated", string(updatedItems), "") return msg, nil, false } @@ -185,7 +187,7 @@ func (b *Service) GroupNameChange(ctx context.Context, RetroID string, UserID st } updatedGroup, _ := json.Marshal(group) - msg := createSocketEvent("group_name_updated", string(updatedGroup), "") + msg := wshub.CreateSocketEvent("group_name_updated", string(updatedGroup), "") return msg, nil, false } @@ -206,7 +208,7 @@ func (b *Service) GroupUserVote(ctx context.Context, RetroID string, UserID stri } updatedVotes, _ := json.Marshal(votes) - msg := createSocketEvent("votes_updated", string(updatedVotes), "") + msg := wshub.CreateSocketEvent("votes_updated", string(updatedVotes), "") return msg, nil, false } @@ -227,7 +229,7 @@ func (b *Service) GroupUserSubtractVote(ctx context.Context, RetroID string, Use } updatedVotes, _ := json.Marshal(votes) - msg := createSocketEvent("votes_updated", string(updatedVotes), "") + msg := wshub.CreateSocketEvent("votes_updated", string(updatedVotes), "") return msg, nil, false } @@ -248,7 +250,7 @@ func (b *Service) CreateAction(ctx context.Context, RetroID string, UserID strin } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("action_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("action_updated", string(updatedItems), "") return msg, nil, false } @@ -271,7 +273,7 @@ func (b *Service) UpdateAction(ctx context.Context, RetroID string, UserID strin } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("action_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("action_updated", string(updatedItems), "") return msg, nil, false } @@ -293,7 +295,7 @@ func (b *Service) ActionAddAssignee(ctx context.Context, RetroID string, UserID } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("action_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("action_updated", string(updatedItems), "") return msg, nil, false } @@ -315,7 +317,7 @@ func (b *Service) ActionRemoveAssignee(ctx context.Context, RetroID string, User } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("action_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("action_updated", string(updatedItems), "") return msg, nil, false } @@ -336,7 +338,7 @@ func (b *Service) DeleteAction(ctx context.Context, RetroID string, UserID strin } updatedItems, _ := json.Marshal(items) - msg := createSocketEvent("action_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("action_updated", string(updatedItems), "") return msg, nil, false } @@ -357,7 +359,7 @@ func (b *Service) AdvancePhase(ctx context.Context, RetroID string, UserID strin } updatedItems, _ := json.Marshal(retro) - msg := createSocketEvent("phase_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("phase_updated", string(updatedItems), "") // if retro is completed send retro email to attendees if rs.Phase == "completed" { @@ -383,7 +385,7 @@ func (b *Service) PhaseTimeout(ctx context.Context, RetroID string, UserID strin } updatedItems, _ := json.Marshal(retro) - msg := createSocketEvent("phase_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("phase_updated", string(updatedItems), "") return msg, nil, false } @@ -404,7 +406,7 @@ func (b *Service) PhaseAllReady(ctx context.Context, RetroID string, UserID stri } updatedItems, _ := json.Marshal(retro) - msg := createSocketEvent("phase_updated", string(updatedItems), "") + msg := wshub.CreateSocketEvent("phase_updated", string(updatedItems), "") return msg, nil, false } @@ -425,7 +427,7 @@ func (b *Service) FacilitatorAdd(ctx context.Context, RetroID string, UserID str } updatedFacilitators, _ := json.Marshal(facilitators) - msg := createSocketEvent("facilitators_updated", string(updatedFacilitators), "") + msg := wshub.CreateSocketEvent("facilitators_updated", string(updatedFacilitators), "") return msg, nil, false } @@ -446,7 +448,7 @@ func (b *Service) FacilitatorRemove(ctx context.Context, RetroID string, UserID } updatedFacilitators, _ := json.Marshal(facilitators) - msg := createSocketEvent("facilitators_updated", string(updatedFacilitators), "") + msg := wshub.CreateSocketEvent("facilitators_updated", string(updatedFacilitators), "") return msg, nil, false } @@ -465,7 +467,7 @@ func (b *Service) FacilitatorSelf(ctx context.Context, RetroID string, UserID st } updatedFacilitators, _ := json.Marshal(facilitators) - msg := createSocketEvent("facilitators_updated", string(updatedFacilitators), "") + msg := wshub.CreateSocketEvent("facilitators_updated", string(updatedFacilitators), "") return msg, nil, false } else { @@ -502,7 +504,7 @@ func (b *Service) EditRetro(ctx context.Context, RetroID string, UserID string, } updatedRetro, _ := json.Marshal(rb) - msg := createSocketEvent("retro_edited", string(updatedRetro), "") + msg := wshub.CreateSocketEvent("retro_edited", string(updatedRetro), "") return msg, nil, false } @@ -513,7 +515,7 @@ func (b *Service) Delete(ctx context.Context, RetroID string, UserID string, Eve if err != nil { return nil, err, false } - msg := createSocketEvent("conceded", "", "") + msg := wshub.CreateSocketEvent("conceded", "", "") return msg, nil, false } @@ -547,22 +549,3 @@ func (b *Service) SendCompletedEmails(retro *thunderdome.Retro) { } } } - -// socketEvent is the event structure used for socket messages -type socketEvent struct { - Type string `json:"type"` - Value string `json:"value"` - User string `json:"userId"` -} - -func createSocketEvent(Type string, Value string, User string) []byte { - newEvent := &socketEvent{ - Type: Type, - Value: Value, - User: User, - } - - event, _ := json.Marshal(newEvent) - - return event -} diff --git a/internal/http/retro/hub.go b/internal/http/retro/hub.go deleted file mode 100644 index 5bbe67eb..00000000 --- a/internal/http/retro/hub.go +++ /dev/null @@ -1,74 +0,0 @@ -package retro - -type message struct { - data []byte - arena string -} - -type subscription struct { - config *Config - conn *connection - arena string - UserID string -} - -// hub maintains the set of active connections and broadcasts messages to the -// connections. -type hub struct { - // Registered connections. - arenas map[string]map[*connection]struct{} - - // Inbound messages from the connections. - broadcast chan message - - // Register requests from the connections. - register chan subscription - - // Unregister requests from connections. - unregister chan subscription -} - -var h = hub{ - broadcast: make(chan message), - register: make(chan subscription), - unregister: make(chan subscription), - arenas: make(map[string]map[*connection]struct{}), -} - -func (h *hub) run() { - for { - select { - case a := <-h.register: - connections := h.arenas[a.arena] - if connections == nil { - connections = make(map[*connection]struct{}) - h.arenas[a.arena] = connections - } - h.arenas[a.arena][a.conn] = struct{}{} - case a := <-h.unregister: - connections := h.arenas[a.arena] - if connections != nil { - if _, ok := connections[a.conn]; ok { - delete(connections, a.conn) - close(a.conn.send) - if len(connections) == 0 { - delete(h.arenas, a.arena) - } - } - } - case m := <-h.broadcast: - connections := h.arenas[m.arena] - for c := range connections { - select { - case c.send <- m.data: - default: - close(c.send) - delete(connections, c) - if len(connections) == 0 { - delete(h.arenas, m.arena) - } - } - } - } - } -} diff --git a/internal/http/retro/retro.go b/internal/http/retro/retro.go index 0dd28f58..c9bbefe1 100644 --- a/internal/http/retro/retro.go +++ b/internal/http/retro/retro.go @@ -3,8 +3,8 @@ package retro import ( "context" "net/http" - "time" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "github.com/uptrace/opentelemetry-go-extra/otelzap" ) @@ -26,18 +26,6 @@ type Config struct { WebsocketSubdomain string } -func (c *Config) WriteWait() time.Duration { - return time.Duration(c.WriteWaitSec) * time.Second -} - -func (c *Config) PingPeriod() time.Duration { - return time.Duration(c.PingPeriodSec) * time.Second -} - -func (c *Config) PongWait() time.Duration { - return time.Duration(c.PongWaitSec) * time.Second -} - type AuthDataSvc interface { GetSessionUser(ctx context.Context, SessionId string) (*thunderdome.User, error) } @@ -52,12 +40,12 @@ type Service struct { logger *otelzap.Logger validateSessionCookie func(w http.ResponseWriter, r *http.Request) (string, error) validateUserCookie func(w http.ResponseWriter, r *http.Request) (string, error) - eventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool) UserService UserDataSvc AuthService AuthDataSvc RetroService thunderdome.RetroDataSvc TemplateService thunderdome.RetroTemplateDataSvc EmailService thunderdome.EmailService + hub *wshub.Hub } // New returns a new retro with websocket hub/client and event handlers @@ -82,7 +70,13 @@ func New( EmailService: emailService, } - rs.eventHandlers = map[string]func(context.Context, string, string, string) ([]byte, error, bool){ + rs.hub = wshub.NewHub(logger, wshub.Config{ + AppDomain: config.AppDomain, + WebsocketSubdomain: config.WebsocketSubdomain, + WriteWaitSec: config.WriteWaitSec, + PongWaitSec: config.PongWaitSec, + PingPeriodSec: config.PingPeriodSec, + }, map[string]func(context.Context, string, string, string) ([]byte, error, bool){ "create_item": rs.CreateItem, "user_ready": rs.UserMarkReady, "user_unready": rs.UserUnMarkReady, @@ -108,9 +102,21 @@ func New( "edit_retro": rs.EditRetro, "concede_retro": rs.Delete, "abandon_retro": rs.Abandon, - } - - go h.run() + }, + map[string]struct{}{ + "advance_phase": {}, + "add_facilitator": {}, + "remove_facilitator": {}, + "edit_retro": {}, + "concede_retro": {}, + "phase_time_ran_out": {}, + "phase_all_ready": {}, + }, + rs.RetroService.RetroConfirmFacilitator, + rs.RetreatUser, + ) + + go rs.hub.Run() return rs } diff --git a/internal/http/storyboard/client.go b/internal/http/storyboard/client.go index 7e840638..4f3baa46 100644 --- a/internal/http/storyboard/client.go +++ b/internal/http/storyboard/client.go @@ -5,308 +5,102 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" "net/http" - "net/url" - "time" - "unicode/utf8" + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "go.uber.org/zap" - "github.com/gorilla/mux" "github.com/gorilla/websocket" ) -const ( - // Maximum message size allowed from peer. - maxMessageSize = 1024 * 1024 -) - -// ownerOnlyOperations contains a map of operations that only a storyboard leader can execute -var ownerOnlyOperations = map[string]struct{}{ - "facilitator_add": {}, - "facilitator_remove": {}, - "edit_storyboard": {}, - "concede_storyboard": {}, -} - -// connection is a middleman between the websocket connection and the hub. -type connection struct { - config *Config - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (sub subscription) readPump(b *Service, ctx context.Context) { - var forceClosed bool - c := sub.conn - UserID := sub.UserID - StoryboardID := sub.arena - - defer func() { - Users := b.StoryboardService.RetreatStoryboardUser(StoryboardID, UserID) - UpdatedUsers, _ := json.Marshal(Users) - - retreatEvent := createSocketEvent("user_left", string(UpdatedUsers), UserID) - m := message{retreatEvent, StoryboardID} - h.broadcast <- m - - h.unregister <- sub - if forceClosed { - cm := websocket.FormatCloseMessage(4002, "abandoned") - if err := c.ws.WriteControl(websocket.CloseMessage, cm, time.Now().Add(sub.config.WriteWait())); err != nil { - b.Logger.Ctx(ctx).Error("abandon error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("storyboard_id", StoryboardID)) - } - } - if err := c.ws.Close(); err != nil { - b.Logger.Ctx(ctx).Error("close error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("storyboard_id", StoryboardID)) - } - }() - c.ws.SetReadLimit(maxMessageSize) - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - c.ws.SetPongHandler(func(string) error { - _ = c.ws.SetReadDeadline(time.Now().Add(sub.config.PongWait())) - return nil - }) - - for { - var badEvent bool - var eventErr error - _, msg, err := c.ws.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - b.Logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("storyboard_id", StoryboardID)) - } - break - } - - keyVal := make(map[string]string) - err = json.Unmarshal(msg, &keyVal) - if err != nil { - badEvent = true - b.Logger.Error("unexpected storyboard event json error", zap.Error(err), - zap.String("session_user_id", UserID), zap.String("storyboard_id", StoryboardID)) - } - - eventType := keyVal["type"] - eventValue := keyVal["value"] - - // confirm owner for any operation that requires it - if _, ok := ownerOnlyOperations[eventType]; ok && !badEvent { - err := b.StoryboardService.ConfirmStoryboardFacilitator(StoryboardID, UserID) - if err != nil { - badEvent = true - } - } - - // find event handler and execute otherwise invalid event - if _, ok := b.EventHandlers[eventType]; ok && !badEvent { - msg, eventErr, forceClosed = b.EventHandlers[eventType](ctx, StoryboardID, UserID, eventValue) - if eventErr != nil { - badEvent = true - - // don't log forceClosed events e.g. Abandon - if !forceClosed { - b.Logger.Ctx(ctx).Error("unexpected close error", zap.Error(eventErr), - zap.String("session_user_id", UserID), zap.String("storyboard_id", StoryboardID), - zap.String("storyboard_event_type", eventType)) - } - } - } - - if !badEvent { - m := message{msg, sub.arena} - h.broadcast <- m - } - - if forceClosed { - break - } - } -} - -// write a message with the given message type and payload. -func (c *connection) write(mt int, payload []byte) error { - _ = c.ws.SetWriteDeadline(time.Now().Add(c.config.WriteWait())) - return c.ws.WriteMessage(mt, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (sub *subscription) writePump() { - c := sub.conn - ticker := time.NewTicker(sub.config.PingPeriod()) - defer func() { - ticker.Stop() - _ = c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - _ = c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, nil); err != nil { - return - } - } - } -} - -func (b *Service) createWebsocketUpgrader() websocket.Upgrader { - return websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, b.config.AppDomain, b.config.WebsocketSubdomain) - }, - } -} - -func checkOrigin(r *http.Request, appDomain string, subDomain string) bool { - origin := r.Header.Get("Origin") - if len(origin) == 0 { - return true - } - originUrl, err := url.Parse(origin) - if err != nil { - return false - } - appDomainCheck := equalASCIIFold(originUrl.Host, appDomain) - subDomainCheck := equalASCIIFold(originUrl.Host, fmt.Sprintf("%s.%s", subDomain, appDomain)) - hostCheck := equalASCIIFold(originUrl.Host, r.Host) - - return appDomainCheck || subDomainCheck || hostCheck -} - -// equalASCIIFold returns true if s is equal to t with ASCII case folding as -// defined in RFC 4790. -// Taken from Gorilla Websocket, https://github.com/gorilla/websocket/blob/main/util.go -func equalASCIIFold(s, t string) bool { - for s != "" && t != "" { - sr, size := utf8.DecodeRuneInString(s) - s = s[size:] - tr, size := utf8.DecodeRuneInString(t) - t = t[size:] - if sr == tr { - continue - } - if 'A' <= sr && sr <= 'Z' { - sr = sr + 'a' - 'A' - } - if 'A' <= tr && tr <= 'Z' { - tr = tr + 'a' - 'A' - } - if sr != tr { - return false - } - } - return s == t -} - -// handleSocketUnauthorized sets the format close message and closes the websocket -func (b *Service) handleSocketClose(ctx context.Context, ws *websocket.Conn, closeCode int, text string) { - cm := websocket.FormatCloseMessage(closeCode, text) - if err := ws.WriteMessage(websocket.CloseMessage, cm); err != nil { - b.Logger.Ctx(ctx).Error("unauthorized close error", zap.Error(err)) - } - if err := ws.Close(); err != nil { - b.Logger.Ctx(ctx).Error("close error", zap.Error(err)) - } -} - // ServeWs handles websocket requests from the peer. func (b *Service) ServeWs() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - storyboardID := vars["storyboardId"] + return b.hub.WebSocketHandler("storyboardId", func(w http.ResponseWriter, r *http.Request, c *wshub.Connection, roomID string) *wshub.AuthError { ctx := r.Context() var User *thunderdome.User - var UserAuthed bool - - // upgrade to WebSocket connection - var upgrader = b.createWebsocketUpgrader() - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - b.Logger.Ctx(ctx).Error("websocket upgrade error", zap.Error(err), - zap.String("storyboard_id", storyboardID)) - return - } - c := &connection{config: &b.config, send: make(chan []byte, 256), ws: ws} - SessionId, cookieErr := b.ValidateSessionCookie(w, r) + SessionId, cookieErr := b.validateSessionCookie(w, r) if cookieErr != nil && cookieErr.Error() != "COOKIE_NOT_FOUND" { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } if SessionId != "" { var userErr error User, userErr = b.AuthService.GetSessionUser(ctx, SessionId) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } else { - UserID, err := b.ValidateUserCookie(w, r) + UserID, err := b.validateUserCookie(w, r) if err != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } var userErr error User, userErr = b.UserService.GetGuestUser(ctx, UserID) if userErr != nil { - b.handleSocketClose(ctx, ws, 4001, "unauthorized") - return + authErr := wshub.AuthError{ + Code: 4001, + Message: "unauthorized", + } + return &authErr } } // make sure storyboard is legit - storyboard, storyboardErr := b.StoryboardService.GetStoryboard(storyboardID, User.Id) + storyboard, storyboardErr := b.StoryboardService.GetStoryboard(roomID, User.Id) if storyboardErr != nil { - b.handleSocketClose(ctx, ws, 4004, "storyboard not found") - return + authErr := wshub.AuthError{ + Code: 4004, + Message: "storyboard not found", + } + return &authErr } // check users storyboard active status - UserErr := b.StoryboardService.GetStoryboardUserActiveStatus(storyboardID, User.Id) + UserErr := b.StoryboardService.GetStoryboardUserActiveStatus(roomID, User.Id) if UserErr != nil && !errors.Is(UserErr, sql.ErrNoRows) { usrErrMsg := UserErr.Error() + var authErr wshub.AuthError if usrErrMsg == "DUPLICATE_STORYBOARD_USER" { - b.handleSocketClose(ctx, ws, 4003, "duplicate session") + authErr = wshub.AuthError{ + Code: 4003, + Message: "duplicate session", + } } else { - b.Logger.Ctx(ctx).Error("error finding user", zap.Error(UserErr), - zap.String("storyboard_id", storyboardID), zap.String("session_user_id", User.Id)) - b.handleSocketClose(ctx, ws, 4005, "internal error") + b.logger.Ctx(ctx).Error("error finding user", zap.Error(UserErr), + zap.String("storyboard_id", roomID), zap.String("session_user_id", User.Id)) + authErr = wshub.AuthError{ + Code: 4005, + Message: "internal error", + } } - return - } - - if storyboard.JoinCode != "" && (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) { - jcrEvent := createSocketEvent("join_code_required", "", User.Id) - _ = c.write(websocket.TextMessage, jcrEvent) + return &authErr + } else if storyboard.JoinCode != "" && (UserErr != nil && errors.Is(UserErr, sql.ErrNoRows)) { + jcrEvent := wshub.CreateSocketEvent("join_code_required", "", User.Id) + _ = c.Write(websocket.TextMessage, jcrEvent) for { - _, msg, err := c.ws.ReadMessage() + _, msg, err := c.Ws.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - b.Logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), - zap.String("storyboard_id", storyboardID), zap.String("session_user_id", User.Id)) + b.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), + zap.String("storyboard_id", roomID), zap.String("session_user_id", User.Id)) } break } @@ -314,65 +108,47 @@ func (b *Service) ServeWs() http.HandlerFunc { keyVal := make(map[string]string) err = json.Unmarshal(msg, &keyVal) if err != nil { - b.Logger.Error("unexpected storyboard message error", zap.Error(err), - zap.String("storyboard_id", storyboardID), zap.String("session_user_id", User.Id)) + b.logger.Error("unexpected message error", zap.Error(err), + zap.String("retro_id", roomID), zap.String("session_user_id", User.Id)) } if keyVal["type"] == "auth_storyboard" && keyVal["value"] == storyboard.JoinCode { - UserAuthed = true + // join code is valid, continue to room break } else if keyVal["type"] == "auth_storyboard" { - authIncorrect := createSocketEvent("join_code_incorrect", "", User.Id) - _ = c.write(websocket.TextMessage, authIncorrect) + authIncorrect := wshub.CreateSocketEvent("join_code_incorrect", "", User.Id) + _ = c.Write(websocket.TextMessage, authIncorrect) } } - } else { - UserAuthed = true } - if UserAuthed { - ss := subscription{&b.config, c, storyboardID, User.Id} - h.register <- ss + sub := b.hub.NewSubscriber(c.Ws, User.Id, roomID) - Users, _ := b.StoryboardService.AddUserToStoryboard(ss.arena, User.Id) - UpdatedUsers, _ := json.Marshal(Users) + Users, _ := b.StoryboardService.AddUserToStoryboard(roomID, User.Id) + UpdatedUsers, _ := json.Marshal(Users) - Storyboard, _ := json.Marshal(storyboard) - initEvent := createSocketEvent("init", string(Storyboard), User.Id) - _ = c.write(websocket.TextMessage, initEvent) + Storyboard, _ := json.Marshal(storyboard) + initEvent := wshub.CreateSocketEvent("init", string(Storyboard), User.Id) + _ = sub.Conn.Write(websocket.TextMessage, initEvent) - joinedEvent := createSocketEvent("user_joined", string(UpdatedUsers), User.Id) - m := message{joinedEvent, ss.arena} - h.broadcast <- m + userJoinedEvent := wshub.CreateSocketEvent("user_joined", string(UpdatedUsers), User.Id) + b.hub.Broadcast(wshub.Message{Data: userJoinedEvent, Room: roomID}) - go ss.writePump() - go ss.readPump(b, ctx) - } - } -} + go sub.WritePump() + go sub.ReadPump(ctx, b.hub) -// APIEvent handles api driven events into the arena (if active) -func (b *Service) APIEvent(ctx context.Context, arenaID string, UserID, eventType string, eventValue string) error { - // confirm leader for any operation that requires it - if _, ok := ownerOnlyOperations[eventType]; ok { - err := b.StoryboardService.ConfirmStoryboardFacilitator(arenaID, UserID) - if err != nil { - return err - } - } + return nil + }) +} - // find event handler and execute otherwise invalid event - if _, ok := b.EventHandlers[eventType]; ok { - msg, eventErr, _ := b.EventHandlers[eventType](ctx, arenaID, UserID, eventValue) - if eventErr != nil { - return eventErr - } +func (b *Service) RetreatUser(roomID string, userID string) string { + Users := b.StoryboardService.RetreatStoryboardUser(roomID, userID) + UpdatedUsers, _ := json.Marshal(Users) - if _, ok := h.arenas[arenaID]; ok { - m := message{msg, arenaID} - h.broadcast <- m - } - } + return string(UpdatedUsers) +} - return nil +// APIEvent handles api driven events into the storyboard (if active) +func (b *Service) APIEvent(ctx context.Context, storyboardID string, UserID, eventType string, eventValue string) error { + return b.hub.ProcessAPIEventHandler(ctx, UserID, storyboardID, eventType, eventValue) } diff --git a/internal/http/storyboard/events.go b/internal/http/storyboard/events.go index 457c96ea..a3c9aa24 100644 --- a/internal/http/storyboard/events.go +++ b/internal/http/storyboard/events.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "errors" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" ) // AddGoal handles adding a goal to storyboard @@ -13,7 +15,7 @@ func (b *Service) AddGoal(ctx context.Context, StoryboardID string, UserID strin return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("goal_added", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("goal_added", string(updatedGoals), "") return msg, nil, false } @@ -33,7 +35,7 @@ func (b *Service) ReviseGoal(ctx context.Context, StoryboardID string, UserID st return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("goal_revised", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("goal_revised", string(updatedGoals), "") return msg, nil, false } @@ -45,7 +47,7 @@ func (b *Service) DeleteGoal(ctx context.Context, StoryboardID string, UserID st return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("goal_deleted", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("goal_deleted", string(updatedGoals), "") return msg, nil, false } @@ -64,7 +66,7 @@ func (b *Service) AddColumn(ctx context.Context, StoryboardID string, UserID str return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("column_added", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("column_added", string(updatedGoals), "") return msg, nil, false } @@ -85,7 +87,7 @@ func (b *Service) ReviseColumn(ctx context.Context, StoryboardID string, UserID return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("column_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("column_updated", string(updatedGoals), "") return msg, nil, false } @@ -97,7 +99,7 @@ func (b *Service) DeleteColumn(ctx context.Context, StoryboardID string, UserID return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_deleted", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_deleted", string(updatedGoals), "") return msg, nil, false } @@ -118,7 +120,7 @@ func (b *Service) ColumnPersonaAdd(ctx context.Context, StoryboardID string, Use return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("column_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("column_updated", string(updatedGoals), "") return msg, nil, false } @@ -139,7 +141,7 @@ func (b *Service) ColumnPersonaRemove(ctx context.Context, StoryboardID string, return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("column_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("column_updated", string(updatedGoals), "") return msg, nil, false } @@ -159,7 +161,7 @@ func (b *Service) AddStory(ctx context.Context, StoryboardID string, UserID stri return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_added", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_added", string(updatedGoals), "") return msg, nil, false } @@ -179,7 +181,7 @@ func (b *Service) UpdateStoryName(ctx context.Context, StoryboardID string, User return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -199,7 +201,7 @@ func (b *Service) UpdateStoryContent(ctx context.Context, StoryboardID string, U return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -219,7 +221,7 @@ func (b *Service) UpdateStoryColor(ctx context.Context, StoryboardID string, Use return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -240,7 +242,7 @@ func (b *Service) UpdateStoryPoints(ctx context.Context, StoryboardID string, Us return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -261,7 +263,7 @@ func (b *Service) UpdateStoryClosed(ctx context.Context, StoryboardID string, Us return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -281,7 +283,7 @@ func (b *Service) UpdateStoryLink(ctx context.Context, StoryboardID string, User return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -303,7 +305,7 @@ func (b *Service) MoveStory(ctx context.Context, StoryboardID string, UserID str return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_moved", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_moved", string(updatedGoals), "") return msg, nil, false } @@ -315,7 +317,7 @@ func (b *Service) DeleteStory(ctx context.Context, StoryboardID string, UserID s return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_deleted", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_deleted", string(updatedGoals), "") return msg, nil, false } @@ -336,7 +338,7 @@ func (b *Service) AddStoryComment(ctx context.Context, StoryboardID string, User return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -357,7 +359,7 @@ func (b *Service) EditStoryComment(ctx context.Context, StoryboardID string, Use return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -377,7 +379,7 @@ func (b *Service) DeleteStoryComment(ctx context.Context, StoryboardID string, U return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("story_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("story_updated", string(updatedGoals), "") return msg, nil, false } @@ -399,7 +401,7 @@ func (b *Service) AddPersona(ctx context.Context, StoryboardID string, UserID st return nil, err, false } updatedPersonas, _ := json.Marshal(personas) - msg := createSocketEvent("personas_updated", string(updatedPersonas), "") + msg := wshub.CreateSocketEvent("personas_updated", string(updatedPersonas), "") return msg, nil, false } @@ -422,7 +424,7 @@ func (b *Service) UpdatePersona(ctx context.Context, StoryboardID string, UserID return nil, err, false } updatedPersonas, _ := json.Marshal(personas) - msg := createSocketEvent("personas_updated", string(updatedPersonas), "") + msg := wshub.CreateSocketEvent("personas_updated", string(updatedPersonas), "") return msg, nil, false } @@ -434,7 +436,7 @@ func (b *Service) DeletePersona(ctx context.Context, StoryboardID string, UserID return nil, err, false } updatedGoals, _ := json.Marshal(goals) - msg := createSocketEvent("personas_updated", string(updatedGoals), "") + msg := wshub.CreateSocketEvent("personas_updated", string(updatedGoals), "") return msg, nil, false } @@ -454,7 +456,7 @@ func (b *Service) FacilitatorAdd(ctx context.Context, StoryboardID string, UserI return nil, err, false } updatedStoryboard, _ := json.Marshal(storyboard) - msg := createSocketEvent("storyboard_updated", string(updatedStoryboard), "") + msg := wshub.CreateSocketEvent("storyboard_updated", string(updatedStoryboard), "") return msg, nil, false } @@ -474,7 +476,7 @@ func (b *Service) FacilitatorRemove(ctx context.Context, StoryboardID string, Us return nil, err, false } updatedStoryboard, _ := json.Marshal(storyboard) - msg := createSocketEvent("storyboard_updated", string(updatedStoryboard), "") + msg := wshub.CreateSocketEvent("storyboard_updated", string(updatedStoryboard), "") return msg, nil, false } @@ -493,7 +495,7 @@ func (b *Service) FacilitatorSelf(ctx context.Context, StoryboardID string, User } updatedStoryboard, _ := json.Marshal(storyboard) - msg := createSocketEvent("storyboard_updated", string(updatedStoryboard), "") + msg := wshub.CreateSocketEvent("storyboard_updated", string(updatedStoryboard), "") return msg, nil, false } else { @@ -508,7 +510,7 @@ func (b *Service) ReviseColorLegend(ctx context.Context, StoryboardID string, Us return nil, err, false } updatedStoryboard, _ := json.Marshal(storyboard) - msg := createSocketEvent("storyboard_updated", string(updatedStoryboard), "") + msg := wshub.CreateSocketEvent("storyboard_updated", string(updatedStoryboard), "") return msg, nil, false } @@ -536,7 +538,7 @@ func (b *Service) EditStoryboard(ctx context.Context, StoryboardID string, UserI } updatedStoryboard, _ := json.Marshal(rb) - msg := createSocketEvent("storyboard_edited", string(updatedStoryboard), "") + msg := wshub.CreateSocketEvent("storyboard_edited", string(updatedStoryboard), "") return msg, nil, false } @@ -547,7 +549,7 @@ func (b *Service) Delete(ctx context.Context, StoryboardID string, UserID string if err != nil { return nil, err, false } - msg := createSocketEvent("storyboard_conceded", "", "") + msg := wshub.CreateSocketEvent("storyboard_conceded", "", "") return msg, nil, false } @@ -561,22 +563,3 @@ func (b *Service) Abandon(ctx context.Context, StoryboardID string, UserID strin return nil, errors.New("ABANDONED_STORYBOARD"), true } - -// socketEvent is the event structure used for socket messages -type socketEvent struct { - Type string `json:"type"` - Value string `json:"value"` - User string `json:"userId"` -} - -func createSocketEvent(Type string, Value string, User string) []byte { - newEvent := &socketEvent{ - Type: Type, - Value: Value, - User: User, - } - - event, _ := json.Marshal(newEvent) - - return event -} diff --git a/internal/http/storyboard/hub.go b/internal/http/storyboard/hub.go deleted file mode 100644 index 5ad692ee..00000000 --- a/internal/http/storyboard/hub.go +++ /dev/null @@ -1,74 +0,0 @@ -package storyboard - -type message struct { - data []byte - arena string -} - -type subscription struct { - config *Config - conn *connection - arena string - UserID string -} - -// hub maintains the set of active connections and broadcasts messages to the -// connections. -type hub struct { - // Registered connections. - arenas map[string]map[*connection]struct{} - - // Inbound messages from the connections. - broadcast chan message - - // Register requests from the connections. - register chan subscription - - // Unregister requests from connections. - unregister chan subscription -} - -var h = hub{ - broadcast: make(chan message), - register: make(chan subscription), - unregister: make(chan subscription), - arenas: make(map[string]map[*connection]struct{}), -} - -func (h *hub) run() { - for { - select { - case a := <-h.register: - connections := h.arenas[a.arena] - if connections == nil { - connections = make(map[*connection]struct{}) - h.arenas[a.arena] = connections - } - h.arenas[a.arena][a.conn] = struct{}{} - case a := <-h.unregister: - connections := h.arenas[a.arena] - if connections != nil { - if _, ok := connections[a.conn]; ok { - delete(connections, a.conn) - close(a.conn.send) - if len(connections) == 0 { - delete(h.arenas, a.arena) - } - } - } - case m := <-h.broadcast: - connections := h.arenas[m.arena] - for c := range connections { - select { - case c.send <- m.data: - default: - close(c.send) - delete(connections, c) - if len(connections) == 0 { - delete(h.arenas, m.arena) - } - } - } - } - } -} diff --git a/internal/http/storyboard/storyboard.go b/internal/http/storyboard/storyboard.go index 708fc19c..a50fa660 100644 --- a/internal/http/storyboard/storyboard.go +++ b/internal/http/storyboard/storyboard.go @@ -3,7 +3,8 @@ package storyboard import ( "context" "net/http" - "time" + + "github.com/StevenWeathers/thunderdome-planning-poker/internal/wshub" "github.com/StevenWeathers/thunderdome-planning-poker/thunderdome" "github.com/uptrace/opentelemetry-go-extra/otelzap" @@ -26,18 +27,6 @@ type Config struct { WebsocketSubdomain string } -func (c *Config) WriteWait() time.Duration { - return time.Duration(c.WriteWaitSec) * time.Second -} - -func (c *Config) PingPeriod() time.Duration { - return time.Duration(c.PingPeriodSec) * time.Second -} - -func (c *Config) PongWait() time.Duration { - return time.Duration(c.PongWaitSec) * time.Second -} - type AuthDataSvc interface { GetSessionUser(ctx context.Context, SessionId string) (*thunderdome.User, error) } @@ -49,13 +38,13 @@ type UserDataSvc interface { // Service provides storyboard service type Service struct { config Config - Logger *otelzap.Logger - ValidateSessionCookie func(w http.ResponseWriter, r *http.Request) (string, error) - ValidateUserCookie func(w http.ResponseWriter, r *http.Request) (string, error) - EventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool) + logger *otelzap.Logger + validateSessionCookie func(w http.ResponseWriter, r *http.Request) (string, error) + validateUserCookie func(w http.ResponseWriter, r *http.Request) (string, error) UserService UserDataSvc AuthService AuthDataSvc StoryboardService thunderdome.StoryboardDataSvc + hub *wshub.Hub } // New returns a new storyboard with websocket hub/client and event handlers @@ -69,15 +58,21 @@ func New( ) *Service { sb := &Service{ config: config, - Logger: logger, - ValidateSessionCookie: validateSessionCookie, - ValidateUserCookie: validateUserCookie, + logger: logger, + validateSessionCookie: validateSessionCookie, + validateUserCookie: validateUserCookie, UserService: userService, AuthService: authService, StoryboardService: storyboardService, } - sb.EventHandlers = map[string]func(context.Context, string, string, string) ([]byte, error, bool){ + sb.hub = wshub.NewHub(logger, wshub.Config{ + AppDomain: config.AppDomain, + WebsocketSubdomain: config.WebsocketSubdomain, + WriteWaitSec: config.WriteWaitSec, + PongWaitSec: config.PongWaitSec, + PingPeriodSec: config.PingPeriodSec, + }, map[string]func(context.Context, string, string, string) ([]byte, error, bool){ "add_goal": sb.AddGoal, "revise_goal": sb.ReviseGoal, "delete_goal": sb.DeleteGoal, @@ -108,9 +103,18 @@ func New( "edit_storyboard": sb.EditStoryboard, "concede_storyboard": sb.Delete, "abandon_storyboard": sb.Abandon, - } - - go h.run() + }, + map[string]struct{}{ + "facilitator_add": {}, + "facilitator_remove": {}, + "edit_storyboard": {}, + "concede_storyboard": {}, + }, + sb.StoryboardService.ConfirmStoryboardFacilitator, + sb.RetreatUser, + ) + + go sb.hub.Run() return sb } diff --git a/internal/wshub/config.go b/internal/wshub/config.go new file mode 100644 index 00000000..32ad0aa1 --- /dev/null +++ b/internal/wshub/config.go @@ -0,0 +1,40 @@ +package wshub + +import "time" + +type Config struct { + // Time allowed to write a message to the peer. + WriteWaitSec int + // Time allowed to read the next pong message from the peer. + PongWaitSec int + // Send pings to peer with this period. Must be less than pongWait. + PingPeriodSec int + // App Domain (for Websocket origin check) + AppDomain string + // Websocket Subdomain (for Websocket origin check) + WebsocketSubdomain string +} + +func (c *Config) WriteWait() time.Duration { + waitSec := c.WriteWaitSec + if waitSec <= 0 { + waitSec = 10 // prevents panic: non-positive interval for NewTicker + } + return time.Duration(waitSec) * time.Second +} + +func (c *Config) PingPeriod() time.Duration { + periodSec := c.PingPeriodSec + if periodSec <= 0 { + periodSec = 54 // prevents panic: non-positive interval for NewTicker + } + return time.Duration(periodSec) * time.Second +} + +func (c *Config) PongWait() time.Duration { + waitSec := c.PongWaitSec + if waitSec <= 0 { + waitSec = 60 // prevents panic: non-positive interval for NewTicker + } + return time.Duration(waitSec) * time.Second +} diff --git a/internal/wshub/connection.go b/internal/wshub/connection.go new file mode 100644 index 00000000..3c332c3e --- /dev/null +++ b/internal/wshub/connection.go @@ -0,0 +1,26 @@ +package wshub + +import ( + "time" + + "github.com/gorilla/websocket" +) + +type Connection struct { + // The websocket connection. + Ws *websocket.Conn + // Buffered channel of outbound messages. + send chan []byte + WriteWait time.Duration + PingPeriod time.Duration + PongWait time.Duration +} + +func (c *Connection) Send() chan<- []byte { return c.send } +func (c *Connection) Close() { c.Ws.Close() } + +// Write a message with the given message type and payload. +func (c *Connection) Write(mt int, payload []byte) error { + _ = c.Ws.SetWriteDeadline(time.Now().Add(c.WriteWait)) + return c.Ws.WriteMessage(mt, payload) +} diff --git a/internal/wshub/http.go b/internal/wshub/http.go new file mode 100644 index 00000000..b2521534 --- /dev/null +++ b/internal/wshub/http.go @@ -0,0 +1,98 @@ +package wshub + +import ( + "context" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// CreateWebsocketUpgrader creates a websocket.Upgrader with the given AppDomain and WebsocketSubdomain +func (h *Hub) CreateWebsocketUpgrader() websocket.Upgrader { + return websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return checkOrigin(r, h.config.AppDomain, h.config.WebsocketSubdomain) + }, + } +} + +// HandleSocketClose sets the format close message and closes the websocket +func (h *Hub) HandleSocketClose(ctx context.Context, ws *websocket.Conn, closeCode int, text string) { + cm := websocket.FormatCloseMessage(closeCode, text) + if err := ws.WriteMessage(websocket.CloseMessage, cm); err != nil { + h.logger.Ctx(ctx).Error("unauthorized close error", zap.Error(err)) + } + if err := ws.Close(); err != nil { + h.logger.Ctx(ctx).Error("close error", zap.Error(err)) + } +} + +type AuthError struct { + Code int + Message string +} + +func (e *AuthError) Error() string { + return fmt.Sprintf(e.Message) +} + +// WebSocketHandler creates a http.HandlerFunc for handling WebSocket connections +func (h *Hub) WebSocketHandler( + roomIDVar string, + authFunc func(w http.ResponseWriter, r *http.Request, c *Connection, roomID string) *AuthError, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + vars := mux.Vars(r) + RoomID := vars[roomIDVar] + + // upgrade to WebSocket connection + var upgrader = h.CreateWebsocketUpgrader() + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Ctx(ctx).Error("websocket upgrade error", zap.Error(err), + zap.String("room_id", RoomID)) + return + } + c := h.NewConnection(ws) + + authErr := authFunc(w, r, &c, RoomID) + if authErr != nil { + h.HandleSocketClose(ctx, c.Ws, authErr.Code, authErr.Error()) + return + } + } +} + +// ProcessAPIEventHandler processes an event from the API through the websocket hub. +func (h *Hub) ProcessAPIEventHandler(ctx context.Context, userID, roomID, eventType string, eventValue string) error { + // find event handler and execute otherwise invalid event + if _, ok := h.eventHandlers[eventType]; ok { + // confirm leader for any operation that requires it + if h.confirmFacilitator != nil { + if _, ok := h.facilitatorOnlyOperations[eventType]; ok { + err := h.confirmFacilitator(roomID, userID) + if err != nil { + return err + } + } + } + + msg, eventErr, _ := h.eventHandlers[eventType](ctx, roomID, userID, eventValue) + if eventErr != nil { + return eventErr + } + + if h.RoomExists(roomID) { + h.Broadcast(Message{Data: msg, Room: roomID}) + } + } + + return nil +} diff --git a/internal/wshub/http_test.go b/internal/wshub/http_test.go new file mode 100644 index 00000000..6a44ce52 --- /dev/null +++ b/internal/wshub/http_test.go @@ -0,0 +1,107 @@ +package wshub + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/uptrace/opentelemetry-go-extra/otelzap" + "go.uber.org/zap" +) + +// TestHandleSocketClose tests the handling of socket closure +func TestHandleSocketClose(t *testing.T) { + hub := NewHub(otelzap.New(zap.NewNop()), Config{}, nil, nil, nil, nil) + + // Create channels for synchronization + serverReady := make(chan struct{}) + clientClosed := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("Failed to upgrade connection: %v", err) + } + defer conn.Close() + + // Signal that the server is ready + close(serverReady) + + // Wait for the client to close the connection + <-clientClosed + + hub.HandleSocketClose(context.Background(), conn, websocket.CloseNormalClosure, "test close") + })) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + t.Fatalf("Failed to connect to WebSocket server: %v", err) + } + + // Wait for the server to be ready + <-serverReady + + // Close the connection from the client side + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "client closing")) + if err != nil { + t.Fatalf("Failed to send close message: %v", err) + } + + // Signal that the client has closed the connection + close(clientClosed) + + // Wait a short time for the server to process the close + time.Sleep(100 * time.Millisecond) + + // Attempt to read from the closed connection + _, _, err = conn.ReadMessage() + + // Check for the expected error + if err == nil { + t.Fatal("Expected an error, but got nil") + } + + // The error message might vary, so we'll check for both possible outcomes + expectedErrors := []string{ + "websocket: close 1000 (normal)", + "use of closed network connection", + } + + errorMatched := false + for _, expectedError := range expectedErrors { + if strings.Contains(err.Error(), expectedError) { + errorMatched = true + break + } + } + + assert.True(t, errorMatched, "Error should be one of the expected closure messages") +} + +// TestWebSocketHandler tests the WebSocket handler +func TestWebSocketHandler(t *testing.T) { + hub := NewHub(otelzap.New(zap.NewNop()), Config{}, nil, nil, nil, nil) + + authFunc := func(w http.ResponseWriter, r *http.Request, c *Connection, roomID string) *AuthError { + return nil + } + + handler := hub.WebSocketHandler("roomID", authFunc) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler.ServeHTTP(w, r) + })) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + _, _, err := websocket.DefaultDialer.Dial(url, nil) + assert.NoError(t, err) +} diff --git a/internal/wshub/hub.go b/internal/wshub/hub.go new file mode 100644 index 00000000..dd3278d3 --- /dev/null +++ b/internal/wshub/hub.go @@ -0,0 +1,152 @@ +package wshub + +import ( + "context" + + "github.com/gorilla/websocket" + "github.com/uptrace/opentelemetry-go-extra/otelzap" +) + +const ( + // Maximum message size allowed from peer. + maxMessageSize = 1024 * 1024 +) + +// Message represents a message sent to the websocket hub. +type Message struct { + Data []byte `json:"data"` + Room string `json:"room"` +} + +type roomExistsRequest struct { + room string + response chan bool +} + +// Hub maintains the set of active connections and broadcasts messages to the connections. +type Hub struct { + rooms map[string]map[Connection]struct{} + broadcast chan Message + register chan Subscription + unregister chan Subscription + roomExists chan roomExistsRequest + logger *otelzap.Logger + config *Config + eventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool) + facilitatorOnlyOperations map[string]struct{} + confirmFacilitator func(roomId string, userId string) error + retreatUser func(roomId string, userId string) string +} + +// NewHub creates a new websocket hub. +func NewHub( + logger *otelzap.Logger, + config Config, + eventHandlers map[string]func(context.Context, string, string, string) ([]byte, error, bool), + facilitatorOnlyOperations map[string]struct{}, + confirmFacilitator func(roomId string, userId string) error, + retreatUser func(roomId string, userId string) string, +) *Hub { + return &Hub{ + broadcast: make(chan Message), + register: make(chan Subscription), + unregister: make(chan Subscription), + rooms: make(map[string]map[Connection]struct{}), + roomExists: make(chan roomExistsRequest), + logger: logger, + config: &config, + eventHandlers: eventHandlers, + facilitatorOnlyOperations: facilitatorOnlyOperations, + confirmFacilitator: confirmFacilitator, + retreatUser: retreatUser, + } +} + +// Run starts the hub. +func (h *Hub) Run() { + for { + select { + case sub := <-h.register: + if _, ok := h.rooms[sub.RoomID]; !ok { + h.rooms[sub.RoomID] = make(map[Connection]struct{}) + } + h.rooms[sub.RoomID][sub.Conn] = struct{}{} + + case sub := <-h.unregister: + if _, ok := h.rooms[sub.RoomID]; ok { + if _, ok := h.rooms[sub.RoomID][sub.Conn]; ok { + delete(h.rooms[sub.RoomID], sub.Conn) + sub.Conn.Close() + if len(h.rooms[sub.RoomID]) == 0 { + delete(h.rooms, sub.RoomID) + } + } + } + + case m := <-h.broadcast: + if connections, ok := h.rooms[m.Room]; ok { + for conn := range connections { + select { + case conn.Send() <- m.Data: + default: + close(conn.Send()) + delete(connections, conn) + if len(connections) == 0 { + delete(h.rooms, m.Room) + } + } + } + } + + case req := <-h.roomExists: + _, exists := h.rooms[req.room] + req.response <- exists + } + } +} + +// Register adds a subscription to the room. +func (h *Hub) Register(sub Subscription) { + h.register <- sub +} + +// Unregister removes a subscription from the room. +func (h *Hub) Unregister(sub Subscription) { + h.unregister <- sub +} + +// Broadcast sends a message to all connections in the room. +func (h *Hub) Broadcast(msg Message) { + h.broadcast <- msg +} + +// RoomExists checks if a room exists in the hub. +func (h *Hub) RoomExists(room string) bool { + response := make(chan bool) + h.roomExists <- roomExistsRequest{room: room, response: response} + return <-response +} + +// NewConnection creates a new websocket connection. +func (h *Hub) NewConnection(ws *websocket.Conn) Connection { + return Connection{ + send: make(chan []byte, 256), + Ws: ws, + PingPeriod: h.config.PingPeriod(), + WriteWait: h.config.WriteWait(), + PongWait: h.config.PongWait(), + } +} + +// NewSubscriber creates a new subscription to the room for the given websocket connection. +func (h *Hub) NewSubscriber(ws *websocket.Conn, userID string, roomID string) Subscription { + sub := Subscription{ + Conn: h.NewConnection(ws), + RoomID: roomID, + UserID: userID, + } + + h.Register(sub) + + return sub +} diff --git a/internal/wshub/hub_test.go b/internal/wshub/hub_test.go new file mode 100644 index 00000000..587026af --- /dev/null +++ b/internal/wshub/hub_test.go @@ -0,0 +1,46 @@ +package wshub + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/uptrace/opentelemetry-go-extra/otelzap" + "go.uber.org/zap" +) + +// TestNewHub tests the creation of a new Hub +func TestNewHub(t *testing.T) { + config := Config{ + WriteWaitSec: 10, + PongWaitSec: 60, + PingPeriodSec: 54, + AppDomain: "example.com", + WebsocketSubdomain: "ws", + } + eventHandlers := make(map[string]func(context.Context, string, string, string) ([]byte, error, bool)) + facilitatorOnlyOperations := make(map[string]struct{}) + confirmFacilitator := func(roomId string, userId string) error { return nil } + retreatUser := func(roomId string, userId string) string { return "" } + + hub := NewHub(otelzap.New(zap.NewNop()), config, eventHandlers, facilitatorOnlyOperations, confirmFacilitator, retreatUser) + + assert.NotNil(t, hub) + assert.Equal(t, &config, hub.config) + //assert.Equal(t, logger, hub.logger) + assert.NotNil(t, hub.rooms) + assert.NotNil(t, hub.broadcast) + assert.NotNil(t, hub.register) + assert.NotNil(t, hub.unregister) +} + +// TestCreateWebsocketUpgrader tests the creation of a websocket upgrader +func TestCreateWebsocketUpgrader(t *testing.T) { + hub := NewHub(otelzap.New(zap.NewNop()), Config{AppDomain: "example.com", WebsocketSubdomain: "ws"}, nil, nil, nil, nil) + upgrader := hub.CreateWebsocketUpgrader() + + assert.NotNil(t, upgrader) + assert.Equal(t, 1024, upgrader.ReadBufferSize) + assert.Equal(t, 1024, upgrader.WriteBufferSize) + assert.NotNil(t, upgrader.CheckOrigin) +} diff --git a/internal/wshub/subscription.go b/internal/wshub/subscription.go new file mode 100644 index 00000000..0a032f76 --- /dev/null +++ b/internal/wshub/subscription.go @@ -0,0 +1,140 @@ +package wshub + +import ( + "context" + "encoding/json" + "time" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +type Subscription struct { + Conn Connection + RoomID string + UserID string +} + +// WritePump pumps messages from the Hub to the websocket connection. +func (s *Subscription) WritePump() { + ticker := time.NewTicker(s.Conn.PingPeriod) + defer func() { + ticker.Stop() + _ = s.Conn.Ws.Close() + }() + for { + select { + case message, ok := <-s.Conn.send: + if !ok { + _ = s.Conn.Write(websocket.CloseMessage, []byte{}) + return + } + if err := s.Conn.Write(websocket.TextMessage, message); err != nil { + return + } + case <-ticker.C: + if err := s.Conn.Write(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +func (s *Subscription) ReadPump( + ctx context.Context, + hub *Hub, +) { + ctx = context.WithoutCancel(ctx) + forceClosed := false + defer func() { + var UpdatedUsers string + if hub.retreatUser != nil { + UpdatedUsers = hub.retreatUser(s.RoomID, s.UserID) + } + + if forceClosed { + cm := websocket.FormatCloseMessage(4002, "abandoned") + if err := s.Conn.Ws.WriteControl(websocket.CloseMessage, cm, time.Now().Add(s.Conn.WriteWait)); err != nil { + hub.logger.Ctx(ctx).Error("abandon error", zap.Error(err), + zap.String("room_id", s.RoomID), zap.String("session_user_id", s.UserID)) + } + } + _ = s.Conn.Ws.Close() // close connection, don't care about error in attempting to close unclosed connection + + hub.Unregister(*s) + + if hub.retreatUser != nil { + userLeaveEvent := CreateSocketEvent("user_left", UpdatedUsers, s.UserID) + if hub.RoomExists(s.RoomID) { + hub.Broadcast(Message{Data: userLeaveEvent, Room: s.RoomID}) + } + } + }() + + s.Conn.Ws.SetReadLimit(maxMessageSize) + _ = s.Conn.Ws.SetReadDeadline(time.Now().Add(s.Conn.PongWait)) + s.Conn.Ws.SetPongHandler(func(string) error { + _ = s.Conn.Ws.SetReadDeadline(time.Now().Add(s.Conn.PongWait)) + return nil + }) + + // Read messages from the websocket connection. + for { + var badEvent bool + var eventErr error + _, msg, err := s.Conn.Ws.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, + ) { + hub.logger.Ctx(ctx).Error("unexpected close error", zap.Error(err), + zap.String("room_id", s.RoomID), zap.String("session_user_id", s.UserID)) + } + break + } + + keyVal := make(map[string]string) + err = json.Unmarshal(msg, &keyVal) + if err != nil { + badEvent = true + hub.logger.Error("unexpected room event json error", zap.Error(err), + zap.String("room_id", s.RoomID), zap.String("session_user_id", s.UserID)) + } + + eventType := keyVal["type"] + eventValue := keyVal["value"] + + // confirm leader for any operation that requires it (if the room requires) + if hub.confirmFacilitator != nil { + if _, ok := hub.facilitatorOnlyOperations[eventType]; ok && !badEvent { + err := hub.confirmFacilitator(s.RoomID, s.UserID) + if err != nil { + badEvent = true + } + } + } + + // find event handler and execute otherwise invalid event + if _, ok := hub.eventHandlers[eventType]; ok && !badEvent { + msg, eventErr, forceClosed = hub.eventHandlers[eventType](ctx, s.RoomID, s.UserID, eventValue) + if eventErr != nil { + badEvent = true + + // don't log forceClosed events e.g. Abandon + if !forceClosed { + hub.logger.Ctx(ctx).Error("close error", zap.Error(eventErr), + zap.String("room_id", s.RoomID), zap.String("session_user_id", s.UserID), + zap.String("room_event_type", eventType)) + } + } + } + + if !badEvent && hub.RoomExists(s.RoomID) { + hub.Broadcast(Message{Data: msg, Room: s.RoomID}) + } + + if forceClosed { + break + } + } +} diff --git a/internal/wshub/util.go b/internal/wshub/util.go new file mode 100644 index 00000000..5b4443dc --- /dev/null +++ b/internal/wshub/util.go @@ -0,0 +1,69 @@ +package wshub + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "unicode/utf8" +) + +// EqualASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +// Taken from Gorilla Websocket, https://github.com/gorilla/websocket/blob/main/util.go +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +func checkOrigin(r *http.Request, appDomain string, subDomain string) bool { + origin := r.Header.Get("Origin") + if len(origin) == 0 { + return true + } + originUrl, err := url.Parse(origin) + if err != nil { + return false + } + appDomainCheck := equalASCIIFold(originUrl.Host, appDomain) + subDomainCheck := equalASCIIFold(originUrl.Host, fmt.Sprintf("%s.%s", subDomain, appDomain)) + hostCheck := equalASCIIFold(originUrl.Host, r.Host) + + return appDomainCheck || subDomainCheck || hostCheck +} + +// SocketEvent is the event structure used for socket messages +type SocketEvent struct { + Type string `json:"type"` + Value string `json:"value"` + UserID string `json:"userId"` +} + +func CreateSocketEvent(Type string, Value string, UserID string) []byte { + newEvent := &SocketEvent{ + Type: Type, + Value: Value, + UserID: UserID, + } + + event, _ := json.Marshal(newEvent) + + return event +} diff --git a/internal/wshub/util_test.go b/internal/wshub/util_test.go new file mode 100644 index 00000000..ef723544 --- /dev/null +++ b/internal/wshub/util_test.go @@ -0,0 +1,41 @@ +package wshub + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestCreateSocketEvent tests the creation of socket events +func TestCreateSocketEvent(t *testing.T) { + event := CreateSocketEvent("test_type", "test_value", "user1") + + var socketEvent SocketEvent + err := json.Unmarshal(event, &socketEvent) + assert.NoError(t, err) + assert.Equal(t, "test_type", socketEvent.Type) + assert.Equal(t, "test_value", socketEvent.Value) + assert.Equal(t, "user1", socketEvent.UserID) +} + +// TestEqualASCIIFold tests the ASCII folding comparison +func TestEqualASCIIFold(t *testing.T) { + assert.True(t, equalASCIIFold("test", "TEST")) + assert.True(t, equalASCIIFold("Test", "tEST")) + assert.False(t, equalASCIIFold("test", "test1")) +} + +// TestCheckOrigin tests the origin checking function +func TestCheckOrigin(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.com", nil) + r.Header.Set("Origin", "http://example.com") + assert.True(t, checkOrigin(r, "example.com", "ws")) + + r.Header.Set("Origin", "http://ws.example.com") + assert.True(t, checkOrigin(r, "example.com", "ws")) + + r.Header.Set("Origin", "http://other.com") + assert.False(t, checkOrigin(r, "example.com", "ws")) +} diff --git a/ui/src/pages/poker/PokerGame.svelte b/ui/src/pages/poker/PokerGame.svelte index f8c73c8d..984a84b2 100644 --- a/ui/src/pages/poker/PokerGame.svelte +++ b/ui/src/pages/poker/PokerGame.svelte @@ -59,7 +59,7 @@ let socketReconnecting: boolean = false; let points: Array = ['1', '2', '3', '5', '8', '13', '?']; let vote: string = ''; - let battle: PokerGame = { + let pokerGame: PokerGame = { leaders: [], autoFinishVoting: false, createdDate: undefined, @@ -75,8 +75,8 @@ teamId: '', }; let currentStory = { ...defaultStory }; - let showEditBattle: boolean = false; - let showDeleteBattle: boolean = false; + let showEditGame: boolean = false; + let showDeleteGame: boolean = false; let isSpectator: boolean = false; let voteStartTime: Date = new Date(); @@ -93,15 +93,15 @@ break; case 'init': { JoinPassRequired = false; - battle = JSON.parse(parsedEvent.value); - points = battle.pointValuesAllowed; + pokerGame = JSON.parse(parsedEvent.value); + points = pokerGame.pointValuesAllowed; const { spectator = false } = - battle.users.find(w => w.id === $user.id) || {}; + pokerGame.users.find(w => w.id === $user.id) || {}; isSpectator = spectator; - if (battle.activePlanId !== '') { - const activePlan = battle.plans.find( - p => p.id === battle.activePlanId, + if (pokerGame.activePlanId !== '') { + const activePlan = pokerGame.plans.find( + p => p.id === pokerGame.activePlanId, ); const warriorVote = activePlan.votes.find( v => v.warriorId === $user.id, @@ -116,10 +116,10 @@ eventTag('join', 'battle', ''); break; } - case 'warrior_joined': { - battle.users = JSON.parse(parsedEvent.value); - const joinedWarrior = battle.users.find( - w => w.id === parsedEvent.warriorId, + case 'user_joined': { + pokerGame.users = JSON.parse(parsedEvent.value); + const joinedWarrior = pokerGame.users.find( + w => w.id === parsedEvent.userId, ); if (joinedWarrior.id === $user.id) { isSpectator = joinedWarrior.spectator; @@ -134,11 +134,11 @@ } break; } - case 'warrior_retreated': - const leftWarrior = battle.users.find( - w => w.id === parsedEvent.warriorId, + case 'user_left': + const leftWarrior = pokerGame.users.find( + w => w.id === parsedEvent.userId, ); - battle.users = JSON.parse(parsedEvent.value); + pokerGame.users = JSON.parse(parsedEvent.value); if ($user.notificationsEnabled) { notifications.danger( @@ -150,15 +150,15 @@ } break; case 'users_updated': - battle.users = JSON.parse(parsedEvent.value); - const updatedWarrior = battle.users.find(w => w.id === $user.id); + pokerGame.users = JSON.parse(parsedEvent.value); + const updatedWarrior = pokerGame.users.find(w => w.id === $user.id); isSpectator = updatedWarrior.spectator; break; case 'plan_added': - battle.plans = JSON.parse(parsedEvent.value); + pokerGame.plans = JSON.parse(parsedEvent.value); break; case 'story_arranged': - battle.plans = JSON.parse(parsedEvent.value); + pokerGame.plans = JSON.parse(parsedEvent.value); break; case 'plan_activated': const updatedPlans = JSON.parse(parsedEvent.value); @@ -166,25 +166,25 @@ currentStory = activePlan; voteStartTime = new Date(activePlan.voteStartTime); - battle.plans = updatedPlans; - battle.activePlanId = activePlan.id; - battle.votingLocked = false; + pokerGame.plans = updatedPlans; + pokerGame.activePlanId = activePlan.id; + pokerGame.votingLocked = false; vote = ''; break; case 'plan_skipped': const updatedPlans2 = JSON.parse(parsedEvent.value); currentStory = { ...defaultStory }; - battle.plans = updatedPlans2; - battle.activePlanId = ''; - battle.votingLocked = true; + pokerGame.plans = updatedPlans2; + pokerGame.activePlanId = ''; + pokerGame.votingLocked = true; vote = ''; if ($user.notificationsEnabled) { notifications.warning($LL.planSkipped()); } break; case 'vote_activity': - const votedWarrior = battle.users.find( - w => w.id === parsedEvent.warriorId, + const votedWarrior = pokerGame.users.find( + w => w.id === parsedEvent.userId, ); if ($user.notificationsEnabled) { notifications.success( @@ -195,11 +195,11 @@ ); } - battle.plans = JSON.parse(parsedEvent.value); + pokerGame.plans = JSON.parse(parsedEvent.value); break; case 'vote_retracted': - const devotedWarrior = battle.users.find( - w => w.id === parsedEvent.warriorId, + const devotedWarrior = pokerGame.users.find( + w => w.id === parsedEvent.userId, ); if ($user.notificationsEnabled) { notifications.warning( @@ -210,23 +210,23 @@ ); } - battle.plans = JSON.parse(parsedEvent.value); + pokerGame.plans = JSON.parse(parsedEvent.value); break; case 'voting_ended': - battle.plans = JSON.parse(parsedEvent.value); - battle.votingLocked = true; + pokerGame.plans = JSON.parse(parsedEvent.value); + pokerGame.votingLocked = true; break; case 'plan_finalized': - battle.plans = JSON.parse(parsedEvent.value); - battle.activePlanId = ''; + pokerGame.plans = JSON.parse(parsedEvent.value); + pokerGame.activePlanId = ''; currentStory = { ...defaultStory }; vote = ''; break; case 'plan_revised': - battle.plans = JSON.parse(parsedEvent.value); - if (battle.activePlanId !== '') { - const activePlan = battle.plans.find( - p => p.id === battle.activePlanId, + pokerGame.plans = JSON.parse(parsedEvent.value); + if (pokerGame.activePlanId !== '') { + const activePlan = pokerGame.plans.find( + p => p.id === pokerGame.activePlanId, ); currentStory = activePlan; } @@ -235,28 +235,29 @@ const postBurnPlans = JSON.parse(parsedEvent.value); if ( - battle.activePlanId !== '' && - postBurnPlans.filter(p => p.id === battle.activePlanId).length === 0 + pokerGame.activePlanId !== '' && + postBurnPlans.filter(p => p.id === pokerGame.activePlanId).length === + 0 ) { - battle.activePlanId = ''; + pokerGame.activePlanId = ''; currentStory = { ...defaultStory }; } - battle.plans = postBurnPlans; + pokerGame.plans = postBurnPlans; break; case 'leaders_updated': - battle.leaders = parsedEvent.value; + pokerGame.leaders = parsedEvent.value; break; case 'battle_revised': const revisedBattle = JSON.parse(parsedEvent.value); - battle.name = revisedBattle.battleName; + pokerGame.name = revisedBattle.battleName; points = revisedBattle.pointValuesAllowed; - battle.autoFinishVoting = revisedBattle.autoFinishVoting; - battle.pointAverageRounding = revisedBattle.pointAverageRounding; - battle.joinCode = revisedBattle.joinCode; - battle.hideVoterIdentity = revisedBattle.hideVoterIdentity; - battle.teamId = revisedBattle.teamId; + pokerGame.autoFinishVoting = revisedBattle.autoFinishVoting; + pokerGame.pointAverageRounding = revisedBattle.pointAverageRounding; + pokerGame.joinCode = revisedBattle.joinCode; + pokerGame.hideVoterIdentity = revisedBattle.hideVoterIdentity; + pokerGame.teamId = revisedBattle.teamId; break; case 'battle_conceded': // poker over, goodbye. @@ -264,7 +265,9 @@ router.route(appRoutes.games); break; case 'jab_warrior': - const userToNudge = battle.users.find(w => w.id === parsedEvent.value); + const userToNudge = pokerGame.users.find( + w => w.id === parsedEvent.value, + ); notifications.info( `${$LL.warriorNudgeMessage({ name: userToNudge.name, @@ -339,9 +342,9 @@ const handleVote = event => { vote = event.detail.point; const voteValue = { - planId: battle.activePlanId, + planId: pokerGame.activePlanId, voteValue: vote, - autoFinishVoting: battle.autoFinishVoting, + autoFinishVoting: pokerGame.autoFinishVoting, }; sendSocketEvent('vote', JSON.stringify(voteValue)); @@ -351,19 +354,19 @@ const handleUnvote = () => { vote = ''; - sendSocketEvent('retract_vote', battle.activePlanId); + sendSocketEvent('retract_vote', pokerGame.activePlanId); eventTag('retract_vote', 'battle', vote); }; // Determine if the warrior has voted on active Plan yet function didVote(warriorId) { if ( - battle.activePlanId === '' || - (battle.votingLocked && battle.hideVoterIdentity) + pokerGame.activePlanId === '' || + (pokerGame.votingLocked && pokerGame.hideVoterIdentity) ) { return false; } - const plan = battle.plans.find(p => p.id === battle.activePlanId); + const plan = pokerGame.plans.find(p => p.id === pokerGame.activePlanId); const voted = plan.votes.find(w => w.warriorId === warriorId); return voted !== undefined; @@ -372,13 +375,13 @@ // Determine if we are showing users vote function showVote(warriorId) { if ( - battle.hideVoterIdentity || - battle.activePlanId === '' || - battle.votingLocked === false + pokerGame.hideVoterIdentity || + pokerGame.activePlanId === '' || + pokerGame.votingLocked === false ) { return ''; } - const story = battle.plans.find(p => p.id === battle.activePlanId); + const story = pokerGame.plans.find(p => p.id === pokerGame.activePlanId); const voted = story.votes.find(w => w.warriorId === warriorId); return voted !== undefined ? voted.vote : ''; @@ -394,7 +397,9 @@ vote: '', count: 0, }; - const activePlan = battle.plans.find(p => p.id === battle.activePlanId); + const activePlan = pokerGame.plans.find( + p => p.id === pokerGame.activePlanId, + ); if (activePlan.votes.length > 0) { const reversedPoints = [...points] @@ -405,7 +410,8 @@ // build a count of each vote activePlan.votes.forEach(v => { - const voteWarrior = battle.users.find(w => w.id === v.warriorId) || {}; + const voteWarrior = + pokerGame.users.find(w => w.id === v.warriorId) || {}; const { spectator = false } = voteWarrior; if (typeof voteCounts[v.vote] !== 'undefined' && !spectator) { @@ -426,15 +432,15 @@ } $: highestVoteCount = - battle.activePlanId !== '' && battle.votingLocked === true + pokerGame.activePlanId !== '' && pokerGame.votingLocked === true ? getHighestVote() : ''; $: showVotingResults = - battle.activePlanId !== '' && battle.votingLocked === true; + pokerGame.activePlanId !== '' && pokerGame.votingLocked === true; - $: isLeader = battle.leaders.includes($user.id); + $: isLeader = pokerGame.leaders.includes($user.id); - function concedeBattle() { + function concedeGame() { eventTag('concede_battle', 'battle', '', () => { sendSocketEvent('concede_battle', ''); }); @@ -446,23 +452,23 @@ }); } - function toggleEditBattle() { - showEditBattle = !showEditBattle; + function toggleEditGame() { + showEditGame = !showEditGame; } - const toggleDeleteBattle = () => { - showDeleteBattle = !showDeleteBattle; + const toggleDeleteGame = () => { + showDeleteGame = !showDeleteGame; }; - function handleBattleEdit(revisedBattle) { + function handleGameEdit(revisedBattle) { sendSocketEvent('revise_battle', JSON.stringify(revisedBattle)); eventTag('revise_battle', 'battle', ''); - toggleEditBattle(); - battle.leaderCode = revisedBattle.leaderCode; + toggleEditGame(); + pokerGame.leaderCode = revisedBattle.leaderCode; } function authBattle(joinPasscode) { - sendSocketEvent('auth_battle', joinPasscode); + sendSocketEvent('auth_game', joinPasscode); eventTag('auth_battle', 'battle', ''); } @@ -477,7 +483,7 @@ {$LL.battle()} - {battle.name} | {$LL.appName()} @@ -519,14 +525,14 @@ class="text-gray-700 dark:text-gray-300 text-3xl font-semibold font-rajdhani leading-tight" data-testid="battle-name" > - {battle.name} + {pokerGame.name}
@@ -538,9 +544,10 @@
{:else} @@ -552,7 +559,7 @@ active="{vote === point}" on:voted="{handleVote}" on:voteRetraction="{handleUnvote}" - isLocked="{battle.votingLocked || isSpectator}" + isLocked="{pokerGame.votingLocked || isSpectator}" /> {/each} @@ -560,13 +567,13 @@ {/if} @@ -580,15 +587,15 @@ - {#each battle.users as war (war.id)} + {#each pokerGame.users as war (war.id)} {#if war.active} @@ -611,22 +618,22 @@
{#if isLeader}
{$LL.battleEdit()} {$LL.battleDelete()} @@ -647,28 +654,28 @@
- {#if showEditBattle} + {#if showEditGame} {/if} - {#if showDeleteBattle} + {#if showDeleteGame}