Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue120 and issue103: add mutex to prevent data race in websocket write message and use named return values for WritePkg method #121

Merged
merged 8 commits into from
May 28, 2024
19 changes: 15 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ func (u *gettyUDPConn) CloseConn(_ int) {
type gettyWSConn struct {
gettyConn
conn *websocket.Conn
lock sync.Mutex
}

// create websocket connection
Expand Down Expand Up @@ -608,7 +609,7 @@ func (w *gettyWSConn) Send(pkg interface{}) (int, error) {
}

w.updateWriteDeadline()
if err = w.conn.WriteMessage(websocket.BinaryMessage, p); err == nil {
if err = w.threadSafeWriteMessage(websocket.BinaryMessage, p); err == nil {
w.writeBytes.Add((uint32)(len(p)))
w.writePkgNum.Add(1)
}
Expand All @@ -617,18 +618,18 @@ func (w *gettyWSConn) Send(pkg interface{}) (int, error) {

func (w *gettyWSConn) writePing() error {
w.updateWriteDeadline()
return perrors.WithStack(w.conn.WriteMessage(websocket.PingMessage, []byte{}))
return perrors.WithStack(w.threadSafeWriteMessage(websocket.PingMessage, []byte{}))
}

func (w *gettyWSConn) writePong(message []byte) error {
w.updateWriteDeadline()
return perrors.WithStack(w.conn.WriteMessage(websocket.PongMessage, message))
return perrors.WithStack(w.threadSafeWriteMessage(websocket.PongMessage, message))
}

// close websocket connection
func (w *gettyWSConn) CloseConn(waitSec int) {
w.updateWriteDeadline()
w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye-bye!!!"))
w.threadSafeWriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye-bye!!!"))
conn := w.conn.UnderlyingConn()
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetLinger(waitSec)
Expand All @@ -637,3 +638,13 @@ func (w *gettyWSConn) CloseConn(waitSec int) {
}
w.conn.Close()
}

// uses a mutex to ensure that only one thread can send a message at a time, preventing race conditions.
func (w *gettyWSConn) threadSafeWriteMessage(messageType int, data []byte) error {
w.lock.Lock()
defer w.lock.Unlock()
if err := w.conn.WriteMessage(messageType, data); err != nil {
return err
}
return nil
}
Loading