From e32bf3c6141016eb8162e03e8c5f0b2cb84d6a41 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sat, 9 Nov 2024 03:05:55 +0530 Subject: [PATCH] fix watch functionality --- .../commands/resp/getwatch_test.go | 10 ++++---- integration_tests/commands/resp/setup.go | 4 ++-- .../commands/resp/zrangewatch_test.go | 10 ++++---- internal/server/resp/server.go | 23 +++++++++---------- internal/watchmanager/watch_manager.go | 12 ++++++---- internal/worker/worker.go | 21 +++++++++-------- 6 files changed, 41 insertions(+), 39 deletions(-) diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go index 537644b8f..739a26819 100644 --- a/integration_tests/commands/resp/getwatch_test.go +++ b/integration_tests/commands/resp/getwatch_test.go @@ -83,7 +83,7 @@ func TestGETWATCH(t *testing.T) { } assert.Equal(t, 3, len(castedValue)) assert.Equal(t, "GET", castedValue[0]) - assert.Equal(t, "1768826704", castedValue[1]) + assert.Equal(t, "2714318480", castedValue[1]) assert.Equal(t, tc.val, castedValue[2]) } } @@ -103,7 +103,7 @@ func TestGETWATCHWithSDK(t *testing.T) { firstMsg, err := watch.Watch(context.Background(), "GET", getWatchKey) assert.Nil(t, err) assert.Equal(t, firstMsg.Command, "GET") - assert.Equal(t, firstMsg.Fingerprint, "1768826704") + assert.Equal(t, firstMsg.Fingerprint, "2714318480") channels[i] = watch.Channel() } @@ -114,7 +114,7 @@ func TestGETWATCHWithSDK(t *testing.T) { for _, channel := range channels { v := <-channel assert.Equal(t, "GET", v.Command) // command - assert.Equal(t, "1768826704", v.Fingerprint) // Fingerprint + assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint assert.Equal(t, tc.val, v.Data.(string)) // data } } @@ -134,7 +134,7 @@ func TestGETWATCHWithSDK2(t *testing.T) { firstMsg, err := watch.GetWatch(context.Background(), getWatchKey) assert.Nil(t, err) assert.Equal(t, firstMsg.Command, "GET") - assert.Equal(t, firstMsg.Fingerprint, "1768826704") + assert.Equal(t, firstMsg.Fingerprint, "2714318480") channels[i] = watch.Channel() } @@ -145,7 +145,7 @@ func TestGETWATCHWithSDK2(t *testing.T) { for _, channel := range channels { v := <-channel assert.Equal(t, "GET", v.Command) // command - assert.Equal(t, "1768826704", v.Fingerprint) // Fingerprint + assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint assert.Equal(t, tc.val, v.Data.(string)) // data } } diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index a1293a8c3..ac7d906a8 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -121,8 +121,8 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { config.DiceConfig.AsyncServer.Port = 9739 } - queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Memory.KeysLimit) - cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Memory.KeysLimit) + queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) + cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) cmdWatchSubscriptionChan := make(chan watchmanager.WatchSubscription) gec := make(chan error) shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec) diff --git a/integration_tests/commands/resp/zrangewatch_test.go b/integration_tests/commands/resp/zrangewatch_test.go index 8f56eb3eb..d7ba06cfc 100644 --- a/integration_tests/commands/resp/zrangewatch_test.go +++ b/integration_tests/commands/resp/zrangewatch_test.go @@ -76,7 +76,7 @@ func TestZRANGEWATCH(t *testing.T) { } assert.Equal(t, 3, len(castedValue)) assert.Equal(t, "ZRANGE", castedValue[0]) - assert.Equal(t, "2491069200", castedValue[1]) + assert.Equal(t, "1178068413", castedValue[1]) assert.DeepEqual(t, tc.result, castedValue[2]) } } @@ -124,7 +124,7 @@ func TestZRANGEWATCHWithSDK(t *testing.T) { firstMsg, err := watch.Watch(context.Background(), "ZRANGE", zrangeWatchKey, "0", "-1", "REV", "WITHSCORES") assert.NilError(t, err) assert.Equal(t, firstMsg.Command, "ZRANGE") - assert.Equal(t, firstMsg.Fingerprint, "2491069200") + assert.Equal(t, firstMsg.Fingerprint, "1178068413") channels[i] = watch.Channel() } @@ -139,7 +139,7 @@ func TestZRANGEWATCHWithSDK(t *testing.T) { v := <-channel assert.Equal(t, "ZRANGE", v.Command) // command - assert.Equal(t, "2491069200", v.Fingerprint) // Fingerprint + assert.Equal(t, "1178068413", v.Fingerprint) // Fingerprint assert.DeepEqual(t, tc.result, v.Data) // data } } @@ -158,7 +158,7 @@ func TestZRANGEWATCHWithSDK2(t *testing.T) { firstMsg, err := conn.ZRangeWatch(context.Background(), zrangeWatchKey, "0", "-1", "REV", "WITHSCORES") assert.NilError(t, err) assert.Equal(t, firstMsg.Command, "ZRANGE") - assert.Equal(t, firstMsg.Fingerprint, "2491069200") + assert.Equal(t, firstMsg.Fingerprint, "1178068413") channels[i] = conn.Channel() } @@ -173,7 +173,7 @@ func TestZRANGEWATCHWithSDK2(t *testing.T) { v := <-channel assert.Equal(t, "ZRANGE", v.Command) - assert.Equal(t, "2491069200", v.Fingerprint) + assert.Equal(t, "1178068413", v.Fingerprint) assert.DeepEqual(t, tc.result, v.Data) } } diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index b0834e2a8..17244d076 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -48,7 +48,6 @@ type Server struct { shardManager *shard.ShardManager watchManager *watchmanager.Manager cmdWatchSubscriptionChan chan watchmanager.WatchSubscription - cmdWatchChan chan dstore.CmdWatchEvent globalErrorChan chan error wl wal.AbstractWAL } @@ -56,15 +55,15 @@ type Server struct { func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager, cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server { return &Server{ - Host: config.DiceConfig.AsyncServer.Addr, - Port: config.DiceConfig.AsyncServer.Port, - connBacklogSize: DefaultConnBacklogSize, - workerManager: workerManager, - shardManager: shardManager, - watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan), - cmdWatchChan: cmdWatchChan, - globalErrorChan: globalErrChan, - wl: wl, + Host: config.DiceConfig.AsyncServer.Addr, + Port: config.DiceConfig.AsyncServer.Port, + connBacklogSize: DefaultConnBacklogSize, + workerManager: workerManager, + shardManager: shardManager, + watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan, cmdWatchChan), + cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, + globalErrorChan: globalErrChan, + wl: wl, } } @@ -81,11 +80,11 @@ func (s *Server) Run(ctx context.Context) (err error) { errChan := make(chan error, 1) wg := &sync.WaitGroup{} - if s.cmdWatchChan != nil { + if s.cmdWatchSubscriptionChan != nil { wg.Add(1) go func() { defer wg.Done() - s.watchManager.Run(ctx, s.cmdWatchChan) + s.watchManager.Run(ctx) }() } diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 12451c654..18fe5369f 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -22,6 +22,7 @@ type ( tcpSubscriptionMap map[uint32]map[chan *cmd.DiceDBCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] fingerprintCmdMap map[uint32]*cmd.DiceDBCmd // fingerprintCmdMap is a map of fingerprint -> DiceDBCmd cmdWatchSubscriptionChan chan WatchSubscription // cmdWatchSubscriptionChan is the channel to send/receive watch subscription requests. + cmdWatchChan chan dstore.CmdWatchEvent // cmdWatchChan is the channel to send/receive watch events. } ) @@ -34,31 +35,32 @@ var ( } ) -func NewManager(cmdWatchSubscriptionChan chan WatchSubscription) *Manager { +func NewManager(cmdWatchSubscriptionChan chan WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent) *Manager { return &Manager{ querySubscriptionMap: make(map[string]map[uint32]struct{}), tcpSubscriptionMap: make(map[uint32]map[chan *cmd.DiceDBCmd]struct{}), fingerprintCmdMap: make(map[uint32]*cmd.DiceDBCmd), cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, + cmdWatchChan: cmdWatchChan, } } // Run starts the watch manager, listening for subscription requests and events -func (m *Manager) Run(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { +func (m *Manager) Run(ctx context.Context) { var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - m.listenForEvents(ctx, cmdWatchChan) + m.listenForEvents(ctx) }() <-ctx.Done() wg.Wait() } -func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { +func (m *Manager) listenForEvents(ctx context.Context) { for { select { case <-ctx.Done(): @@ -69,7 +71,7 @@ func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore. } else { m.handleUnsubscription(sub) } - case watchEvent := <-cmdWatchChan: + case watchEvent := <-m.cmdWatchChan: m.handleWatchEvent(watchEvent) } } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index f9ce02b04..a8f03a999 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -58,16 +58,17 @@ func NewWorker(wid string, responseChan, preprocessingChan chan *ops.StoreRespon ioHandler iohandler.IOHandler, parser requestparser.Parser, shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseWorker { return &BaseWorker{ - id: wid, - ioHandler: ioHandler, - parser: parser, - shardManager: shardManager, - globalErrorChan: gec, - responseChan: responseChan, - preprocessingChan: preprocessingChan, - Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), - wl: wl, + id: wid, + ioHandler: ioHandler, + parser: parser, + shardManager: shardManager, + globalErrorChan: gec, + responseChan: responseChan, + preprocessingChan: preprocessingChan, + Session: auth.NewSession(), + adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), + cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, + wl: wl, } }