diff --git a/integration_tests/commands/async/setup.go b/integration_tests/commands/async/setup.go index c922c3eaa7..ec359f8f11 100644 --- a/integration_tests/commands/async/setup.go +++ b/integration_tests/commands/async/setup.go @@ -122,7 +122,7 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption var err error watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) gec := make(chan error) - shardManager := shard.NewShardManager(1, watchChan, gec, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, gec, opt.Logger) // Initialize the AsyncServer testServer := server.NewAsyncServer(shardManager, watchChan, opt.Logger) diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index b0c20bf288..a10526cb3d 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -103,7 +103,7 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) queryWatcherLocal := querymanager.NewQueryManager(opt.Logger) config.HTTPPort = opt.Port // Initialize the HTTPServer diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index 406f5fd525..10f718c77d 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -99,7 +99,7 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO // Initialize the WebsocketServer globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) config.WebsocketPort = opt.Port testServer := server.NewWebSocketServer(shardManager, watchChan, opt.Logger) diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index fd830c71ea..423eb68e00 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -2,8 +2,9 @@ package cmd import ( "fmt" - "github.com/dgryski/go-farm" "strings" + + "github.com/dgryski/go-farm" ) type DiceDBCmd struct { diff --git a/internal/eval/bloom_test.go b/internal/eval/bloom_test.go index 78f97fabeb..5a2577e70f 100644 --- a/internal/eval/bloom_test.go +++ b/internal/eval/bloom_test.go @@ -15,7 +15,7 @@ import ( ) func TestBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // This test only contains some basic checks for all the bloom filter // operations like BFINIT, BFADD, BFEXISTS. It assumes that the // functions called in the main function are working correctly and @@ -112,7 +112,7 @@ func TestBloomFilter(t *testing.T) { } func TestGetOrCreateBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Create a key and default opts key := "bf" opts, _ := newBloomOpts([]string{}, true) diff --git a/internal/eval/commands.go b/internal/eval/commands.go index e0f18544b7..98e787b886 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -1051,16 +1051,6 @@ var ( Arity: -4, KeySpecs: KeySpecs{BeginIndex: 1}, } - getWatchCmdMeta = DiceCmdMeta{ - Name: "GET.WATCH", - Info: `GET.WATCH key - Returns the value of the key and starts watching for changes in the key's value. Note that some update - deliveries may be missed in case of high write rate on the given key. However, the values being delivered will - always be monotonically consistent.`, - Arity: 2, - KeySpecs: KeySpecs{BeginIndex: 1}, - CmdEquivalent: "GET", - } geoAddCmdMeta = DiceCmdMeta{ Name: "GEOADD", Info: `Adds one or more members to a geospatial index. The key is created if it doesn't exist.`, diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index c99455823b..a068a95810 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -42,7 +42,7 @@ func setupTest(store *dstore.Store) *dstore.Store { } func TestEval(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) testEvalMSET(t, store) testEvalECHO(t, store) @@ -1108,7 +1108,7 @@ func testEvalJSONOBJLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -2747,13 +2747,13 @@ func runEvalTests(t *testing.T, tests map[string]evalTestCase, evalFunc func([]s func BenchmarkEvalMSET(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) } } func BenchmarkEvalHSET(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -2888,7 +2888,7 @@ func testEvalHKEYS(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHKEYS(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) @@ -2899,7 +2899,7 @@ func BenchmarkEvalHKEYS(b *testing.B) { } } func BenchmarkEvalPFCOUNT(b *testing.B) { - store := *dstore.NewStore(nil) + store := *dstore.NewStore(nil, nil) // Helper function to create and insert HLL objects createAndInsertHLL := func(key string, items []string) { @@ -3247,7 +3247,7 @@ func testEvalHLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalHLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("HashSize_%d", size), func(b *testing.B) { @@ -3489,7 +3489,7 @@ func testEvalTYPE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalTYPE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Define different types of objects to benchmark objectTypes := map[string]func(){ @@ -3680,7 +3680,7 @@ func testEvalJSONOBJKEYS(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJKEYS(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -3848,7 +3848,7 @@ func testEvalGETRANGE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalGETRANGE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("BENCHMARK_KEY", store.NewObj("Hello World", maxExDuration, object.ObjTypeString, object.ObjEncodingRaw)) inputs := []struct { @@ -3873,7 +3873,7 @@ func BenchmarkEvalGETRANGE(b *testing.B) { } func BenchmarkEvalHSETNX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSETNX([]string{"KEY", fmt.Sprintf("FIELD_%d", i/2), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -3935,7 +3935,7 @@ func testEvalHSETNX(t *testing.T, store *dstore.Store) { } func TestMSETConsistency(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) assert.Equal(t, "VAL", store.Get("KEY").Value) @@ -3943,7 +3943,7 @@ func TestMSETConsistency(t *testing.T) { } func BenchmarkEvalHINCRBY(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // creating new fields for i := 0; i < b.N; i++ { @@ -4193,7 +4193,7 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { } func BenchmarkEvalSETEX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -4368,7 +4368,7 @@ func testEvalINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("key1", store.NewObj("1", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) store.Put("key2", store.NewObj("1.2", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) @@ -4483,7 +4483,7 @@ func testEvalBITOP(t *testing.T, store *dstore.Store) { } func BenchmarkEvalBITOP(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setup initial data for benchmarking store.Put("key1", store.NewObj(&ByteArray{data: []byte{0x01, 0x02, 0xff}}, maxExDuration, object.ObjTypeByteArray, object.ObjEncodingByteArray)) @@ -4777,7 +4777,7 @@ func testEvalAPPEND(t *testing.T, store *dstore.Store) { } func BenchmarkEvalAPPEND(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalAPPEND([]string{"key", fmt.Sprintf("val_%d", i)}, store) } @@ -5250,7 +5250,7 @@ func testEvalHINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setting initial fields with some values store.Put("key1", store.NewObj(HashMap{"field1": "1.0", "field2": "1.2"}, maxExDuration, object.ObjTypeHashMap, object.ObjEncodingHashMap)) diff --git a/internal/eval/main_test.go b/internal/eval/main_test.go index 7b79a58dc9..04d622233f 100644 --- a/internal/eval/main_test.go +++ b/internal/eval/main_test.go @@ -13,7 +13,7 @@ func TestMain(m *testing.M) { l := logger.New(logger.Opts{WithTimestamp: false}) slog.SetDefault(l) - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.ResetStore() exitCode := m.Run() diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 32254ff1ef..3bf4d06fd6 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - dstore "github.com/dicedb/dice/internal/store" - "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "sync" @@ -13,6 +11,9 @@ import ( "syscall" "time" + dstore "github.com/dicedb/dice/internal/store" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio/iohandler/netconn" respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" diff --git a/internal/store/store.go b/internal/store/store.go index e564329eed..1b8b607fd3 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -331,7 +331,7 @@ func (store *Store) notifyQueryManager(k, operation string, obj object.Obj) { store.queryWatchChan <- QueryWatchEvent{k, operation, obj} } -func (store *Store) notifyWatchManager(cmd string, affectedKey string) { +func (store *Store) notifyWatchManager(cmd, affectedKey string) { store.cmdWatchChan <- CmdWatchEvent{cmd, affectedKey} } diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 9bf6e90aa1..289dd47535 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -2,24 +2,25 @@ package watchmanager import ( "context" - "github.com/dicedb/dice/internal/cmd" - dstore "github.com/dicedb/dice/internal/store" "log/slog" "sync" + + "github.com/dicedb/dice/internal/cmd" + dstore "github.com/dicedb/dice/internal/store" ) type ( WatchSubscription struct { - Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. - AdhocReqChan chan *cmd.RedisCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. - WatchCmd *cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. - Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. + AdhocReqChan chan *cmd.DiceDBCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. + WatchCmd *cmd.DiceDBCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. } Manager struct { - querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] - tcpSubscriptionMap map[uint32]map[chan *cmd.RedisCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] - fingerprintCmdMap map[uint32]*cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd + querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[chan *cmd.DiceDBCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] + fingerprintCmdMap map[uint32]*cmd.DiceDBCmd // fingerprintCmdMap is a map of fingerprint -> DiceDBCmd logger *slog.Logger } ) @@ -36,8 +37,8 @@ func NewManager(logger *slog.Logger) *Manager { CmdWatchSubscriptionChan = make(chan WatchSubscription) return &Manager{ querySubscriptionMap: make(map[string]map[uint32]struct{}), - tcpSubscriptionMap: make(map[uint32]map[chan *cmd.RedisCmd]struct{}), - fingerprintCmdMap: make(map[uint32]*cmd.RedisCmd), + tcpSubscriptionMap: make(map[uint32]map[chan *cmd.DiceDBCmd]struct{}), + fingerprintCmdMap: make(map[uint32]*cmd.DiceDBCmd), logger: logger, } } @@ -85,12 +86,12 @@ func (m *Manager) handleSubscription(sub WatchSubscription) { } m.querySubscriptionMap[key][fingerprint] = struct{}{} - // Add RedisCmd to fingerprintCmdMap + // Add DiceDBCmd to fingerprintCmdMap m.fingerprintCmdMap[fingerprint] = sub.WatchCmd // Add client channel to tcpSubscriptionMap if _, exists := m.tcpSubscriptionMap[fingerprint]; !exists { - m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.RedisCmd]struct{}) + m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.DiceDBCmd]struct{}) } m.tcpSubscriptionMap[fingerprint][sub.AdhocReqChan] = struct{}{} } @@ -116,8 +117,8 @@ func (m *Manager) handleUnsubscription(sub WatchSubscription) { } // Remove fingerprint from querySubscriptionMap - if redisCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { - key := redisCmd.GetKey() + if diceDBCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { + key := diceDBCmd.GetKey() if fingerprints, ok := m.querySubscriptionMap[key]; ok { // Remove the fingerprint from the list of fingerprints listening to this key delete(fingerprints, fingerprint) @@ -161,7 +162,7 @@ func (m *Manager) handleWatchEvent(event dstore.CmdWatchEvent) { } // notifyClients sends cmd to all clients listening to this fingerprint, so that they can execute it. -func (m *Manager) notifyClients(fingerprint uint32, cmd *cmd.RedisCmd) { +func (m *Manager) notifyClients(fingerprint uint32, diceDBCmd *cmd.DiceDBCmd) { clients, exists := m.tcpSubscriptionMap[fingerprint] if !exists { m.logger.Warn("No clients found for fingerprint", @@ -170,6 +171,6 @@ func (m *Manager) notifyClients(fingerprint uint32, cmd *cmd.RedisCmd) { } for clientChan := range clients { - clientChan <- cmd + clientChan <- diceDBCmd } } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index fc7ae0203d..bdb42d05da 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -4,12 +4,13 @@ import ( "context" "errors" "fmt" - "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "syscall" "time" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio" @@ -35,7 +36,7 @@ type BaseWorker struct { parser requestparser.Parser shardManager *shard.ShardManager respChan chan *ops.StoreResponse - adhocReqChan chan *cmd.RedisCmd + adhocReqChan chan *cmd.DiceDBCmd Session *auth.Session globalErrorChan chan error logger *slog.Logger @@ -54,7 +55,7 @@ func NewWorker(wid string, respChan chan *ops.StoreResponse, respChan: respChan, logger: logger, Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.RedisCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. + adhocReqChan: make(chan *cmd.DiceDBCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. } } @@ -96,8 +97,12 @@ func (w *BaseWorker) Start(ctx context.Context) error { } return fmt.Errorf("error writing response: %w", err) case cmdReq := <-w.adhocReqChan: - // Handle adhoc requests of RedisCmd - w.executeCommandHandler(errChan, nil, ctx, []*cmd.RedisCmd{cmdReq}, true) + // Handle adhoc requests of DiceDBCmd + go func() { + execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests + defer cancel() + w.executeCommandHandler(execCtx, errChan, []*cmd.DiceDBCmd{cmdReq}, true) + }() case data := <-dataChan: cmds, err := w.parser.Parse(data) if err != nil { @@ -138,17 +143,17 @@ func (w *BaseWorker) Start(ctx context.Context) error { func(errChan chan error) { execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests defer cancel() - w.executeCommandHandler(errChan, err, execCtx, cmds, false) + w.executeCommandHandler(execCtx, errChan, cmds, false) }(errChan) - case err := <-errChan: + case err := <-readErrChan: w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err } } } -func (w *BaseWorker) executeCommandHandler(errChan chan error, err error, execCtx context.Context, cmds []*cmd.RedisCmd, isWatchNotification bool) { - err = w.executeCommand(execCtx, cmds[0], isWatchNotification) +func (w *BaseWorker) executeCommandHandler(execCtx context.Context, errChan chan error, cmds []*cmd.DiceDBCmd, isWatchNotification bool) { + err := w.executeCommand(execCtx, cmds[0], isWatchNotification) if err != nil { w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { diff --git a/main.go b/main.go index 0fae87dc84..5cd217b5c6 100644 --- a/main.go +++ b/main.go @@ -92,7 +92,7 @@ func main() { // Initialize the AsyncServer server // Find a port and bind it if !config.EnableMultiThreading { - asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, cmdWatchChan, logr) + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, logr) if err := asyncServer.FindPortAndBind(); err != nil { cancel() logr.Error("Error finding and binding port", slog.Any("error", err)) @@ -155,7 +155,7 @@ func main() { } else { workerManager := worker.NewWorkerManager(config.DiceConfig.Server.MaxClients, shardManager) // Initialize the RESP Server - respServer := resp.NewServer(shardManager, workerManager, serverErrCh, logr) + respServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, serverErrCh, logr) serverWg.Add(1) go func() { defer serverWg.Done()