Skip to content

Commit a96fbf7

Browse files
authored
Merge pull request #123 from No-SilverBullet/master
fix:add read mutex in gettyWSConn(websocket) struct to prevent data race in ReadMessage()
2 parents 2769505 + 56834ab commit a96fbf7

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

connection.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,9 @@ func (u *gettyUDPConn) CloseConn(_ int) {
489489

490490
type gettyWSConn struct {
491491
gettyConn
492-
conn *websocket.Conn
493-
lock sync.Mutex
492+
writeLock sync.Mutex
493+
readLock sync.Mutex
494+
conn *websocket.Conn
494495
}
495496

496497
// create websocket connection
@@ -565,7 +566,7 @@ func (w *gettyWSConn) handlePong(string) error {
565566
func (w *gettyWSConn) recv() ([]byte, error) {
566567
// Pls do not set read deadline when using ReadMessage. AlexStocks 20180310
567568
// gorilla/websocket/conn.go:NextReader will always fail when got a timeout error.
568-
_, b, e := w.conn.ReadMessage() // the first return value is message type.
569+
_, b, e := w.threadSafeReadMessage() // the first return value is message type.
569570
if e == nil {
570571
w.readBytes.Add((uint32)(len(b)))
571572
} else {
@@ -641,10 +642,21 @@ func (w *gettyWSConn) CloseConn(waitSec int) {
641642

642643
// uses a mutex to ensure that only one thread can send a message at a time, preventing race conditions.
643644
func (w *gettyWSConn) threadSafeWriteMessage(messageType int, data []byte) error {
644-
w.lock.Lock()
645-
defer w.lock.Unlock()
645+
w.writeLock.Lock()
646+
defer w.writeLock.Unlock()
646647
if err := w.conn.WriteMessage(messageType, data); err != nil {
647648
return err
648649
}
649650
return nil
650651
}
652+
653+
// uses a mutex to ensure that only one thread can read a message at a time, preventing race conditions.
654+
func (w *gettyWSConn) threadSafeReadMessage() (int, []byte, error) {
655+
w.readLock.Lock()
656+
defer w.readLock.Unlock()
657+
messageType, readBytes, err := w.conn.ReadMessage()
658+
if err != nil {
659+
return messageType, nil, err
660+
}
661+
return messageType, readBytes, nil
662+
}

0 commit comments

Comments
 (0)