Skip to content

Commit

Permalink
support cf_ws_frame_ping_pong url param (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Jun 26, 2023
1 parent cbb6427 commit c6ba9d3
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 8 deletions.
81 changes: 73 additions & 8 deletions handler_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,20 @@ func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
incTransportConnect(transportWebsocket)

var protoType = ProtocolTypeJSON
var useFramePingPong bool

if r.URL.RawQuery != "" {
query := r.URL.Query()
if query.Get("format") == "protobuf" || query.Get("cf_protocol") == "protobuf" {
protoType = ProtocolTypeProtobuf
}
if query.Get("cf_ws_frame_ping_pong") == "true" {
// This is a way for tools like Postman, wscat and others to maintain
// active connection to the Centrifuge-based server without the need to
// respond to app-level pings. We rely on native websocket ping/pong
// frames in this case.
useFramePingPong = true
}
}

compression := s.config.Compression
Expand Down Expand Up @@ -144,6 +152,15 @@ func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
protoType = ProtocolTypeProtobuf
}

if useFramePingPong {
pongWait := framePingInterval * 10 / 9
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
}

// Separate goroutine for better GC of caller's data.
go func() {
opts := websocketTransportOptions{
Expand All @@ -154,7 +171,7 @@ func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

graceCh := make(chan struct{})
transport := newWebsocketTransport(conn, opts, graceCh)
transport := newWebsocketTransport(conn, opts, graceCh, useFramePingPong)

select {
case <-s.node.NotifyShutdown():
Expand Down Expand Up @@ -191,6 +208,11 @@ func (s *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
}

if useFramePingPong {
conn.SetPingHandler(nil)
conn.SetPongHandler(nil)
}

_ = conn.SetReadDeadline(time.Now().Add(closeFrameWait))
for {
if _, _, err := conn.NextReader(); err != nil {
Expand Down Expand Up @@ -245,12 +267,13 @@ const (
// websocketTransport is a wrapper struct over websocket connection to fit session
// interface so client will accept it.
type websocketTransport struct {
mu sync.RWMutex
conn *websocket.Conn
closed bool
closeCh chan struct{}
graceCh chan struct{}
opts websocketTransportOptions
mu sync.RWMutex
conn *websocket.Conn
closed bool
closeCh chan struct{}
graceCh chan struct{}
opts websocketTransportOptions
nativePingTimer *time.Timer
}

type websocketTransportOptions struct {
Expand All @@ -260,13 +283,16 @@ type websocketTransportOptions struct {
compressionMinSize int
}

func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}) *websocketTransport {
func newWebsocketTransport(conn *websocket.Conn, opts websocketTransportOptions, graceCh chan struct{}, useNativePingPong bool) *websocketTransport {
transport := &websocketTransport{
conn: conn,
closeCh: make(chan struct{}),
graceCh: graceCh,
opts: opts,
}
if useNativePingPong {
transport.addPing()
}
return transport
}

Expand Down Expand Up @@ -303,6 +329,15 @@ func (t *websocketTransport) DisabledPushFlags() uint64 {

// PingPongConfig ...
func (t *websocketTransport) PingPongConfig() PingPongConfig {
t.mu.RLock()
useNativePingPong := t.nativePingTimer != nil
t.mu.RUnlock()
if useNativePingPong {
return PingPongConfig{
PingInterval: -1,
PongTimeout: -1,
}
}
return t.opts.pingPong
}

Expand Down Expand Up @@ -372,6 +407,9 @@ func (t *websocketTransport) Close(disconnect Disconnect) error {
}
t.closed = true
close(t.closeCh)
if t.nativePingTimer != nil {
t.nativePingTimer.Stop()
}
t.mu.Unlock()

if disconnect.Code != DisconnectConnectionClosed.Code {
Expand All @@ -396,6 +434,33 @@ func (t *websocketTransport) Close(disconnect Disconnect) error {
return t.conn.Close()
}

var framePingInterval = 25 * time.Second

func (t *websocketTransport) ping() {
select {
case <-t.closeCh:
return
default:
deadline := time.Now().Add(framePingInterval / 2)
err := t.conn.WriteControl(websocket.PingMessage, nil, deadline)
if err != nil {
_ = t.Close(DisconnectWriteError)
return
}
t.addPing()
}
}

func (t *websocketTransport) addPing() {
t.mu.Lock()
if t.closed {
t.mu.Unlock()
return
}
t.nativePingTimer = time.AfterFunc(framePingInterval, t.ping)
t.mu.Unlock()
}

func sameHostOriginCheck(n *Node) func(r *http.Request) bool {
return func(r *http.Request) bool {
err := checkSameHost(r)
Expand Down
57 changes: 57 additions & 0 deletions handler_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,63 @@ func TestWebsocketHandlerPing(t *testing.T) {
}
}

func TestWebsocketHandler_FramePingPong(t *testing.T) {
t.Parallel()
framePingInterval = time.Second
n, _ := New(Config{})
require.NoError(t, n.Run())
defer func() { _ = n.Shutdown(context.Background()) }()
mux := http.NewServeMux()
mux.Handle("/connection/websocket", NewWebsocketHandler(n, WebsocketConfig{}))
server := httptest.NewServer(mux)
defer server.Close()

n.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) {
return ConnectReply{
Credentials: &Credentials{
UserID: "test",
},
}, nil
})

url := "ws" + server.URL[4:]

conn, resp, err := websocket.DefaultDialer.Dial(url+"/connection/websocket?cf_ws_frame_ping_pong=true", nil)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.NotNil(t, conn)
defer func() { _ = conn.Close() }()

closeCh := make(chan struct{})

conn.SetPingHandler(func(_ string) error {
close(closeCh)
return nil
})

err = conn.WriteMessage(websocket.TextMessage, []byte(`{"id": 1, "connect": {}}`))
require.NoError(t, err)

go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
break
}
if strings.Contains(string(msg), "{}") {
require.Fail(t, "unexpected app-level ping")
}
}
}()

select {
case <-closeCh:
case <-time.After(5 * time.Second):
require.Fail(t, "timeout waiting for frame ping")
}
}

func TestWebsocketHandlerCustomDisconnect(t *testing.T) {
n, _ := New(Config{})
require.NoError(t, n.Run())
Expand Down

0 comments on commit c6ba9d3

Please sign in to comment.