diff --git a/integration_tests/commands/async/getex_test.go b/integration_tests/commands/async/getex_test.go index a02597dad..8266cc671 100644 --- a/integration_tests/commands/async/getex_test.go +++ b/integration_tests/commands/async/getex_test.go @@ -15,8 +15,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/json_test.go b/integration_tests/commands/async/json_test.go index a92e60b65..59f5cbc85 100644 --- a/integration_tests/commands/async/json_test.go +++ b/integration_tests/commands/async/json_test.go @@ -862,8 +862,8 @@ func TestJsonNummultby(t *testing.T) { invalidArgMessage := "ERR wrong number of arguments for 'json.nummultby' command" testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string }{ @@ -1021,9 +1021,9 @@ func TestJSONNumIncrBy(t *testing.T) { defer conn.Close() invalidArgMessage := "ERR wrong number of arguments for 'json.numincrby' command" testCases := []struct { - name string - setupData string - commands []string + name string + setupData string + commands []string expected []interface{} assertType []string cleanUp []string diff --git a/integration_tests/commands/async/qwatch_test.go b/integration_tests/commands/async/qwatch_test.go index 3b2aefda5..82ea6a878 100644 --- a/integration_tests/commands/async/qwatch_test.go +++ b/integration_tests/commands/async/qwatch_test.go @@ -84,7 +84,7 @@ func setupQWATCHTest(t *testing.T) (net.Conn, []net.Conn, func()) { subscribers := []net.Conn{getLocalConnection(), getLocalConnection(), getLocalConnection()} cleanup := func() { - cleanupKeys(publisher) + cleanupQWATCHKeys(publisher) if err := publisher.Close(); err != nil { t.Errorf("Error closing publisher connection: %v", err) } @@ -342,7 +342,7 @@ func verifyJSONUpdates(t *testing.T, rp *clientio.RESPParser, tc JSONTestCase) { } } -func cleanupKeys(publisher net.Conn) { +func cleanupQWATCHKeys(publisher net.Conn) { for _, tc := range qWatchTestCases { FireCommand(publisher, fmt.Sprintf("DEL %s:%d", tc.key, tc.userID)) } diff --git a/integration_tests/commands/async/set_data_cmd_test.go b/integration_tests/commands/async/set_data_cmd_test.go index d6ffaa4f6..cb03863d8 100644 --- a/integration_tests/commands/async/set_data_cmd_test.go +++ b/integration_tests/commands/async/set_data_cmd_test.go @@ -30,8 +30,8 @@ func TestSetDataCommand(t *testing.T) { defer conn.Close() testCases := []struct { - name string - cmd []string + name string + cmd []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/setup.go b/integration_tests/commands/async/setup.go index 939db0bf1..1d7d07c76 100644 --- a/integration_tests/commands/async/setup.go +++ b/integration_tests/commands/async/setup.go @@ -122,7 +122,7 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption var err error watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.WatchChanBufSize) gec := make(chan error) - shardManager := shard.NewShardManager(1, watchChan, gec, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, gec, opt.Logger) // Initialize the AsyncServer testServer := server.NewAsyncServer(shardManager, watchChan, opt.Logger) diff --git a/integration_tests/commands/async/touch_test.go b/integration_tests/commands/async/touch_test.go index b1c7c98d9..bf51f5088 100644 --- a/integration_tests/commands/async/touch_test.go +++ b/integration_tests/commands/async/touch_test.go @@ -12,8 +12,8 @@ func TestTouch(t *testing.T) { defer conn.Close() testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/http/getex_test.go b/integration_tests/commands/http/getex_test.go index 9bbd42fee..e93552755 100644 --- a/integration_tests/commands/http/getex_test.go +++ b/integration_tests/commands/http/getex_test.go @@ -14,8 +14,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []HTTPCommand + name string + commands []HTTPCommand expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index e0ff654a3..0fc1a8deb 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -103,7 +103,7 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.WatchChanBufSize) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) queryWatcherLocal := querymanager.NewQueryManager(opt.Logger) config.HTTPPort = opt.Port // Initialize the HTTPServer diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go new file mode 100644 index 000000000..b5d5e6aef --- /dev/null +++ b/integration_tests/commands/resp/getwatch_test.go @@ -0,0 +1,80 @@ +package resp + +import ( + "fmt" + "github.com/dicedb/dice/internal/clientio" + "gotest.tools/v3/assert" + "net" + "testing" + "time" +) + +const getWatchKey = "getwatchkey" + +type getWatchTestCase struct { + key string + val string +} + +var getWatchTestCases = []getWatchTestCase{ + {getWatchKey, "value1"}, + {getWatchKey, "value2"}, + {getWatchKey, "value3"}, + {getWatchKey, "value4"}, +} + +func TestGETWATCH(t *testing.T) { + publisher := getLocalConnection() + subscribers := []net.Conn{getLocalConnection(), getLocalConnection(), getLocalConnection()} + + defer func() { + if err := publisher.Close(); err != nil { + t.Errorf("Error closing publisher connection: %v", err) + } + for _, sub := range subscribers { + //FireCommand(sub, fmt.Sprintf("GET.UNWATCH %s", fingerprint)) + time.Sleep(100 * time.Millisecond) + if err := sub.Close(); err != nil { + t.Errorf("Error closing subscriber connection: %v", err) + } + } + }() + + // Fire a SET command to set a key + res := FireCommand(publisher, fmt.Sprintf("SET %s %s", getWatchKey, "value")) + assert.Equal(t, "OK", res) + + respParsers := make([]*clientio.RESPParser, len(subscribers)) + for i, subscriber := range subscribers { + rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf("GET.WATCH %s", getWatchKey)) + assert.Assert(t, rp != nil) + respParsers[i] = rp + + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + } + + // Fire updates to the key using the publisher, then check if the subscribers receive the updates in the push-response form (i.e. array of three elements, with third element being the value) + for _, tc := range getWatchTestCases { + res := FireCommand(publisher, fmt.Sprintf("SET %s %s", tc.key, tc.val)) + assert.Equal(t, "OK", res) + + for _, rp := range respParsers { + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + assert.Equal(t, "GET", castedValue[0]) + assert.Equal(t, "1768826704", castedValue[1]) + assert.Equal(t, tc.val, castedValue[2]) + } + } +} diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index f087ec939..db579d94a 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -122,12 +122,13 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { config.DiceConfig.Server.Port = 9739 } - watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Server.KeysLimit) gec := make(chan error) - shardManager := shard.NewShardManager(1, watchChan, gec, logr) + shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec, logr) workerManager := worker.NewWorkerManager(20000, shardManager) - // Initialize the REST Server - testServer := resp.NewServer(shardManager, workerManager, gec, logr) + // Initialize the RESP Server + testServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, gec, logr) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.Server.Port) diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index d706a14c7..0947d9997 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -99,7 +99,7 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO // Initialize the WebsocketServer globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.WatchChanBufSize) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) config.WebsocketPort = opt.Port testServer := server.NewWebSocketServer(shardManager, watchChan, opt.Logger) diff --git a/internal/clientio/push_response.go b/internal/clientio/push_response.go index cb0c1ab20..94593742f 100644 --- a/internal/clientio/push_response.go +++ b/internal/clientio/push_response.go @@ -1,16 +1,18 @@ package clientio -import ( - "github.com/dicedb/dice/internal/sql" +const ( + ResponseTypeRegular = iota + ResponseTypePush ) // CreatePushResponse creates a push response. Push responses refer to messages that the server sends to clients without // the client explicitly requesting them. These are typically seen in scenarios where the client has subscribed to some -// kind of event or data feed and is notified in real-time when changes occur -func CreatePushResponse(query *sql.DSQLQuery, result *[]sql.QueryResultRow) (response []interface{}) { +// kind of event or data feed and is notified in real-time when changes occur. +// `key` is the unique key that identifies the push response. +func CreatePushResponse[T any](cmd, key string, result T) (response []interface{}) { response = make([]interface{}, 3) - response[0] = sql.Qwatch - response[1] = query.String() - response[2] = *result + response[0] = cmd + response[1] = key + response[2] = result return } diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index 68309737c..423eb68e0 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -1,5 +1,12 @@ package cmd +import ( + "fmt" + "strings" + + "github.com/dgryski/go-farm" +) + type DiceDBCmd struct { RequestID uint32 Cmd string @@ -10,3 +17,17 @@ type RedisCmds struct { Cmds []*DiceDBCmd RequestID uint32 } + +// GetFingerprint returns a 32-bit fingerprint of the command and its arguments. +func (cmd *DiceDBCmd) GetFingerprint() uint32 { + return farm.Fingerprint32([]byte(fmt.Sprintf("%s-%s", cmd.Cmd, strings.Join(cmd.Args, " ")))) +} + +// GetKey Returns the key which the command operates on. +// +// TODO: This is a naive implementation which assumes that the first argument is the key. +// This is not true for all commands, however, for now this is only used by the watch manager, +// which as of now only supports a small subset of commands (all of which fit this implementation). +func (cmd *DiceDBCmd) GetKey() string { + return cmd.Args[0] +} diff --git a/internal/comm/client.go b/internal/comm/client.go index fe6240f32..c24736c77 100644 --- a/internal/comm/client.go +++ b/internal/comm/client.go @@ -8,6 +8,12 @@ import ( "github.com/dicedb/dice/internal/cmd" ) +type CmdWatchResponse struct { + ClientIdentifierID uint32 + Result interface{} + Error error +} + type QwatchResponse struct { ClientIdentifierID uint32 Result interface{} diff --git a/internal/eval/bloom_test.go b/internal/eval/bloom_test.go index 78f97fabe..5a2577e70 100644 --- a/internal/eval/bloom_test.go +++ b/internal/eval/bloom_test.go @@ -15,7 +15,7 @@ import ( ) func TestBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // This test only contains some basic checks for all the bloom filter // operations like BFINIT, BFADD, BFEXISTS. It assumes that the // functions called in the main function are working correctly and @@ -112,7 +112,7 @@ func TestBloomFilter(t *testing.T) { } func TestGetOrCreateBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Create a key and default opts key := "bf" opts, _ := newBloomOpts([]string{}, true) diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 820908f0c..f41d99330 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -2128,7 +2128,7 @@ func EvalQWATCH(args []string, httpOp bool, client *comm.Client, store *dstore.S } // TODO: We should return the list of all queries being watched by the client. - return clientio.Encode(clientio.CreatePushResponse(&query, queryResult.Result), false) + return clientio.Encode(clientio.CreatePushResponse(sql.Qwatch, query.String(), *queryResult.Result), false) } // EvalQUNWATCH removes the specified key from the watch list for the caller client. diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index 4b65a924e..2ff275a17 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -42,7 +42,7 @@ func setupTest(store *dstore.Store) *dstore.Store { } func TestEval(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) testEvalMSET(t, store) testEvalECHO(t, store) @@ -1109,7 +1109,7 @@ func testEvalJSONOBJLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -2833,13 +2833,13 @@ func runEvalTests(t *testing.T, tests map[string]evalTestCase, evalFunc func([]s func BenchmarkEvalMSET(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) } } func BenchmarkEvalHSET(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -2974,7 +2974,7 @@ func testEvalHKEYS(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHKEYS(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) @@ -2985,7 +2985,7 @@ func BenchmarkEvalHKEYS(b *testing.B) { } } func BenchmarkEvalPFCOUNT(b *testing.B) { - store := *dstore.NewStore(nil) + store := *dstore.NewStore(nil, nil) // Helper function to create and insert HLL objects createAndInsertHLL := func(key string, items []string) { @@ -3333,7 +3333,7 @@ func testEvalHLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalHLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("HashSize_%d", size), func(b *testing.B) { @@ -3575,7 +3575,7 @@ func testEvalTYPE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalTYPE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Define different types of objects to benchmark objectTypes := map[string]func(){ @@ -3766,7 +3766,7 @@ func testEvalJSONOBJKEYS(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJKEYS(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -3934,7 +3934,7 @@ func testEvalGETRANGE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalGETRANGE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("BENCHMARK_KEY", store.NewObj("Hello World", maxExDuration, object.ObjTypeString, object.ObjEncodingRaw)) inputs := []struct { @@ -3959,7 +3959,7 @@ func BenchmarkEvalGETRANGE(b *testing.B) { } func BenchmarkEvalHSETNX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSETNX([]string{"KEY", fmt.Sprintf("FIELD_%d", i/2), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -4021,7 +4021,7 @@ func testEvalHSETNX(t *testing.T, store *dstore.Store) { } func TestMSETConsistency(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) assert.Equal(t, "VAL", store.Get("KEY").Value) @@ -4029,7 +4029,7 @@ func TestMSETConsistency(t *testing.T) { } func BenchmarkEvalHINCRBY(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // creating new fields for i := 0; i < b.N; i++ { @@ -4279,7 +4279,7 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { } func BenchmarkEvalSETEX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -4454,7 +4454,7 @@ func testEvalINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("key1", store.NewObj("1", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) store.Put("key2", store.NewObj("1.2", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) @@ -4569,7 +4569,7 @@ func testEvalBITOP(t *testing.T, store *dstore.Store) { } func BenchmarkEvalBITOP(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setup initial data for benchmarking store.Put("key1", store.NewObj(&ByteArray{data: []byte{0x01, 0x02, 0xff}}, maxExDuration, object.ObjTypeByteArray, object.ObjEncodingByteArray)) @@ -4863,7 +4863,7 @@ func testEvalAPPEND(t *testing.T, store *dstore.Store) { } func BenchmarkEvalAPPEND(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalAPPEND([]string{"key", fmt.Sprintf("val_%d", i)}, store) } @@ -5336,7 +5336,7 @@ func testEvalHINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setting initial fields with some values store.Put("key1", store.NewObj(HashMap{"field1": "1.0", "field2": "1.2"}, maxExDuration, object.ObjTypeHashMap, object.ObjEncodingHashMap)) diff --git a/internal/eval/hmap_test.go b/internal/eval/hmap_test.go index 07bfea205..cba318717 100644 --- a/internal/eval/hmap_test.go +++ b/internal/eval/hmap_test.go @@ -71,7 +71,7 @@ func TestHashMapIncrementValue(t *testing.T) { } func TestGetValueFromHashMap(t *testing.T) { - store := store.NewStore(nil) + store := store.NewStore(nil, nil) key := "key1" field := "field1" value := "value1" diff --git a/internal/eval/main_test.go b/internal/eval/main_test.go index 7b79a58dc..04d622233 100644 --- a/internal/eval/main_test.go +++ b/internal/eval/main_test.go @@ -13,7 +13,7 @@ func TestMain(m *testing.M) { l := logger.New(logger.Opts{WithTimestamp: false}) slog.SetDefault(l) - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.ResetStore() exitCode := m.Run() diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index 700b1ac84..02d0e8975 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -236,7 +236,7 @@ func (m *Manager) updateQueryCache(queryFingerprint string, event dstore.QueryWa } func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryResult *[]sql.QueryResultRow) { - encodedResult := clientio.Encode(clientio.CreatePushResponse(query, queryResult), false) + encodedResult := clientio.Encode(clientio.CreatePushResponse(sql.Qwatch, query.String(), *queryResult), false) clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index e4e617d08..3bf4d06fd 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -11,6 +11,9 @@ import ( "syscall" "time" + dstore "github.com/dicedb/dice/internal/store" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio/iohandler/netconn" respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" @@ -37,20 +40,24 @@ type Server struct { Port int serverFD int connBacklogSize int - wm *worker.WorkerManager - sm *shard.ShardManager + workerManager *worker.WorkerManager + shardManager *shard.ShardManager + watchManager *watchmanager.Manager + cmdWatchChan chan dstore.CmdWatchEvent globalErrorChan chan error logger *slog.Logger } -func NewServer(sm *shard.ShardManager, wm *worker.WorkerManager, gec chan error, l *slog.Logger) *Server { +func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, l *slog.Logger) *Server { return &Server{ Host: config.DiceConfig.Server.Addr, Port: config.DiceConfig.Server.Port, connBacklogSize: DefaultConnBacklogSize, - wm: wm, - sm: sm, - globalErrorChan: gec, + workerManager: workerManager, + shardManager: shardManager, + watchManager: watchmanager.NewManager(l), + cmdWatchChan: cmdWatchChan, + globalErrorChan: globalErrChan, logger: l, } } @@ -67,7 +74,13 @@ func (s *Server) Run(ctx context.Context) (err error) { // Start a go routine to accept connections errChan := make(chan error, 1) wg := &sync.WaitGroup{} - wg.Add(1) + + wg.Add(2) + go func() { + defer wg.Done() + s.watchManager.Run(ctx, s.cmdWatchChan) + }() + go func(wg *sync.WaitGroup) { defer wg.Done() if err := s.AcceptConnectionRequests(ctx, wg); err != nil { @@ -178,14 +191,14 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou parser := respparser.NewParser(s.logger) respChan := make(chan *ops.StoreResponse) wID := GenerateUniqueWorkerID() - w := worker.NewWorker(wID, respChan, ioHandler, parser, s.sm, s.globalErrorChan, s.logger) + w := worker.NewWorker(wID, respChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.logger) if err != nil { s.logger.Error("Failed to create new worker for clientFD", slog.Int("client-fd", clientFD), slog.Any("error", err)) return err } // Register the worker with the worker manager - err = s.wm.RegisterWorker(w) + err = s.workerManager.RegisterWorker(w) if err != nil { return err } @@ -198,7 +211,7 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou if err != nil { s.logger.Warn("Failed to unregister worker", slog.String("worker-id", wID), slog.Any("error", err)) } - }(s.wm, wID) + }(s.workerManager, wID) wctx, cwctx := context.WithCancel(ctx) defer cwctx() err := w.Start(wctx) diff --git a/internal/server/server.go b/internal/server/server.go index af5d14191..8ee6a0453 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -37,12 +37,12 @@ type AsyncServer struct { queryWatcher *querymanager.Manager shardManager *shard.ShardManager ioChan chan *ops.StoreResponse // The server acts like a worker today, this behavior will change once IOThreads are introduced and each client gets its own worker. - watchChan chan dstore.QueryWatchEvent // This is needed to co-ordinate between the store and the query watcher. + queryWatchChan chan dstore.QueryWatchEvent // This is needed to co-ordinate between the store and the query watcher. logger *slog.Logger // logger is the logger for the server } // NewAsyncServer initializes a new AsyncServer -func NewAsyncServer(shardManager *shard.ShardManager, watchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *AsyncServer { +func NewAsyncServer(shardManager *shard.ShardManager, queryWatchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *AsyncServer { return &AsyncServer{ maxClients: config.DiceConfig.Server.MaxClients, connectedClients: make(map[int]*comm.Client), @@ -50,7 +50,7 @@ func NewAsyncServer(shardManager *shard.ShardManager, watchChan chan dstore.Quer queryWatcher: querymanager.NewQueryManager(logger), multiplexerPollTimeout: config.DiceConfig.Server.MultiplexerPollTimeout, ioChan: make(chan *ops.StoreResponse, 1000), - watchChan: watchChan, + queryWatchChan: queryWatchChan, logger: logger, } } @@ -151,7 +151,7 @@ func (s *AsyncServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - s.queryWatcher.Run(watchCtx, s.watchChan) + s.queryWatcher.Run(watchCtx, s.queryWatchChan) }() s.shardManager.RegisterWorker("server", s.ioChan) diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index d5010bcbe..6feede9d4 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -26,14 +26,14 @@ type ShardManager struct { } // NewShardManager creates a new ShardManager instance with the given number of Shards and a parent context. -func NewShardManager(shardCount uint8, watchChan chan dstore.QueryWatchEvent, globalErrorChan chan error, logger *slog.Logger) *ShardManager { +func NewShardManager(shardCount uint8, queryWatchChan chan dstore.QueryWatchEvent, cmdWatchChan chan dstore.CmdWatchEvent, globalErrorChan chan error, logger *slog.Logger) *ShardManager { shards := make([]*ShardThread, shardCount) shardReqMap := make(map[ShardID]chan *ops.StoreOp) shardErrorChan := make(chan *ShardError) for i := uint8(0); i < shardCount; i++ { // Shards are numbered from 0 to shardCount-1 - shard := NewShardThread(i, globalErrorChan, shardErrorChan, watchChan, logger) + shard := NewShardThread(i, globalErrorChan, shardErrorChan, queryWatchChan, cmdWatchChan, logger) shards[i] = shard shardReqMap[i] = shard.ReqChan } diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 144111ba1..29d078e62 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -37,10 +37,10 @@ type ShardThread struct { } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. -func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, watchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *ShardThread { +func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, queryWatchChan chan dstore.QueryWatchEvent, cmdWatchChan chan dstore.CmdWatchEvent, logger *slog.Logger) *ShardThread { return &ShardThread{ id: id, - store: dstore.NewStore(watchChan), + store: dstore.NewStore(queryWatchChan, cmdWatchChan), ReqChan: make(chan *ops.StoreOp, 1000), workerMap: make(map[string]chan *ops.StoreResponse), globalErrorChan: gec, diff --git a/internal/sql/executerbechmark_test.go b/internal/sql/executerbechmark_test.go index e7a86388c..b2a935d48 100644 --- a/internal/sql/executerbechmark_test.go +++ b/internal/sql/executerbechmark_test.go @@ -35,7 +35,7 @@ func generateBenchmarkData(count int, store *dstore.Store) { } func BenchmarkExecuteQueryOrderBykey(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -60,7 +60,7 @@ func BenchmarkExecuteQueryOrderBykey(b *testing.B) { } func BenchmarkExecuteQueryBasicOrderByValue(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -83,7 +83,7 @@ func BenchmarkExecuteQueryBasicOrderByValue(b *testing.B) { } func BenchmarkExecuteQueryLimit(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -106,7 +106,7 @@ func BenchmarkExecuteQueryLimit(b *testing.B) { } func BenchmarkExecuteQueryNoMatch(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -129,7 +129,7 @@ func BenchmarkExecuteQueryNoMatch(b *testing.B) { } func BenchmarkExecuteQueryWithBasicWhere(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -152,7 +152,7 @@ func BenchmarkExecuteQueryWithBasicWhere(b *testing.B) { } func BenchmarkExecuteQueryWithComplexWhere(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -175,7 +175,7 @@ func BenchmarkExecuteQueryWithComplexWhere(b *testing.B) { } func BenchmarkExecuteQueryWithCompareWhereKeyandValue(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -198,7 +198,7 @@ func BenchmarkExecuteQueryWithCompareWhereKeyandValue(b *testing.B) { } func BenchmarkExecuteQueryWithBasicWhereNoMatch(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -221,7 +221,7 @@ func BenchmarkExecuteQueryWithBasicWhereNoMatch(b *testing.B) { } func BenchmarkExecuteQueryWithCaseSesnsitivity(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -243,7 +243,7 @@ func BenchmarkExecuteQueryWithCaseSesnsitivity(b *testing.B) { } func BenchmarkExecuteQueryWithClauseOnKey(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -266,7 +266,7 @@ func BenchmarkExecuteQueryWithClauseOnKey(b *testing.B) { } func BenchmarkExecuteQueryWithAllMatchingKeyRegex(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -308,7 +308,7 @@ func generateBenchmarkJSONData(b *testing.B, count int, json string, store *dsto } func BenchmarkExecuteQueryWithJSON(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -333,7 +333,7 @@ func BenchmarkExecuteQueryWithJSON(b *testing.B) { } func BenchmarkExecuteQueryWithNestedJSON(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -358,7 +358,7 @@ func BenchmarkExecuteQueryWithNestedJSON(b *testing.B) { } func BenchmarkExecuteQueryWithJsonInLeftAndRightExpressions(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -384,7 +384,7 @@ func BenchmarkExecuteQueryWithJsonInLeftAndRightExpressions(b *testing.B) { func BenchmarkExecuteQueryWithJsonNoMatch(b *testing.B) { for _, v := range benchmarkDataSizesJSON { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) diff --git a/internal/sql/executor_test.go b/internal/sql/executor_test.go index fbee39bc3..7feb9f4d3 100644 --- a/internal/sql/executor_test.go +++ b/internal/sql/executor_test.go @@ -42,7 +42,7 @@ func setup(store *dstore.Store, dataset []keyValue) { } func TestExecuteQueryOrderBykey(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryString := "SELECT $key, $value WHERE $key like 'k*' ORDER BY $key ASC" @@ -69,7 +69,7 @@ func TestExecuteQueryOrderBykey(t *testing.T) { } func TestExecuteQueryBasicOrderByValue(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $key, $value WHERE $key like 'k*' ORDER BY $value ASC" @@ -96,7 +96,7 @@ func TestExecuteQueryBasicOrderByValue(t *testing.T) { } func TestExecuteQueryLimit(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $value WHERE $key like 'k*' ORDER BY $key ASC LIMIT 3" @@ -123,7 +123,7 @@ func TestExecuteQueryLimit(t *testing.T) { } func TestExecuteQueryNoMatch(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $key, $value WHERE $key like 'x*'" @@ -137,7 +137,7 @@ func TestExecuteQueryNoMatch(t *testing.T) { } func TestExecuteQueryWithWhere(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("BasicWhereClause", func(t *testing.T) { queryStr := "SELECT $key, $value WHERE $value = 'v3' AND $key like 'k*'" @@ -190,7 +190,7 @@ func TestExecuteQueryWithWhere(t *testing.T) { } func TestExecuteQueryWithIncompatibleTypes(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("ComparingStrWithInt", func(t *testing.T) { @@ -205,7 +205,7 @@ func TestExecuteQueryWithIncompatibleTypes(t *testing.T) { } func TestExecuteQueryWithEdgeCases(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("CaseSensitivity", func(t *testing.T) { @@ -285,7 +285,7 @@ func setupJSON(t *testing.T, store *dstore.Store, dataset []keyValue) { } func TestExecuteQueryWithJsonExpressionInWhere(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setupJSON(t, store, jsonWhereClauseDataset) t.Run("BasicWhereClauseWithJSON", func(t *testing.T) { @@ -394,7 +394,7 @@ var jsonOrderDataset = []keyValue{ } func TestExecuteQueryWithJsonOrderBy(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setupJSON(t, store, jsonOrderDataset) t.Run("OrderBySimpleJSONField", func(t *testing.T) { @@ -557,7 +557,7 @@ var stringComparisonDataset = []keyValue{ } func TestExecuteQueryWithLikeStringComparisons(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, stringComparisonDataset) testCases := []struct { @@ -642,7 +642,7 @@ func TestExecuteQueryWithLikeStringComparisons(t *testing.T) { } func TestExecuteQueryWithStringNotLikeComparisons(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, stringComparisonDataset) testCases := []struct { diff --git a/internal/store/expire_test.go b/internal/store/expire_test.go index afc9982ae..bace26856 100644 --- a/internal/store/expire_test.go +++ b/internal/store/expire_test.go @@ -7,7 +7,7 @@ import ( ) func TestDelExpiry(t *testing.T) { - store := NewStore(nil) + store := NewStore(nil, nil) // Initialize the test environment store.store = NewStoreMap() store.expires = NewExpireMap() diff --git a/internal/store/lfu_eviction_test.go b/internal/store/lfu_eviction_test.go index 9ccd90930..79d079364 100644 --- a/internal/store/lfu_eviction_test.go +++ b/internal/store/lfu_eviction_test.go @@ -11,7 +11,7 @@ import ( func TestLFUEviction(t *testing.T) { originalEvictionPolicy := config.DiceConfig.Server.EvictionPolicy - store := NewStore(nil) + store := NewStore(nil, nil) config.DiceConfig.Server.EvictionPolicy = config.EvictAllKeysLFU // Define test cases diff --git a/internal/store/store.go b/internal/store/store.go index 9497727b7..1b8b607fd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -42,18 +42,25 @@ type QueryWatchEvent struct { Value object.Obj } +type CmdWatchEvent struct { + Cmd string + AffectedKey string +} + type Store struct { - store common.ITable[string, *object.Obj] - expires common.ITable[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. - numKeys int - watchChan chan QueryWatchEvent + store common.ITable[string, *object.Obj] + expires common.ITable[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. + numKeys int + queryWatchChan chan QueryWatchEvent + cmdWatchChan chan CmdWatchEvent } -func NewStore(watchChan chan QueryWatchEvent) *Store { +func NewStore(queryWatchChan chan QueryWatchEvent, cmdWatchChan chan CmdWatchEvent) *Store { return &Store{ - store: NewStoreRegMap(), - expires: NewExpireRegMap(), - watchChan: watchChan, + store: NewStoreRegMap(), + expires: NewExpireRegMap(), + queryWatchChan: queryWatchChan, + cmdWatchChan: cmdWatchChan, } } @@ -149,9 +156,12 @@ func (store *Store) putHelper(k string, obj *object.Obj, opts ...PutOption) { } store.store.Put(k, obj) - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(k, Set, *obj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("SET", k) + } } // getHelper is a helper function to get the object from the store. It also updates the last accessed time if touch is true. @@ -249,9 +259,12 @@ func (store *Store) Rename(sourceKey, destKey string) bool { store.numKeys-- // Notify watchers about the deletion of the source key - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(sourceKey, Del, *sourceObj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("DEL", sourceKey) + } return true } @@ -292,9 +305,12 @@ func (store *Store) deleteKey(k string, obj *object.Obj) bool { store.expires.Delete(obj) store.numKeys-- - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(k, Del, *obj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("DEL", k) + } return true } @@ -312,7 +328,11 @@ func (store *Store) delByPtr(ptr string) bool { // notifyQueryManager notifies the query manager about a key change, so that it can update the query cache if needed. func (store *Store) notifyQueryManager(k, operation string, obj object.Obj) { - store.watchChan <- QueryWatchEvent{k, operation, obj} + store.queryWatchChan <- QueryWatchEvent{k, operation, obj} +} + +func (store *Store) notifyWatchManager(cmd, affectedKey string) { + store.cmdWatchChan <- CmdWatchEvent{cmd, affectedKey} } func (store *Store) GetStore() common.ITable[string, *object.Obj] { diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go new file mode 100644 index 000000000..289dd4753 --- /dev/null +++ b/internal/watchmanager/watch_manager.go @@ -0,0 +1,176 @@ +package watchmanager + +import ( + "context" + "log/slog" + "sync" + + "github.com/dicedb/dice/internal/cmd" + dstore "github.com/dicedb/dice/internal/store" +) + +type ( + WatchSubscription struct { + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. + AdhocReqChan chan *cmd.DiceDBCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. + WatchCmd *cmd.DiceDBCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. + } + + Manager struct { + querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[chan *cmd.DiceDBCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] + fingerprintCmdMap map[uint32]*cmd.DiceDBCmd // fingerprintCmdMap is a map of fingerprint -> DiceDBCmd + logger *slog.Logger + } +) + +var ( + CmdWatchSubscriptionChan chan WatchSubscription + affectedCmdMap = map[string]map[string]struct{}{ + "SET": {"GET": struct{}{}}, + "DEL": {"GET": struct{}{}}, + } +) + +func NewManager(logger *slog.Logger) *Manager { + CmdWatchSubscriptionChan = make(chan WatchSubscription) + return &Manager{ + querySubscriptionMap: make(map[string]map[uint32]struct{}), + tcpSubscriptionMap: make(map[uint32]map[chan *cmd.DiceDBCmd]struct{}), + fingerprintCmdMap: make(map[uint32]*cmd.DiceDBCmd), + logger: logger, + } +} + +// Run starts the watch manager, listening for subscription requests and events +func (m *Manager) Run(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + defer wg.Done() + m.listenForEvents(ctx, cmdWatchChan) + }() + + <-ctx.Done() + wg.Wait() +} + +func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { + for { + select { + case <-ctx.Done(): + return + case sub := <-CmdWatchSubscriptionChan: + if sub.Subscribe { + m.handleSubscription(sub) + } else { + m.handleUnsubscription(sub) + } + case watchEvent := <-cmdWatchChan: + m.handleWatchEvent(watchEvent) + } + } +} + +// handleSubscription processes a new subscription request +func (m *Manager) handleSubscription(sub WatchSubscription) { + fingerprint := sub.WatchCmd.GetFingerprint() + key := sub.WatchCmd.GetKey() + + // Add fingerprint to querySubscriptionMap + if _, exists := m.querySubscriptionMap[key]; !exists { + m.querySubscriptionMap[key] = make(map[uint32]struct{}) + } + m.querySubscriptionMap[key][fingerprint] = struct{}{} + + // Add DiceDBCmd to fingerprintCmdMap + m.fingerprintCmdMap[fingerprint] = sub.WatchCmd + + // Add client channel to tcpSubscriptionMap + if _, exists := m.tcpSubscriptionMap[fingerprint]; !exists { + m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.DiceDBCmd]struct{}) + } + m.tcpSubscriptionMap[fingerprint][sub.AdhocReqChan] = struct{}{} +} + +// handleUnsubscription processes an unsubscription request +func (m *Manager) handleUnsubscription(sub WatchSubscription) { + fingerprint := sub.Fingerprint + + // Remove clientID from tcpSubscriptionMap + if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { + delete(clients, sub.AdhocReqChan) + // If there are no more clients listening to this fingerprint, remove it from the map + if len(clients) == 0 { + // Remove the fingerprint from tcpSubscriptionMap + delete(m.tcpSubscriptionMap, fingerprint) + // Also remove the fingerprint from fingerprintCmdMap + delete(m.fingerprintCmdMap, fingerprint) + } else { + // Update the map with the new set of clients + // TODO: Is this actually required? + m.tcpSubscriptionMap[fingerprint] = clients + } + } + + // Remove fingerprint from querySubscriptionMap + if diceDBCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { + key := diceDBCmd.GetKey() + if fingerprints, ok := m.querySubscriptionMap[key]; ok { + // Remove the fingerprint from the list of fingerprints listening to this key + delete(fingerprints, fingerprint) + // If there are no more fingerprints listening to this key, remove it from the map + if len(fingerprints) == 0 { + delete(m.querySubscriptionMap, key) + } else { + // Update the map with the new set of fingerprints. + // TODO: Is this actually required? + m.querySubscriptionMap[key] = fingerprints + } + } + } +} + +func (m *Manager) handleWatchEvent(event dstore.CmdWatchEvent) { + // Check if any watch commands are listening to updates on this key. + fingerprints, exists := m.querySubscriptionMap[event.AffectedKey] + if !exists { + return + } + + affectedCommands, cmdExists := affectedCmdMap[event.Cmd] + if !cmdExists { + m.logger.Error("Received a watch event for an unknown command type", + slog.String("cmd", event.Cmd)) + return + } + + // iterate through all command fingerprints that are listening to this key + for fingerprint := range fingerprints { + cmdToExecute := m.fingerprintCmdMap[fingerprint] + // Check if the command associated with this fingerprint actually needs to be executed for this event. + // For instance, if the event is a SET, only GET commands need to be executed. This also + // helps us handle cases where a key might get updated by an unrelated command which makes it + // incompatible with the watched command. + if _, affected := affectedCommands[cmdToExecute.Cmd]; affected { + m.notifyClients(fingerprint, cmdToExecute) + } + } +} + +// notifyClients sends cmd to all clients listening to this fingerprint, so that they can execute it. +func (m *Manager) notifyClients(fingerprint uint32, diceDBCmd *cmd.DiceDBCmd) { + clients, exists := m.tcpSubscriptionMap[fingerprint] + if !exists { + m.logger.Warn("No clients found for fingerprint", + slog.Uint64("fingerprint", uint64(fingerprint))) + return + } + + for clientChan := range clients { + clientChan <- diceDBCmd + } +} diff --git a/internal/worker/cmd_meta.go b/internal/worker/cmd_meta.go index 3b47c344e..84ab08018 100644 --- a/internal/worker/cmd_meta.go +++ b/internal/worker/cmd_meta.go @@ -15,6 +15,7 @@ const ( SingleShard MultiShard Custom + Watch ) // Global commands @@ -26,9 +27,10 @@ const ( // Single-shard commands. const ( - CmdSet = "SET" - CmdGet = "GET" - CmdGetSet = "GETSET" + CmdSet = "SET" + CmdGet = "GET" + CmdGetSet = "GETSET" + CmdGetWatch = "GET.WATCH" ) type CmdMeta struct { @@ -69,6 +71,9 @@ var CommandsMeta = map[string]CmdMeta{ CmdGetSet: { CmdType: SingleShard, }, + CmdGetWatch: { + CmdType: Watch, + }, } func init() { @@ -92,7 +97,7 @@ func validateCmdMeta(c string, meta CmdMeta) error { if meta.decomposeCommand == nil || meta.composeResponse == nil { return fmt.Errorf("multi-shard command %s must have both decomposeCommand and composeResponse implemented", c) } - case SingleShard, Custom: + case SingleShard, Watch, Custom: // No specific validations for these types currently default: return fmt.Errorf("unknown command type for %s", c) diff --git a/internal/worker/worker.go b/internal/worker/worker.go index dd379a187..3d3e04be1 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -9,6 +9,8 @@ import ( "syscall" "time" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio" @@ -34,6 +36,7 @@ type BaseWorker struct { parser requestparser.Parser shardManager *shard.ShardManager respChan chan *ops.StoreResponse + adhocReqChan chan *cmd.DiceDBCmd Session *auth.Session globalErrorChan chan error logger *slog.Logger @@ -52,6 +55,7 @@ func NewWorker(wid string, respChan chan *ops.StoreResponse, respChan: respChan, logger: logger, Session: auth.NewSession(), + adhocReqChan: make(chan *cmd.DiceDBCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. } } @@ -61,6 +65,21 @@ func (w *BaseWorker) ID() string { func (w *BaseWorker) Start(ctx context.Context) error { errChan := make(chan error, 1) + + dataChan := make(chan []byte) + readErrChan := make(chan error) + + go func() { + for { + data, err := w.ioHandler.Read(ctx) + if err != nil { + readErrChan <- err + return + } + dataChan <- data + } + }() + for { select { case <-ctx.Done(): @@ -77,12 +96,14 @@ func (w *BaseWorker) Start(ctx context.Context) error { } } return fmt.Errorf("error writing response: %w", err) - default: - data, err := w.ioHandler.Read(ctx) - if err != nil { - w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } + case cmdReq := <-w.adhocReqChan: + // Handle adhoc requests of DiceDBCmd + func() { + execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests + defer cancel() + w.executeCommandHandler(execCtx, errChan, []*cmd.DiceDBCmd{cmdReq}, true) + }() + case data := <-dataChan: cmds, err := w.parser.Parse(data) if err != nil { err = w.ioHandler.Write(ctx, err) @@ -120,22 +141,30 @@ func (w *BaseWorker) Start(ctx context.Context) error { } // executeCommand executes the command and return the response back to the client func(errChan chan error) { - execctx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests + execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests defer cancel() - err = w.executeCommand(execctx, cmds[0]) - if err != nil { - w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { - w.logger.Debug("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) - errChan <- err - } - } + w.executeCommandHandler(execCtx, errChan, cmds, false) }(errChan) + case err := <-readErrChan: + w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } +} + +func (w *BaseWorker) executeCommandHandler(execCtx context.Context, errChan chan error, cmds []*cmd.DiceDBCmd, isWatchNotification bool) { + err := w.executeCommand(execCtx, cmds[0], isWatchNotification) + if err != nil { + w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { + w.logger.Debug("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) + errChan <- err } } } -func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd) error { +func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { + responseType := clientio.ResponseTypeRegular // Break down the single command into multiple commands if multisharding is supported. // The length of cmdList helps determine how many shards to wait for responses. cmdList := make([]*cmd.DiceDBCmd, 0) @@ -157,10 +186,39 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm case SingleShard: // For single-shard or custom commands, process them without breaking up. cmdList = append(cmdList, diceDBCmd) + responseType = clientio.ResponseTypeRegular case MultiShard: // If the command supports multisharding, break it down into multiple commands. cmdList = meta.decomposeCommand(diceDBCmd) + responseType = clientio.ResponseTypeRegular + case Watch: + // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass + // it along as is. + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6], // Remove the .WATCH suffix + Args: diceDBCmd.Args, + } + + cmdList = append(cmdList, watchCmd) + + // Execute the command (scatter and gather) + if err := w.scatter(ctx, cmdList); err != nil { + return err + } + + if err := w.gather(ctx, diceDBCmd, len(cmdList), clientio.ResponseTypePush); err != nil { + return err + } + + // Proceed to subscribe after successful execution + watchmanager.CmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: watchCmd, + AdhocReqChan: w.adhocReqChan, + } + + return nil case Custom: switch diceDBCmd.Cmd { case CmdAuth: @@ -184,18 +242,18 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm } // Scatter the broken-down commands to the appropriate shards. - err := w.scatter(ctx, cmdList) - if err != nil { + if err := w.scatter(ctx, cmdList); err != nil { return err } - // Gather the responses from the shards and write them to the buffer. - err = w.gather(ctx, diceDBCmd.Cmd, len(cmdList), meta.CmdType) - if err != nil { - return err + // For watch notifications, we need to set the responseType to push + if isWatchNotification { + responseType = clientio.ResponseTypePush } - return nil + // Gather the responses from the shards and write them to the buffer. + err := w.gather(ctx, diceDBCmd, len(cmdList), responseType) + return err } // scatter distributes the DiceDB commands to the respective shards based on the key. @@ -235,8 +293,8 @@ func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd) error { // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdType) error { - // Loop to wait for messages from numberof shards +func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds, responseType int) error { + // Loop to wait for messages from number of shards var evalResp []eval.EvalResponse for numCmds != 0 { select { @@ -255,38 +313,38 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy } } - // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. - // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. - // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. - // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. - // If not found, treat it as a command that's not yet refactored, and write the response back to the client. - val, ok := CommandsMeta[c] - if !ok { - if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, []byte(evalResp[0].Error.Error())) - if err != nil { - w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } - } - - err := w.ioHandler.Write(ctx, evalResp[0].Result.([]byte)) + switch responseType { + case clientio.ResponseTypeRegular: + return w.handleRegularResponse(ctx, diceDBCmd, evalResp) + case clientio.ResponseTypePush: + return w.handlePushResponse(ctx, diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), evalResp) + default: + w.logger.Error("Unknown response type", slog.String("workerID", w.id), slog.Int("responseType", responseType)) + err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - return nil } +} - switch ct { - case SingleShard, Custom: +// handleRegularResponse handles the response for regular commands, i.e., responses for which are pushed from the server to the client. +func (w *BaseWorker) handleRegularResponse(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, evalResp []eval.EvalResponse) error { + // Check if the command is multi-shard capable + // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. + // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. + // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. + // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. + // If not found, treat it as a command that's not yet refactored, and write the response back to the client. + val, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok || val.CmdType == SingleShard || val.CmdType == Custom { + // Handle single-shard or custom commands if evalResp[0].Error != nil { err := w.ioHandler.Write(ctx, evalResp[0].Error) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) } - return err } @@ -295,23 +353,39 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - - case MultiShard: + } else if val.CmdType == MultiShard { + // Handle multi-shard commands err := w.ioHandler.Write(ctx, val.composeResponse(evalResp...)) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - - default: - w.logger.Error("Unknown command type", slog.String("workerID", w.id)) + } else { + w.logger.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", evalResp)) err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } } + return nil +} + +// handlePushResponse handles the response for push commands, i.e., responses for which are pushed from the server to the client. +func (w *BaseWorker) handlePushResponse(ctx context.Context, cmdName, pushResponseKey string, evalResp []eval.EvalResponse) error { + if evalResp[0].Error != nil { + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmdName, pushResponseKey, evalResp[0].Error)) + if err != nil { + w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) + } + return err + } + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmdName, pushResponseKey, evalResp[0].Result)) + if err != nil { + w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } return nil } diff --git a/main.go b/main.go index ed08af854..a66e824ad 100644 --- a/main.go +++ b/main.go @@ -50,7 +50,8 @@ func main() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) - watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.WatchChanBufSize) + queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.WatchChanBufSize) + cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Server.KeysLimit) var serverErrCh chan error // Get the number of available CPU cores on the machine using runtime.NumCPU(). @@ -76,7 +77,7 @@ func main() { runtime.GOMAXPROCS(numCores) // Initialize the ShardManager - shardManager := shard.NewShardManager(uint8(numCores), watchChan, serverErrCh, logr) + shardManager := shard.NewShardManager(uint8(numCores), queryWatchChan, cmdWatchChan, serverErrCh, logr) wg := sync.WaitGroup{} @@ -91,7 +92,7 @@ func main() { // Initialize the AsyncServer server // Find a port and bind it if !config.EnableMultiThreading { - asyncServer := server.NewAsyncServer(shardManager, watchChan, logr) + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, logr) if err := asyncServer.FindPortAndBind(); err != nil { cancel() logr.Error("Error finding and binding port", slog.Any("error", err)) @@ -154,7 +155,7 @@ func main() { } else { workerManager := worker.NewWorkerManager(config.DiceConfig.Server.MaxClients, shardManager) // Initialize the RESP Server - respServer := resp.NewServer(shardManager, workerManager, serverErrCh, logr) + respServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, serverErrCh, logr) serverWg.Add(1) go func() { defer serverWg.Done() @@ -186,7 +187,7 @@ func main() { }() } - websocketServer := server.NewWebSocketServer(shardManager, watchChan, logr) + websocketServer := server.NewWebSocketServer(shardManager, queryWatchChan, logr) serverWg.Add(1) go func() { defer serverWg.Done()