diff --git a/go.mod b/go.mod index aaf9e557e7..259e900555 100644 --- a/go.mod +++ b/go.mod @@ -70,7 +70,7 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/google/pprof v0.0.0-20230926050212-f7f687d19a98 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect - github.com/igm/sockjs-go/v3 v3.0.2 // indirect + github.com/igm/sockjs-go/v3 v3.0.2 github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.4 // indirect diff --git a/internal/sockjs/cancelctx.go b/internal/sockjs/cancelctx.go new file mode 100644 index 0000000000..54c3996419 --- /dev/null +++ b/internal/sockjs/cancelctx.go @@ -0,0 +1,34 @@ +package sockjs + +import ( + "context" + "time" +) + +// customCancelContext wraps context and cancels as soon as channel closed. +type customCancelContext struct { + context.Context + ch <-chan struct{} +} + +// Deadline not used. +func (c customCancelContext) Deadline() (time.Time, bool) { return time.Time{}, false } + +// Done returns channel that will be closed as soon as connection closed. +func (c customCancelContext) Done() <-chan struct{} { return c.ch } + +// Err returns context error. +func (c customCancelContext) Err() error { + select { + case <-c.ch: + return context.Canceled + default: + return nil + } +} + +// NewCancelContext returns a wrapper context around original context that will +// be canceled on channel close. +func NewCancelContext(ctx context.Context, ch <-chan struct{}) context.Context { + return customCancelContext{Context: ctx, ch: ch} +} diff --git a/internal/sockjs/handler_sockjs.go b/internal/sockjs/handler_sockjs.go new file mode 100644 index 0000000000..ffc631ef73 --- /dev/null +++ b/internal/sockjs/handler_sockjs.go @@ -0,0 +1,278 @@ +package sockjs + +import ( + "net/http" + "sync" + "time" + + "github.com/centrifugal/centrifuge" + + "github.com/centrifugal/protocol" + "github.com/gorilla/websocket" + "github.com/igm/sockjs-go/v3/sockjs" +) + +// Config represents config for SockJS handler. +type Config struct { + // HandlerPrefix sets prefix for SockJS handler endpoint path. + HandlerPrefix string + + // URL is an address to SockJS client javascript library. Required for iframe-based + // transports to work. This URL should lead to the same SockJS client version as used + // for connecting on the client side. + URL string + + // CheckOrigin allows deciding whether to use CORS or not in XHR case. + // When false returned then CORS headers won't be set. + CheckOrigin func(*http.Request) bool + + // WebsocketCheckOrigin allows setting custom CheckOrigin func for underlying + // Gorilla Websocket based websocket.Upgrader. + WebsocketCheckOrigin func(*http.Request) bool + + // WebsocketReadBufferSize is a parameter that is used for raw websocket.Upgrader. + // If set to zero reasonable default value will be used. + WebsocketReadBufferSize int + + // WebsocketWriteBufferSize is a parameter that is used for raw websocket.Upgrader. + // If set to zero reasonable default value will be used. + WebsocketWriteBufferSize int + + // WebsocketUseWriteBufferPool enables using buffer pool for writes in Websocket transport. + WebsocketUseWriteBufferPool bool + + // WebsocketWriteTimeout is maximum time of write message operation. + // Slow client will be disconnected. + // By default, 1 * time.Second will be used. + WebsocketWriteTimeout time.Duration + + centrifuge.PingPongConfig +} + +// Handler accepts SockJS connections. SockJS has a bunch of fallback +// transports when WebSocket connection is not supported. It comes with additional +// costs though: small protocol framing overhead, lack of binary support, more +// goroutines per connection, and you need to use sticky session mechanism on +// your load balancer in case you are using HTTP-based SockJS fallbacks and have +// more than one Centrifuge Node on a backend (so SockJS to be able to emulate +// bidirectional protocol). So if you can afford it - use WebsocketHandler only. +type Handler struct { + node *centrifuge.Node + config Config + handler http.Handler +} + +var writeBufferPool = &sync.Pool{} + +// NewHandler creates new Handler. +func NewHandler(node *centrifuge.Node, config Config) *Handler { + options := sockjs.DefaultOptions + + wsUpgrader := &websocket.Upgrader{ + ReadBufferSize: config.WebsocketReadBufferSize, + WriteBufferSize: config.WebsocketWriteBufferSize, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, + } + wsUpgrader.CheckOrigin = config.WebsocketCheckOrigin + if config.WebsocketUseWriteBufferPool { + wsUpgrader.WriteBufferPool = writeBufferPool + } else { + wsUpgrader.WriteBufferSize = config.WebsocketWriteBufferSize + } + options.WebsocketUpgrader = wsUpgrader + + // Override sockjs url. It's important to use the same SockJS + // library version on client and server sides when using iframe + // based SockJS transports, otherwise SockJS will raise error + // about version mismatch. + options.SockJSURL = config.URL + options.CheckOrigin = config.CheckOrigin + + wsWriteTimeout := config.WebsocketWriteTimeout + if wsWriteTimeout == 0 { + wsWriteTimeout = 1 * time.Second + } + options.WebsocketWriteTimeout = wsWriteTimeout + + s := &Handler{ + node: node, + config: config, + } + + options.HeartbeatDelay = 0 + s.handler = sockjs.NewHandler(config.HandlerPrefix, options, s.sockJSHandler) + return s +} + +func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(rw, r) +} + +// sockJSHandler called when new client connection comes to SockJS endpoint. +func (s *Handler) sockJSHandler(sess sockjs.Session) { + s.handleSession(sess) +} + +// sockJSHandler called when new client connection comes to SockJS endpoint. +func (s *Handler) handleSession(sess sockjs.Session) { + // Separate goroutine for better GC of caller's data. + go func() { + transport := newSockjsTransport(sess, sockjsTransportOptions{ + pingPong: s.config.PingPongConfig, + }) + + select { + case <-s.node.NotifyShutdown(): + _ = transport.Close(centrifuge.DisconnectShutdown) + return + default: + } + + ctxCh := make(chan struct{}) + defer close(ctxCh) + c, closeFn, err := centrifuge.NewClient(NewCancelContext(sess.Request().Context(), ctxCh), s.node, transport) + if err != nil { + s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error creating client", map[string]any{"transport": transportSockJS})) + return + } + defer func() { _ = closeFn() }() + + if s.node.LogEnabled(centrifuge.LogLevelDebug) { + s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection established", map[string]any{"client": c.ID(), "transport": transportSockJS})) + defer func(started time.Time) { + s.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "client connection completed", map[string]any{"client": c.ID(), "transport": transportSockJS, "duration": time.Since(started)})) + }(time.Now()) + } + + var needWaitLoop bool + + for { + if msg, err := sess.Recv(); err == nil { + reader := GetStringReader(msg) + if ok := centrifuge.HandleReadFrame(c, reader); !ok { + PutStringReader(reader) + needWaitLoop = true + break + } + PutStringReader(reader) + continue + } + break + } + + if needWaitLoop { + // One extra loop till we get an error from session, + // this is required to wait until close frame will be sent + // into connection inside Client implementation and transport + // closed with proper disconnect reason. + for { + if _, err := sess.Recv(); err != nil { + break + } + } + } + }() +} + +const ( + transportSockJS = "sockjs" +) + +type sockjsTransportOptions struct { + pingPong centrifuge.PingPongConfig +} + +type sockjsTransport struct { + mu sync.RWMutex + closeCh chan struct{} + session sockjs.Session + opts sockjsTransportOptions + closed bool +} + +func newSockjsTransport(s sockjs.Session, opts sockjsTransportOptions) *sockjsTransport { + t := &sockjsTransport{ + session: s, + closeCh: make(chan struct{}), + opts: opts, + } + return t +} + +// Name returns name of transport. +func (t *sockjsTransport) Name() string { + return transportSockJS +} + +// Protocol returns transport protocol. +func (t *sockjsTransport) Protocol() centrifuge.ProtocolType { + return centrifuge.ProtocolTypeJSON +} + +// ProtocolVersion returns transport ProtocolVersion. +func (t *sockjsTransport) ProtocolVersion() centrifuge.ProtocolVersion { + return centrifuge.ProtocolVersion2 +} + +// Unidirectional returns whether transport is unidirectional. +func (t *sockjsTransport) Unidirectional() bool { + return false +} + +// Emulation ... +func (t *sockjsTransport) Emulation() bool { + return false +} + +// DisabledPushFlags ... +func (t *sockjsTransport) DisabledPushFlags() uint64 { + // SockJS has its own close frames to mimic WebSocket Close frames, + // so we don't need to send Disconnect pushes. + return centrifuge.PushFlagDisconnect +} + +// PingPongConfig ... +func (t *sockjsTransport) PingPongConfig() centrifuge.PingPongConfig { + return t.opts.pingPong +} + +// Write data to transport. +func (t *sockjsTransport) Write(message []byte) error { + select { + case <-t.closeCh: + return nil + default: + // No need to use protocol encoders here since + // SockJS only supports JSON. + return t.session.Send(string(message)) + } +} + +// WriteMany messages to transport. +func (t *sockjsTransport) WriteMany(messages ...[]byte) error { + select { + case <-t.closeCh: + return nil + default: + encoder := protocol.GetDataEncoder(protocol.Type(centrifuge.ProtocolTypeJSON)) + defer protocol.PutDataEncoder(protocol.Type(centrifuge.ProtocolTypeJSON), encoder) + for i := range messages { + _ = encoder.Encode(messages[i]) + } + return t.session.Send(string(encoder.Finish())) + } +} + +// Close closes transport. +func (t *sockjsTransport) Close(disconnect centrifuge.Disconnect) error { + t.mu.Lock() + if t.closed { + // Already closed, noop. + t.mu.Unlock() + return nil + } + t.closed = true + close(t.closeCh) + t.mu.Unlock() + return t.session.Close(disconnect.Code, disconnect.Reason) +} diff --git a/internal/sockjs/handler_sockjs_test.go b/internal/sockjs/handler_sockjs_test.go new file mode 100644 index 0000000000..dd5ff30726 --- /dev/null +++ b/internal/sockjs/handler_sockjs_test.go @@ -0,0 +1,280 @@ +package sockjs + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/centrifugal/centrifuge" + + "github.com/centrifugal/protocol" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +func sockjsData(data []byte) []byte { + quoted, _ := json.Marshal(string(data)) + return []byte(fmt.Sprintf("[%s]", string(quoted))) +} + +func TestSockjsHandler(t *testing.T) { + n, _ := centrifuge.New(centrifuge.Config{}) + require.NoError(t, n.Run()) + defer func() { _ = n.Shutdown(context.Background()) }() + mux := http.NewServeMux() + + n.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) { + require.Equal(t, transportSockJS, event.Transport.Name()) + require.Equal(t, centrifuge.ProtocolTypeJSON, event.Transport.Protocol()) + return centrifuge.ConnectReply{ + Credentials: ¢rifuge.Credentials{UserID: "user"}, + Data: []byte(`{"SockJS connect response": 1}`), + }, nil + }) + + doneCh := make(chan struct{}) + + n.OnConnect(func(client *centrifuge.Client) { + err := client.Send([]byte(`{"SockJS write": 1}`)) + require.NoError(t, err) + client.Disconnect(centrifuge.DisconnectForceReconnect) + }) + + mux.Handle("/connection/sockjs/", NewHandler(n, Config{ + HandlerPrefix: "/connection/sockjs", + })) + server := httptest.NewServer(mux) + defer server.Close() + + url := "ws" + server.URL[4:] + + conn, resp, err := websocket.DefaultDialer.Dial(url+"/connection/sockjs/220/fi0988475/websocket", nil) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NotNil(t, conn) + defer func() { _ = conn.Close() }() + _, p, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "o", string(p)) // open frame of SockJS protocol. + + connectRequest := &protocol.ConnectRequest{ + Token: "boom", + } + cmd := &protocol.Command{ + Id: 1, + Connect: connectRequest, + } + cmdBytes, _ := json.Marshal(cmd) + err = conn.WriteMessage(websocket.TextMessage, sockjsData(cmdBytes)) + require.NoError(t, err) + + go func() { + pos := 0 + contentExpected := []string{ + "SockJS connect response", + "SockJS write", + "force reconnect", + } + + loop: + for { + _, p, err = conn.ReadMessage() + if err != nil { + break loop + } + + for { + if strings.Contains(string(p), contentExpected[pos]) { + pos++ + if pos >= len(contentExpected) { + close(doneCh) + break loop + } + } else { + break + } + } + } + }() + + waitWithTimeout(t, doneCh) +} + +func waitWithTimeout(t *testing.T, ch chan struct{}) { + t.Helper() + select { + case <-ch: + case <-time.After(3 * time.Second): + require.Fail(t, "timeout") + } +} + +func defaultNodeNoHandlers() *centrifuge.Node { + n, err := centrifuge.New(centrifuge.Config{ + LogLevel: centrifuge.LogLevelTrace, + LogHandler: func(entry centrifuge.LogEntry) {}, + }) + if err != nil { + panic(err) + } + err = n.Run() + if err != nil { + panic(err) + } + return n +} + +func TestSockjsTransportWrite(t *testing.T) { + node := defaultNodeNoHandlers() + defer func() { _ = node.Shutdown(context.Background()) }() + + node.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) { + require.Equal(t, event.Transport.Protocol(), centrifuge.ProtocolTypeJSON) + transport := event.Transport.(centrifuge.Transport) + // Write to transport directly - this is only valid for tests, in normal situation + // we write over client methods. + require.NoError(t, transport.Write([]byte("hello"))) + return centrifuge.ConnectReply{}, centrifuge.DisconnectForceNoReconnect + }) + + mux := http.NewServeMux() + mux.Handle("/connection/sockjs/", NewHandler(node, Config{ + HandlerPrefix: "/connection/sockjs", + })) + server := httptest.NewServer(mux) + defer server.Close() + + url := "ws" + server.URL[4:] + + conn, resp, err := websocket.DefaultDialer.Dial(url+"/connection/sockjs/220/fi0988475/websocket", nil) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NotNil(t, conn) + defer func() { _ = conn.Close() }() + _, p, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "o", string(p)) // open frame of SockJS protocol. + + connectRequest := &protocol.ConnectRequest{ + Token: "boom", + } + cmd := &protocol.Command{ + Id: 1, + Connect: connectRequest, + } + cmdBytes, _ := json.Marshal(cmd) + err = conn.WriteMessage(websocket.TextMessage, sockjsData(cmdBytes)) + require.NoError(t, err) + + _, p, err = conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "a[\"hello\"]", string(p)) +} + +func TestSockjsTransportWriteMany(t *testing.T) { + node := defaultNodeNoHandlers() + defer func() { _ = node.Shutdown(context.Background()) }() + + node.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) { + require.Equal(t, event.Transport.Protocol(), centrifuge.ProtocolTypeJSON) + transport := event.Transport.(centrifuge.Transport) + // Write to transport directly - this is only valid for tests, in normal situation + // we write over client methods. + require.NoError(t, transport.WriteMany([]byte("1"), []byte("22"))) + return centrifuge.ConnectReply{}, centrifuge.DisconnectForceNoReconnect + }) + + mux := http.NewServeMux() + mux.Handle("/connection/sockjs/", NewHandler(node, Config{ + HandlerPrefix: "/connection/sockjs", + })) + server := httptest.NewServer(mux) + defer server.Close() + + url := "ws" + server.URL[4:] + + conn, resp, err := websocket.DefaultDialer.Dial(url+"/connection/sockjs/220/fi0988475/websocket", nil) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NotNil(t, conn) + defer func() { _ = conn.Close() }() + _, p, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "o", string(p)) // open frame of SockJS protocol. + + connectRequest := &protocol.ConnectRequest{ + Token: "boom", + } + cmd := &protocol.Command{ + Id: 1, + Connect: connectRequest, + } + cmdBytes, _ := json.Marshal(cmd) + err = conn.WriteMessage(websocket.TextMessage, sockjsData(cmdBytes)) + require.NoError(t, err) + + _, p, err = conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "a[\"1\\n22\"]", string(p)) +} + +func TestSockjsHandlerURLParams(t *testing.T) { + n, _ := centrifuge.New(centrifuge.Config{}) + require.NoError(t, n.Run()) + defer func() { _ = n.Shutdown(context.Background()) }() + mux := http.NewServeMux() + + n.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) { + return centrifuge.ConnectReply{ + Credentials: ¢rifuge.Credentials{UserID: "user"}, + }, nil + }) + + doneCh := make(chan struct{}) + + n.OnConnect(func(client *centrifuge.Client) { + require.Equal(t, transportSockJS, client.Transport().Name()) + require.Equal(t, centrifuge.ProtocolTypeJSON, client.Transport().Protocol()) + require.Equal(t, centrifuge.ProtocolVersion2, client.Transport().ProtocolVersion()) + close(doneCh) + }) + + mux.Handle("/connection/sockjs/", NewHandler(n, Config{ + HandlerPrefix: "/connection/sockjs", + })) + server := httptest.NewServer(mux) + defer server.Close() + + url := "ws" + server.URL[4:] + + conn, resp, err := websocket.DefaultDialer.Dial(url+"/connection/sockjs/220/fi0988475/websocket?cf_protocol_version=v2", nil) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NotNil(t, conn) + defer func() { _ = conn.Close() }() + _, p, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "o", string(p)) // open frame of SockJS protocol. + + connectRequest := &protocol.ConnectRequest{ + Token: "boom", + } + cmd := &protocol.Command{ + Id: 1, + Connect: connectRequest, + } + cmdBytes, _ := json.Marshal(cmd) + err = conn.WriteMessage(websocket.TextMessage, sockjsData(cmdBytes)) + require.NoError(t, err) + + waitWithTimeout(t, doneCh) +} diff --git a/internal/sockjs/pool.go b/internal/sockjs/pool.go new file mode 100644 index 0000000000..0ce2b6a00a --- /dev/null +++ b/internal/sockjs/pool.go @@ -0,0 +1,45 @@ +package sockjs + +import ( + "bytes" + "strings" + "sync" +) + +var stringReaderPool sync.Pool + +// GetStringReader from pool. +func GetStringReader(data string) *strings.Reader { + r := bytesReaderPool.Get() + if r == nil { + return strings.NewReader(data) + } + reader := r.(*strings.Reader) + reader.Reset(data) + return reader +} + +// PutStringReader to pool. +func PutStringReader(reader *strings.Reader) { + reader.Reset("") + stringReaderPool.Put(reader) +} + +var bytesReaderPool sync.Pool + +// GetBytesReader from pool. +func GetBytesReader(data []byte) *bytes.Reader { + r := bytesReaderPool.Get() + if r == nil { + return bytes.NewReader(data) + } + reader := r.(*bytes.Reader) + reader.Reset(data) + return reader +} + +// PutBytesReader to pool. +func PutBytesReader(reader *bytes.Reader) { + reader.Reset(nil) + bytesReaderPool.Put(reader) +} diff --git a/internal/sockjs/pool_test.go b/internal/sockjs/pool_test.go new file mode 100644 index 0000000000..7055a694de --- /dev/null +++ b/internal/sockjs/pool_test.go @@ -0,0 +1,34 @@ +package sockjs + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringReaderPool(t *testing.T) { + r := GetStringReader("string1") + d, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("string1"), d) + PutStringReader(r) + r = GetStringReader("string2") + defer PutStringReader(r) + d, err = io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("string2"), d) +} + +func TestBytesReaderPool(t *testing.T) { + r := GetBytesReader([]byte("bytes1")) + d, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("bytes1"), d) + PutBytesReader(r) + r = GetBytesReader([]byte("bytes2")) + defer PutBytesReader(r) + d, err = io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("bytes2"), d) +} diff --git a/main.go b/main.go index 1152431961..489cdad6df 100644 --- a/main.go +++ b/main.go @@ -53,6 +53,7 @@ import ( "github.com/centrifugal/centrifugo/v5/internal/redisnatsbroker" "github.com/centrifugal/centrifugo/v5/internal/rule" "github.com/centrifugal/centrifugo/v5/internal/service" + "github.com/centrifugal/centrifugo/v5/internal/sockjs" "github.com/centrifugal/centrifugo/v5/internal/survey" "github.com/centrifugal/centrifugo/v5/internal/swaggerui" "github.com/centrifugal/centrifugo/v5/internal/telemetry" @@ -2487,9 +2488,9 @@ func uniGRPCHandlerConfig() unigrpc.Config { return unigrpc.Config{} } -func sockjsHandlerConfig() centrifuge.SockjsConfig { +func sockjsHandlerConfig() sockjs.Config { v := viper.GetViper() - cfg := centrifuge.SockjsConfig{} + cfg := sockjs.Config{} cfg.URL = v.GetString("sockjs_url") cfg.WebsocketReadBufferSize = v.GetInt("websocket_read_buffer_size") cfg.WebsocketWriteBufferSize = v.GetInt("websocket_write_buffer_size") @@ -2995,7 +2996,7 @@ func Mux(n *centrifuge.Node, ruleContainer *rule.Container, apiExecutor *api.Exe sockjsConfig := sockjsHandlerConfig() sockjsPrefix := strings.TrimRight(v.GetString("sockjs_handler_prefix"), "/") sockjsConfig.HandlerPrefix = sockjsPrefix - mux.Handle(sockjsPrefix+"/", connChain.Then(centrifuge.NewSockjsHandler(n, sockjsConfig))) + mux.Handle(sockjsPrefix+"/", connChain.Then(sockjs.NewHandler(n, sockjsConfig))) } if flags&HandlerUniWebsocket != 0 {