Skip to content

Commit

Permalink
fix watch functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
JyotinderSingh committed Nov 8, 2024
1 parent 1db7b4c commit e32bf3c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 39 deletions.
10 changes: 5 additions & 5 deletions integration_tests/commands/resp/getwatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestGETWATCH(t *testing.T) {
}
assert.Equal(t, 3, len(castedValue))
assert.Equal(t, "GET", castedValue[0])
assert.Equal(t, "1768826704", castedValue[1])
assert.Equal(t, "2714318480", castedValue[1])
assert.Equal(t, tc.val, castedValue[2])
}
}
Expand All @@ -103,7 +103,7 @@ func TestGETWATCHWithSDK(t *testing.T) {
firstMsg, err := watch.Watch(context.Background(), "GET", getWatchKey)
assert.Nil(t, err)
assert.Equal(t, firstMsg.Command, "GET")
assert.Equal(t, firstMsg.Fingerprint, "1768826704")
assert.Equal(t, firstMsg.Fingerprint, "2714318480")
channels[i] = watch.Channel()
}

Expand All @@ -114,7 +114,7 @@ func TestGETWATCHWithSDK(t *testing.T) {
for _, channel := range channels {
v := <-channel
assert.Equal(t, "GET", v.Command) // command
assert.Equal(t, "1768826704", v.Fingerprint) // Fingerprint
assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint
assert.Equal(t, tc.val, v.Data.(string)) // data
}
}
Expand All @@ -134,7 +134,7 @@ func TestGETWATCHWithSDK2(t *testing.T) {
firstMsg, err := watch.GetWatch(context.Background(), getWatchKey)
assert.Nil(t, err)
assert.Equal(t, firstMsg.Command, "GET")
assert.Equal(t, firstMsg.Fingerprint, "1768826704")
assert.Equal(t, firstMsg.Fingerprint, "2714318480")
channels[i] = watch.Channel()
}

Expand All @@ -145,7 +145,7 @@ func TestGETWATCHWithSDK2(t *testing.T) {
for _, channel := range channels {
v := <-channel
assert.Equal(t, "GET", v.Command) // command
assert.Equal(t, "1768826704", v.Fingerprint) // Fingerprint
assert.Equal(t, "2714318480", v.Fingerprint) // Fingerprint
assert.Equal(t, tc.val, v.Data.(string)) // data
}
}
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/commands/resp/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) {
config.DiceConfig.AsyncServer.Port = 9739
}

queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Memory.KeysLimit)
cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Memory.KeysLimit)
queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Performance.WatchChanBufSize)
cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Performance.WatchChanBufSize)
cmdWatchSubscriptionChan := make(chan watchmanager.WatchSubscription)
gec := make(chan error)
shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec)
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/commands/resp/zrangewatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestZRANGEWATCH(t *testing.T) {
}
assert.Equal(t, 3, len(castedValue))
assert.Equal(t, "ZRANGE", castedValue[0])
assert.Equal(t, "2491069200", castedValue[1])
assert.Equal(t, "1178068413", castedValue[1])
assert.DeepEqual(t, tc.result, castedValue[2])
}
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestZRANGEWATCHWithSDK(t *testing.T) {
firstMsg, err := watch.Watch(context.Background(), "ZRANGE", zrangeWatchKey, "0", "-1", "REV", "WITHSCORES")
assert.NilError(t, err)
assert.Equal(t, firstMsg.Command, "ZRANGE")
assert.Equal(t, firstMsg.Fingerprint, "2491069200")
assert.Equal(t, firstMsg.Fingerprint, "1178068413")
channels[i] = watch.Channel()
}

Expand All @@ -139,7 +139,7 @@ func TestZRANGEWATCHWithSDK(t *testing.T) {
v := <-channel

assert.Equal(t, "ZRANGE", v.Command) // command
assert.Equal(t, "2491069200", v.Fingerprint) // Fingerprint
assert.Equal(t, "1178068413", v.Fingerprint) // Fingerprint
assert.DeepEqual(t, tc.result, v.Data) // data
}
}
Expand All @@ -158,7 +158,7 @@ func TestZRANGEWATCHWithSDK2(t *testing.T) {
firstMsg, err := conn.ZRangeWatch(context.Background(), zrangeWatchKey, "0", "-1", "REV", "WITHSCORES")
assert.NilError(t, err)
assert.Equal(t, firstMsg.Command, "ZRANGE")
assert.Equal(t, firstMsg.Fingerprint, "2491069200")
assert.Equal(t, firstMsg.Fingerprint, "1178068413")
channels[i] = conn.Channel()
}

Expand All @@ -173,7 +173,7 @@ func TestZRANGEWATCHWithSDK2(t *testing.T) {
v := <-channel

assert.Equal(t, "ZRANGE", v.Command)
assert.Equal(t, "2491069200", v.Fingerprint)
assert.Equal(t, "1178068413", v.Fingerprint)
assert.DeepEqual(t, tc.result, v.Data)
}
}
Expand Down
23 changes: 11 additions & 12 deletions internal/server/resp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,22 @@ type Server struct {
shardManager *shard.ShardManager
watchManager *watchmanager.Manager
cmdWatchSubscriptionChan chan watchmanager.WatchSubscription
cmdWatchChan chan dstore.CmdWatchEvent
globalErrorChan chan error
wl wal.AbstractWAL
}

func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager,
cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server {
return &Server{
Host: config.DiceConfig.AsyncServer.Addr,
Port: config.DiceConfig.AsyncServer.Port,
connBacklogSize: DefaultConnBacklogSize,
workerManager: workerManager,
shardManager: shardManager,
watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan),
cmdWatchChan: cmdWatchChan,
globalErrorChan: globalErrChan,
wl: wl,
Host: config.DiceConfig.AsyncServer.Addr,
Port: config.DiceConfig.AsyncServer.Port,
connBacklogSize: DefaultConnBacklogSize,
workerManager: workerManager,
shardManager: shardManager,
watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan, cmdWatchChan),
cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
globalErrorChan: globalErrChan,
wl: wl,
}
}

Expand All @@ -81,11 +80,11 @@ func (s *Server) Run(ctx context.Context) (err error) {
errChan := make(chan error, 1)
wg := &sync.WaitGroup{}

if s.cmdWatchChan != nil {
if s.cmdWatchSubscriptionChan != nil {
wg.Add(1)
go func() {
defer wg.Done()
s.watchManager.Run(ctx, s.cmdWatchChan)
s.watchManager.Run(ctx)
}()
}

Expand Down
12 changes: 7 additions & 5 deletions internal/watchmanager/watch_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type (
tcpSubscriptionMap map[uint32]map[chan *cmd.DiceDBCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...]
fingerprintCmdMap map[uint32]*cmd.DiceDBCmd // fingerprintCmdMap is a map of fingerprint -> DiceDBCmd
cmdWatchSubscriptionChan chan WatchSubscription // cmdWatchSubscriptionChan is the channel to send/receive watch subscription requests.
cmdWatchChan chan dstore.CmdWatchEvent // cmdWatchChan is the channel to send/receive watch events.
}
)

Expand All @@ -34,31 +35,32 @@ var (
}
)

func NewManager(cmdWatchSubscriptionChan chan WatchSubscription) *Manager {
func NewManager(cmdWatchSubscriptionChan chan WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent) *Manager {
return &Manager{
querySubscriptionMap: make(map[string]map[uint32]struct{}),
tcpSubscriptionMap: make(map[uint32]map[chan *cmd.DiceDBCmd]struct{}),
fingerprintCmdMap: make(map[uint32]*cmd.DiceDBCmd),
cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
cmdWatchChan: cmdWatchChan,
}
}

// Run starts the watch manager, listening for subscription requests and events
func (m *Manager) Run(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) {
func (m *Manager) Run(ctx context.Context) {
var wg sync.WaitGroup

wg.Add(1)

go func() {
defer wg.Done()
m.listenForEvents(ctx, cmdWatchChan)
m.listenForEvents(ctx)
}()

<-ctx.Done()
wg.Wait()
}

func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) {
func (m *Manager) listenForEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
Expand All @@ -69,7 +71,7 @@ func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore.
} else {
m.handleUnsubscription(sub)
}
case watchEvent := <-cmdWatchChan:
case watchEvent := <-m.cmdWatchChan:
m.handleWatchEvent(watchEvent)
}
}
Expand Down
21 changes: 11 additions & 10 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ func NewWorker(wid string, responseChan, preprocessingChan chan *ops.StoreRespon
ioHandler iohandler.IOHandler, parser requestparser.Parser,
shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseWorker {
return &BaseWorker{
id: wid,
ioHandler: ioHandler,
parser: parser,
shardManager: shardManager,
globalErrorChan: gec,
responseChan: responseChan,
preprocessingChan: preprocessingChan,
Session: auth.NewSession(),
adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize),
wl: wl,
id: wid,
ioHandler: ioHandler,
parser: parser,
shardManager: shardManager,
globalErrorChan: gec,
responseChan: responseChan,
preprocessingChan: preprocessingChan,
Session: auth.NewSession(),
adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize),
cmdWatchSubscriptionChan: cmdWatchSubscriptionChan,
wl: wl,
}
}

Expand Down

0 comments on commit e32bf3c

Please sign in to comment.