From a565ef7eebf6417c22882aa115b93f0eb8f08b22 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Wed, 2 Oct 2024 21:51:21 +0530 Subject: [PATCH 01/11] initial commit --- .../commands/async/getex_test.go | 4 +- integration_tests/commands/async/json_test.go | 10 +- .../commands/async/set_data_cmd_test.go | 4 +- integration_tests/commands/async/set_test.go | 30 +-- .../commands/async/touch_test.go | 6 +- integration_tests/commands/http/getex_test.go | 4 +- integration_tests/commands/http/set_test.go | 4 +- internal/clientio/client_identifier.go | 13 ++ internal/cmd/cmds.go | 20 ++ internal/comm/client.go | 6 + internal/eval/store_eval.go | 6 +- internal/querymanager/query_manager.go | 26 +-- internal/store/store.go | 5 + internal/watchmanager/watch_manager.go | 182 ++++++++++++++++++ 14 files changed, 267 insertions(+), 53 deletions(-) create mode 100644 internal/clientio/client_identifier.go create mode 100644 internal/watchmanager/watch_manager.go diff --git a/integration_tests/commands/async/getex_test.go b/integration_tests/commands/async/getex_test.go index a02597dad..8266cc671 100644 --- a/integration_tests/commands/async/getex_test.go +++ b/integration_tests/commands/async/getex_test.go @@ -15,8 +15,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/json_test.go b/integration_tests/commands/async/json_test.go index a92e60b65..59f5cbc85 100644 --- a/integration_tests/commands/async/json_test.go +++ b/integration_tests/commands/async/json_test.go @@ -862,8 +862,8 @@ func TestJsonNummultby(t *testing.T) { invalidArgMessage := "ERR wrong number of arguments for 'json.nummultby' command" testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string }{ @@ -1021,9 +1021,9 @@ func TestJSONNumIncrBy(t *testing.T) { defer conn.Close() invalidArgMessage := "ERR wrong number of arguments for 'json.numincrby' command" testCases := []struct { - name string - setupData string - commands []string + name string + setupData string + commands []string expected []interface{} assertType []string cleanUp []string diff --git a/integration_tests/commands/async/set_data_cmd_test.go b/integration_tests/commands/async/set_data_cmd_test.go index d6ffaa4f6..cb03863d8 100644 --- a/integration_tests/commands/async/set_data_cmd_test.go +++ b/integration_tests/commands/async/set_data_cmd_test.go @@ -30,8 +30,8 @@ func TestSetDataCommand(t *testing.T) { defer conn.Close() testCases := []struct { - name string - cmd []string + name string + cmd []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/set_test.go b/integration_tests/commands/async/set_test.go index 611b937c1..98f1b5082 100644 --- a/integration_tests/commands/async/set_test.go +++ b/integration_tests/commands/async/set_test.go @@ -20,12 +20,12 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Value", + name: "Set and Get Simple Cmd", commands: []string{"SET k v", "GET k"}, expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Value", + name: "Set and Get Integer Cmd", commands: []string{"SET k 123456789", "GET k"}, expected: []interface{}{"OK", int64(123456789)}, }, @@ -146,31 +146,31 @@ func TestSetWithExat(t *testing.T) { func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Value mismatch for cmd SET k v EXAT "+Etime) - assert.Equal(t, "v", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Value mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Cmd mismatch for cmd SET k v EXAT "+Etime) + assert.Equal(t, "v", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Cmd mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Value mismatch for cmd TTL k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Cmd mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") }) t.Run("SET with invalid EXAT expires key immediately", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Cmd mismatch for cmd SET k v EXAT "+BadTime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") }) t.Run("SET with EXAT and PXAT returns syntax error", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Cmd mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") }) } @@ -187,7 +187,7 @@ func TestWithKeepTTLFlag(t *testing.T) { for i := 0; i < len(tcase.commands); i++ { cmd := tcase.commands[i] out := tcase.expected[i] - assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) } } @@ -196,5 +196,5 @@ func TestWithKeepTTLFlag(t *testing.T) { cmd := "GET k" out := "(nil)" - assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) } diff --git a/integration_tests/commands/async/touch_test.go b/integration_tests/commands/async/touch_test.go index b1c7c98d9..963cb20cd 100644 --- a/integration_tests/commands/async/touch_test.go +++ b/integration_tests/commands/async/touch_test.go @@ -12,14 +12,14 @@ func TestTouch(t *testing.T) { defer conn.Close() testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration }{ { - name: "Touch Simple Value", + name: "Touch Simple Cmd", commands: []string{"SET foo bar", "OBJECT IDLETIME foo", "TOUCH foo", "OBJECT IDLETIME foo"}, expected: []interface{}{"OK", int64(2), int64(1), int64(0)}, assertType: []string{"equal", "assert", "equal", "assert"}, diff --git a/integration_tests/commands/http/getex_test.go b/integration_tests/commands/http/getex_test.go index 9bbd42fee..e93552755 100644 --- a/integration_tests/commands/http/getex_test.go +++ b/integration_tests/commands/http/getex_test.go @@ -14,8 +14,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []HTTPCommand + name string + commands []HTTPCommand expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/http/set_test.go b/integration_tests/commands/http/set_test.go index bf5d8d6d0..ed367a1fd 100644 --- a/integration_tests/commands/http/set_test.go +++ b/integration_tests/commands/http/set_test.go @@ -20,7 +20,7 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Value", + name: "Set and Get Simple Cmd", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v"}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, @@ -28,7 +28,7 @@ func TestSet(t *testing.T) { expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Value", + name: "Set and Get Integer Cmd", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": 123456789}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, diff --git a/internal/clientio/client_identifier.go b/internal/clientio/client_identifier.go new file mode 100644 index 000000000..91ea3bce0 --- /dev/null +++ b/internal/clientio/client_identifier.go @@ -0,0 +1,13 @@ +package clientio + +type ClientIdentifier struct { + ClientIdentifierID int + IsHTTPClient bool +} + +func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { + return ClientIdentifier{ + ClientIdentifierID: clientIdentifierID, + IsHTTPClient: isHTTPClient, + } +} diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index d0dfae6fa..d677e685f 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -1,5 +1,11 @@ package cmd +import ( + "fmt" + "github.com/dgryski/go-farm" + "strings" +) + type RedisCmd struct { RequestID uint32 Cmd string @@ -10,3 +16,17 @@ type RedisCmds struct { Cmds []*RedisCmd RequestID uint32 } + +// GetFingerprint returns a 32-bit fingerprint of the command and its arguments. +func (cmd *RedisCmd) GetFingerprint() uint32 { + return farm.Fingerprint32([]byte(fmt.Sprintf("%s-%s", cmd.Cmd, strings.Join(cmd.Args, " ")))) +} + +// GetKey Returns the key which the command operates on. +// +// TODO: This is a naive implementation which assumes that the first argument is the key. +// This is not true for all commands, however, for now this is only used by the watch manager, +// which as of now only supports a small subset of commands (all of which fit this implementation). +func (cmd *RedisCmd) GetKey() string { + return cmd.Args[0] +} diff --git a/internal/comm/client.go b/internal/comm/client.go index 33601ca67..1cadc533e 100644 --- a/internal/comm/client.go +++ b/internal/comm/client.go @@ -8,6 +8,12 @@ import ( "github.com/dicedb/dice/internal/cmd" ) +type CmdWatchResponse struct { + ClientIdentifierID uint32 + Result interface{} + Error error +} + type QwatchResponse struct { ClientIdentifierID uint32 Result interface{} diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index c9c02d4d8..6240c64eb 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -154,7 +154,7 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { // Decode and return the value based on its encoding switch _, oEnc := object.ExtractTypeEncoding(obj); oEnc { case object.ObjEncodingInt: - // Value is stored as an int64, so use type assertion + // Cmd is stored as an int64, so use type assertion if val, ok := obj.Value.(int64); ok { return EvalResponse{Result: clientio.Encode(val, false), Error: nil} } @@ -162,7 +162,7 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { Error: errors.New(string(diceerrors.NewErrWithFormattedMessage("expected int64 but got another type: %s", obj.Value)))} case object.ObjEncodingEmbStr, object.ObjEncodingRaw: - // Value is stored as a string, use type assertion + // Cmd is stored as a string, use type assertion if val, ok := obj.Value.(string); ok { return EvalResponse{Result: clientio.Encode(val, false), Error: nil} } @@ -170,7 +170,7 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { Error: errors.New(string(diceerrors.NewErrWithMessage("expected string but got another type")))} case object.ObjEncodingByteArray: - // Value is stored as a bytearray, use type assertion + // Cmd is stored as a bytearray, use type assertion if val, ok := obj.Value.(*ByteArray); ok { return EvalResponse{Result: clientio.Encode(string(val.data), false), Error: nil} } diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index a1c61b246..f91fa3f81 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -64,11 +64,6 @@ type ( Query string `json:"query"` Data []any `json:"data"` } - - ClientIdentifier struct { - ClientIdentifierID int - IsHTTPClient bool - } ) var ( @@ -79,13 +74,6 @@ var ( AdhocQueryChan chan AdhocQuery ) -func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { - return ClientIdentifier{ - ClientIdentifierID: clientIdentifierID, - IsHTTPClient: isHTTPClient, - } -} - // NewQueryManager initializes a new Manager. func NewQueryManager(logger *slog.Logger) *Manager { QuerySubscriptionChan = make(chan QuerySubscription) @@ -130,11 +118,11 @@ func (m *Manager) listenForSubscriptions(ctx context.Context) { for { select { case event := <-QuerySubscriptionChan: - var client ClientIdentifier + var client clientio.ClientIdentifier if event.QwatchClientChan != nil { - client = NewClientIdentifier(int(event.ClientIdentifierID), true) + client = clientio.NewClientIdentifier(int(event.ClientIdentifierID), true) } else { - client = NewClientIdentifier(event.ClientFD, false) + client = clientio.NewClientIdentifier(event.ClientFD, false) } if event.Subscribe { @@ -224,7 +212,7 @@ func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryRe clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly - switch clientIdentifier := clientKey.(ClientIdentifier); { + switch clientIdentifier := clientKey.(clientio.ClientIdentifier); { case clientIdentifier.IsHTTPClient: qwatchClientResponseChannel := clientVal.(chan comm.QwatchResponse) qwatchClientResponseChannel <- comm.QwatchResponse{ @@ -274,7 +262,7 @@ func (m *Manager) sendWithRetry(query *sql.DSQLQuery, clientFD int, data []byte) slog.Int("client", clientFD), slog.Any("error", err), ) - m.removeWatcher(query, NewClientIdentifier(clientFD, false), nil) + m.removeWatcher(query, clientio.NewClientIdentifier(clientFD, false), nil) return } } @@ -297,7 +285,7 @@ func (m *Manager) serveAdhocQueries(ctx context.Context) { } // addWatcher adds a client as a watcher to a query. -func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, +func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, qwatchClientChan chan comm.QwatchResponse, cacheChan chan *[]struct { Key string Value *object.Obj @@ -327,7 +315,7 @@ func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdenti } // removeWatcher removes a client from the watchlist for a query. -func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, +func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, qwatchClientChan chan comm.QwatchResponse) { queryString := query.String() if clients, ok := m.WatchList.Load(queryString); ok { diff --git a/internal/store/store.go b/internal/store/store.go index 46fbe2bbb..bb0119dd3 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -22,6 +22,11 @@ type QueryWatchEvent struct { Value object.Obj } +type CmdWatchEvent struct { + Cmd string + AffectedKey string +} + type Store struct { store *swiss.Map[string, *object.Obj] expires *swiss.Map[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go new file mode 100644 index 000000000..322cbf6d8 --- /dev/null +++ b/internal/watchmanager/watch_manager.go @@ -0,0 +1,182 @@ +package watchmanager + +import ( + "context" + "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/cmd" + "github.com/dicedb/dice/internal/comm" + dstore "github.com/dicedb/dice/internal/store" + "log/slog" + "sync" +) + +type ( + WatchSubscription struct { + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe + WatchCmd cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. + ClientFD int // ClientFD is the file descriptor of the client connection + CmdWatchClientChan chan comm.CmdWatchResponse // CmdWatchClientChan is the generic channel for HTTP/Websockets etc. + ClientIdentifierID uint32 // ClientIdentifierID Helps identify CmdWatch client on httpserver side + } + + Manager struct { + querySubscriptionMap map[string]map[uint32]bool // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[clientio.ClientIdentifier]bool // tcpSubscriptionMap is a map of fingerprint -> [client1, client2, ...] + fingerprintCmdMap map[uint32]cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd + mu sync.RWMutex + logger *slog.Logger + } +) + +var ( + CmdWatchSubscriptionChan chan WatchSubscription +) + +func NewManager(logger *slog.Logger) *Manager { + CmdWatchSubscriptionChan = make(chan WatchSubscription) + return &Manager{ + querySubscriptionMap: make(map[string]map[uint32]bool), + tcpSubscriptionMap: make(map[uint32]map[clientio.ClientIdentifier]bool), + fingerprintCmdMap: make(map[uint32]cmd.RedisCmd), + logger: logger, + } +} + +// Run starts the watch manager, listening for subscription requests and events +func (m *Manager) Run(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { + var wg sync.WaitGroup + + wg.Add(2) + go func() { + defer wg.Done() + m.listenForSubscriptions(ctx) + }() + + go func() { + defer wg.Done() + m.listenForEvents(ctx, eventChan) + }() + + wg.Wait() +} + +// listenForSubscriptions handles incoming subscription requests +func (m *Manager) listenForSubscriptions(ctx context.Context) { + for { + select { + case sub := <-CmdWatchSubscriptionChan: + if sub.Subscribe { + m.handleSubscription(sub) + } else { + m.handleUnsubscription(sub) + } + case <-ctx.Done(): + return + } + } +} + +// handleSubscription processes a new subscription request +func (m *Manager) handleSubscription(sub WatchSubscription) { + fingerprint := sub.WatchCmd.GetFingerprint() + key := sub.WatchCmd.GetKey() + + client := clientio.NewClientIdentifier(sub.ClientFD, false) + + m.mu.Lock() + defer m.mu.Unlock() + + // Add fingerprint to querySubscriptionMap + if m.querySubscriptionMap[key] == nil { + m.querySubscriptionMap[key] = make(map[uint32]bool) + } + m.querySubscriptionMap[key][fingerprint] = true + + // Add RedisCmd to fingerprintCmdMap + m.fingerprintCmdMap[fingerprint] = sub.WatchCmd + + // Add clientID to tcpSubscriptionMap + if m.tcpSubscriptionMap[fingerprint] == nil { + m.tcpSubscriptionMap[fingerprint] = make(map[clientio.ClientIdentifier]bool) + } + m.tcpSubscriptionMap[fingerprint][client] = true +} + +// handleUnsubscription processes an unsubscription request +func (m *Manager) handleUnsubscription(sub WatchSubscription) { + fingerprint := sub.Fingerprint + client := clientio.NewClientIdentifier(sub.ClientFD, false) + + m.mu.Lock() + defer m.mu.Unlock() + + // Remove clientID from tcpSubscriptionMap + if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { + delete(clients, client) + // If there are no more clients listening to this fingerprint, remove it from the map + if len(clients) == 0 { + // Remove the fingerprint from tcpSubscriptionMap + delete(m.tcpSubscriptionMap, fingerprint) + // Also remove the fingerprint from fingerprintCmdMap + delete(m.fingerprintCmdMap, fingerprint) + } else { + // Update the map with the new set of clients + m.tcpSubscriptionMap[fingerprint] = clients + } + } + + // Remove fingerprint from querySubscriptionMap + if redisCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { + key := redisCmd.GetKey() + if fingerprints, ok := m.querySubscriptionMap[key]; ok { + // Remove the fingerprint from the list of fingerprints listening to this key + delete(fingerprints, fingerprint) + // If there are no more fingerprints listening to this key, remove it from the map + if len(fingerprints) == 0 { + delete(m.querySubscriptionMap, key) + } else { + // Update the map with the new set of fingerprints + m.querySubscriptionMap[key] = fingerprints + } + } + } +} + +func (m *Manager) listenForEvents(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { + affectedCmdMap := map[string]map[string]bool{"SET": {"GET": true}} + for { + select { + case <-ctx.Done(): + return + case event := <-eventChan: + m.mu.RLock() + + // Check if any watch commands are listening to updates on this key. + if _, ok := m.querySubscriptionMap[event.AffectedKey]; ok { + // iterate through all command fingerprints that are listening to this key + for fingerprint := range m.querySubscriptionMap[event.AffectedKey] { + // Check if the command associated with this fingerprint actually needs to be executed for this event. + // For instance, if the event is a SET, only execute GET commands need to be executed. This also + // helps us handle cases where a key might get updated by an unrelated command which makes it + // incompatible with the watched command. + if affectedCommands, ok := affectedCmdMap[event.Cmd]; ok { + if _, ok := affectedCommands[m.fingerprintCmdMap[fingerprint].Cmd]; ok { + // TODO: execute the command, store the result, send to clients + if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { + for client := range clients { + notifyClient(client, result) + } + } + } + } else { + m.logger.Error("Received a watch event for an unknown command type", + slog.String("cmd", event.Cmd)) + } + } + } + + m.mu.RUnlock() + } + } +} From f20957beab7050883d499f96be776129423d1998 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 16:58:03 +0530 Subject: [PATCH 02/11] initial GET.WATCH implementation --- .../commands/async/qwatch_test.go | 4 +- .../commands/resp/getwatch_test.go | 79 ++++++++++ integration_tests/commands/resp/setup.go | 9 +- internal/clientio/push_response.go | 9 +- internal/eval/commands.go | 56 ++++--- internal/eval/eval.go | 2 +- internal/querymanager/query_manager.go | 6 +- internal/server/resp/server.go | 32 ++-- internal/server/server.go | 8 +- internal/shard/shard_manager.go | 4 +- internal/shard/shard_thread.go | 4 +- internal/store/store.go | 39 +++-- internal/watchmanager/watch_manager.go | 147 +++++++++--------- internal/worker/cmd_meta.go | 13 +- internal/worker/worker.go | 103 ++++++++++-- main.go | 9 +- 16 files changed, 357 insertions(+), 167 deletions(-) create mode 100644 integration_tests/commands/resp/getwatch_test.go diff --git a/integration_tests/commands/async/qwatch_test.go b/integration_tests/commands/async/qwatch_test.go index 3b2aefda5..82ea6a878 100644 --- a/integration_tests/commands/async/qwatch_test.go +++ b/integration_tests/commands/async/qwatch_test.go @@ -84,7 +84,7 @@ func setupQWATCHTest(t *testing.T) (net.Conn, []net.Conn, func()) { subscribers := []net.Conn{getLocalConnection(), getLocalConnection(), getLocalConnection()} cleanup := func() { - cleanupKeys(publisher) + cleanupQWATCHKeys(publisher) if err := publisher.Close(); err != nil { t.Errorf("Error closing publisher connection: %v", err) } @@ -342,7 +342,7 @@ func verifyJSONUpdates(t *testing.T, rp *clientio.RESPParser, tc JSONTestCase) { } } -func cleanupKeys(publisher net.Conn) { +func cleanupQWATCHKeys(publisher net.Conn) { for _, tc := range qWatchTestCases { FireCommand(publisher, fmt.Sprintf("DEL %s:%d", tc.key, tc.userID)) } diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go new file mode 100644 index 000000000..30c32d796 --- /dev/null +++ b/integration_tests/commands/resp/getwatch_test.go @@ -0,0 +1,79 @@ +package resp + +import ( + "fmt" + "github.com/dicedb/dice/internal/clientio" + "gotest.tools/v3/assert" + "net" + "testing" + "time" +) + +const getWatchKey = "getwatchkey" + +type getWatchTestCase struct { + key string + val string +} + +var getWatchTestCases = []getWatchTestCase{ + {getWatchKey, "value1"}, + {getWatchKey, "value2"}, + {getWatchKey, "value3"}, + {getWatchKey, "value4"}, +} + +func TestGETWATCH(t *testing.T) { + publisher := getLocalConnection() + subscribers := []net.Conn{getLocalConnection(), getLocalConnection(), getLocalConnection()} + + defer func() { + if err := publisher.Close(); err != nil { + t.Errorf("Error closing publisher connection: %v", err) + } + for _, sub := range subscribers { + //FireCommand(sub, fmt.Sprintf("GET.UNWATCH %s", fingerprint)) + time.Sleep(100 * time.Millisecond) + if err := sub.Close(); err != nil { + t.Errorf("Error closing subscriber connection: %v", err) + } + } + }() + + // Fire a SET command to set a key + res := FireCommand(publisher, fmt.Sprintf("SET %s %s", getWatchKey, "value")) + assert.Equal(t, "OK", res) + + respParsers := make([]*clientio.RESPParser, len(subscribers)) + for i, subscriber := range subscribers { + rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf("GET.WATCH %s", getWatchKey)) + assert.Assert(t, rp != nil) + respParsers[i] = rp + + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + } + + // Fire updates to the key using the publisher, then check if the subscribers receive the updates in the push-response form (i.e. array of three elements, with third element being the value) + for _, tc := range getWatchTestCases { + res := FireCommand(publisher, fmt.Sprintf("SET %s %s", tc.key, tc.val)) + assert.Equal(t, "OK", res) + + for _, rp := range respParsers { + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + assert.Equal(t, "GET.WATCH", castedValue[1]) + assert.Equal(t, tc.val, castedValue[2]) + } + } +} diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index f087ec939..db579d94a 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -122,12 +122,13 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { config.DiceConfig.Server.Port = 9739 } - watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Server.KeysLimit) gec := make(chan error) - shardManager := shard.NewShardManager(1, watchChan, gec, logr) + shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec, logr) workerManager := worker.NewWorkerManager(20000, shardManager) - // Initialize the REST Server - testServer := resp.NewServer(shardManager, workerManager, gec, logr) + // Initialize the RESP Server + testServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, gec, logr) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.Server.Port) diff --git a/internal/clientio/push_response.go b/internal/clientio/push_response.go index cb0c1ab20..fc544a86a 100644 --- a/internal/clientio/push_response.go +++ b/internal/clientio/push_response.go @@ -6,11 +6,12 @@ import ( // CreatePushResponse creates a push response. Push responses refer to messages that the server sends to clients without // the client explicitly requesting them. These are typically seen in scenarios where the client has subscribed to some -// kind of event or data feed and is notified in real-time when changes occur -func CreatePushResponse(query *sql.DSQLQuery, result *[]sql.QueryResultRow) (response []interface{}) { +// kind of event or data feed and is notified in real-time when changes occur. +// `key` is the unique key that identifies the push response. +func CreatePushResponse[T any](key string, result T) (response []interface{}) { response = make([]interface{}, 3) response[0] = sql.Qwatch - response[1] = query.String() - response[2] = *result + response[1] = key + response[2] = result return } diff --git a/internal/eval/commands.go b/internal/eval/commands.go index bb2d4a29d..0c31dd6d2 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -25,6 +25,10 @@ type DiceCmdMeta struct { // will utilize this function for evaluation, allowing for better handling of // complex command execution scenarios and improved response consistency. NewEval func([]string, *dstore.Store) *EvalResponse + + // CmdEquivalent refers to the regular version of a watch command. For instance, "GET" is the regular version of "GET.WATCH". + // This field is only populated for watch commands. + CmdEquivalent string } type KeySpecs struct { @@ -621,10 +625,10 @@ var ( KeySpecs: KeySpecs{BeginIndex: 1}, } hkeysCmdMeta = DiceCmdMeta{ - Name: "HKEYS", - Info: `HKEYS command is used to retrieve all the keys(or field names) within a hash. Complexity is O(n) where n is the size of the hash.`, - Eval: evalHKEYS, - Arity: 1, + Name: "HKEYS", + Info: `HKEYS command is used to retrieve all the keys(or field names) within a hash. Complexity is O(n) where n is the size of the hash.`, + Eval: evalHKEYS, + Arity: 1, KeySpecs: KeySpecs{BeginIndex: 1}, } hsetnxCmdMeta = DiceCmdMeta{ @@ -911,27 +915,27 @@ var ( Arity: 3, KeySpecs: KeySpecs{BeginIndex: 1}, } - dumpkeyCMmdMeta=DiceCmdMeta{ - Name: "DUMP", - Info: `Serialize the value stored at key in a Redis-specific format and return it to the user. + dumpkeyCMmdMeta = DiceCmdMeta{ + Name: "DUMP", + Info: `Serialize the value stored at key in a Redis-specific format and return it to the user. The returned value can be synthesized back into a Redis key using the RESTORE command.`, - Eval: evalDUMP, - Arity: 1, - KeySpecs: KeySpecs{BeginIndex: 1}, + Eval: evalDUMP, + Arity: 1, + KeySpecs: KeySpecs{BeginIndex: 1}, } - restorekeyCmdMeta=DiceCmdMeta{ - Name: "RESTORE", - Info: `Serialize the value stored at key in a Redis-specific format and return it to the user. + restorekeyCmdMeta = DiceCmdMeta{ + Name: "RESTORE", + Info: `Serialize the value stored at key in a Redis-specific format and return it to the user. The returned value can be synthesized back into a Redis key using the RESTORE command.`, - Eval: evalRestore, - Arity: 2, + Eval: evalRestore, + Arity: 2, KeySpecs: KeySpecs{BeginIndex: 1}, } typeCmdMeta = DiceCmdMeta{ - Name: "TYPE", - Info: `Returns the string representation of the type of the value stored at key. The different types that can be returned are: string, list, set, zset, hash and stream.`, - Eval: evalTYPE, - Arity: 1, + Name: "TYPE", + Info: `Returns the string representation of the type of the value stored at key. The different types that can be returned are: string, list, set, zset, hash and stream.`, + Eval: evalTYPE, + Arity: 1, KeySpecs: KeySpecs{BeginIndex: 1}, } @@ -1016,14 +1020,24 @@ var ( Arity: -4, KeySpecs: KeySpecs{BeginIndex: 1}, } + getWatchCmdMeta = DiceCmdMeta{ + Name: "GET.WATCH", + Info: `GET.WATCH key + Returns the value of the key and starts watching for changes in the key's value. Note that some update + deliveries may be missed in case of high write rate on the given key. However, the values being delivered will + always be monotonically consistent.`, + Arity: 2, + KeySpecs: KeySpecs{BeginIndex: 1}, + CmdEquivalent: "GET", + } ) func init() { DiceCmds["PING"] = pingCmdMeta DiceCmds["ECHO"] = echoCmdMeta DiceCmds["AUTH"] = authCmdMeta - DiceCmds["DUMP"]=dumpkeyCMmdMeta - DiceCmds["RESTORE"]=restorekeyCmdMeta + DiceCmds["DUMP"] = dumpkeyCMmdMeta + DiceCmds["RESTORE"] = restorekeyCmdMeta DiceCmds["SET"] = setCmdMeta DiceCmds["GET"] = getCmdMeta DiceCmds["MSET"] = msetCmdMeta diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 2b334f9e1..7f99fabe3 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -2119,7 +2119,7 @@ func EvalQWATCH(args []string, httpOp bool, client *comm.Client, store *dstore.S } // TODO: We should return the list of all queries being watched by the client. - return clientio.Encode(clientio.CreatePushResponse(&query, queryResult.Result), false) + return clientio.Encode(clientio.CreatePushResponse(query.String(), *queryResult.Result), false) } // EvalQUNWATCH removes the specified key from the watch list for the caller client. diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index a392078f6..b7fb9a5ce 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -74,8 +74,8 @@ var ( AdhocQueryChan chan AdhocQuery ) -func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { - return ClientIdentifier{ +func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) clientio.ClientIdentifier { + return clientio.ClientIdentifier{ ClientIdentifierID: clientIdentifierID, IsHTTPClient: isHTTPClient, } @@ -231,7 +231,7 @@ func (m *Manager) updateQueryCache(queryFingerprint string, event dstore.QueryWa } func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryResult *[]sql.QueryResultRow) { - encodedResult := clientio.Encode(clientio.CreatePushResponse(query, queryResult), false) + encodedResult := clientio.Encode(clientio.CreatePushResponse(query.String(), *queryResult), false) clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index e4e617d08..32254ff1e 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + dstore "github.com/dicedb/dice/internal/store" + "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "sync" @@ -37,20 +39,24 @@ type Server struct { Port int serverFD int connBacklogSize int - wm *worker.WorkerManager - sm *shard.ShardManager + workerManager *worker.WorkerManager + shardManager *shard.ShardManager + watchManager *watchmanager.Manager + cmdWatchChan chan dstore.CmdWatchEvent globalErrorChan chan error logger *slog.Logger } -func NewServer(sm *shard.ShardManager, wm *worker.WorkerManager, gec chan error, l *slog.Logger) *Server { +func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, l *slog.Logger) *Server { return &Server{ Host: config.DiceConfig.Server.Addr, Port: config.DiceConfig.Server.Port, connBacklogSize: DefaultConnBacklogSize, - wm: wm, - sm: sm, - globalErrorChan: gec, + workerManager: workerManager, + shardManager: shardManager, + watchManager: watchmanager.NewManager(l), + cmdWatchChan: cmdWatchChan, + globalErrorChan: globalErrChan, logger: l, } } @@ -67,7 +73,13 @@ func (s *Server) Run(ctx context.Context) (err error) { // Start a go routine to accept connections errChan := make(chan error, 1) wg := &sync.WaitGroup{} - wg.Add(1) + + wg.Add(2) + go func() { + defer wg.Done() + s.watchManager.Run(ctx, s.cmdWatchChan) + }() + go func(wg *sync.WaitGroup) { defer wg.Done() if err := s.AcceptConnectionRequests(ctx, wg); err != nil { @@ -178,14 +190,14 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou parser := respparser.NewParser(s.logger) respChan := make(chan *ops.StoreResponse) wID := GenerateUniqueWorkerID() - w := worker.NewWorker(wID, respChan, ioHandler, parser, s.sm, s.globalErrorChan, s.logger) + w := worker.NewWorker(wID, respChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.logger) if err != nil { s.logger.Error("Failed to create new worker for clientFD", slog.Int("client-fd", clientFD), slog.Any("error", err)) return err } // Register the worker with the worker manager - err = s.wm.RegisterWorker(w) + err = s.workerManager.RegisterWorker(w) if err != nil { return err } @@ -198,7 +210,7 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou if err != nil { s.logger.Warn("Failed to unregister worker", slog.String("worker-id", wID), slog.Any("error", err)) } - }(s.wm, wID) + }(s.workerManager, wID) wctx, cwctx := context.WithCancel(ctx) defer cwctx() err := w.Start(wctx) diff --git a/internal/server/server.go b/internal/server/server.go index 7194cad54..273b0823a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -37,12 +37,12 @@ type AsyncServer struct { queryWatcher *querymanager.Manager shardManager *shard.ShardManager ioChan chan *ops.StoreResponse // The server acts like a worker today, this behavior will change once IOThreads are introduced and each client gets its own worker. - watchChan chan dstore.QueryWatchEvent // This is needed to co-ordinate between the store and the query watcher. + queryWatchChan chan dstore.QueryWatchEvent // This is needed to co-ordinate between the store and the query watcher. logger *slog.Logger // logger is the logger for the server } // NewAsyncServer initializes a new AsyncServer -func NewAsyncServer(shardManager *shard.ShardManager, watchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *AsyncServer { +func NewAsyncServer(shardManager *shard.ShardManager, queryWatchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *AsyncServer { return &AsyncServer{ maxClients: config.DiceConfig.Server.MaxClients, connectedClients: make(map[int]*comm.Client), @@ -50,7 +50,7 @@ func NewAsyncServer(shardManager *shard.ShardManager, watchChan chan dstore.Quer queryWatcher: querymanager.NewQueryManager(logger), multiplexerPollTimeout: config.DiceConfig.Server.MultiplexerPollTimeout, ioChan: make(chan *ops.StoreResponse, 1000), - watchChan: watchChan, + queryWatchChan: queryWatchChan, logger: logger, } } @@ -151,7 +151,7 @@ func (s *AsyncServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - s.queryWatcher.Run(watchCtx, s.watchChan) + s.queryWatcher.Run(watchCtx, s.queryWatchChan) }() s.shardManager.RegisterWorker("server", s.ioChan) diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index d5010bcbe..6feede9d4 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -26,14 +26,14 @@ type ShardManager struct { } // NewShardManager creates a new ShardManager instance with the given number of Shards and a parent context. -func NewShardManager(shardCount uint8, watchChan chan dstore.QueryWatchEvent, globalErrorChan chan error, logger *slog.Logger) *ShardManager { +func NewShardManager(shardCount uint8, queryWatchChan chan dstore.QueryWatchEvent, cmdWatchChan chan dstore.CmdWatchEvent, globalErrorChan chan error, logger *slog.Logger) *ShardManager { shards := make([]*ShardThread, shardCount) shardReqMap := make(map[ShardID]chan *ops.StoreOp) shardErrorChan := make(chan *ShardError) for i := uint8(0); i < shardCount; i++ { // Shards are numbered from 0 to shardCount-1 - shard := NewShardThread(i, globalErrorChan, shardErrorChan, watchChan, logger) + shard := NewShardThread(i, globalErrorChan, shardErrorChan, queryWatchChan, cmdWatchChan, logger) shards[i] = shard shardReqMap[i] = shard.ReqChan } diff --git a/internal/shard/shard_thread.go b/internal/shard/shard_thread.go index 144111ba1..29d078e62 100644 --- a/internal/shard/shard_thread.go +++ b/internal/shard/shard_thread.go @@ -37,10 +37,10 @@ type ShardThread struct { } // NewShardThread creates a new ShardThread instance with the given shard id and error channel. -func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, watchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *ShardThread { +func NewShardThread(id ShardID, gec chan error, sec chan *ShardError, queryWatchChan chan dstore.QueryWatchEvent, cmdWatchChan chan dstore.CmdWatchEvent, logger *slog.Logger) *ShardThread { return &ShardThread{ id: id, - store: dstore.NewStore(watchChan), + store: dstore.NewStore(queryWatchChan, cmdWatchChan), ReqChan: make(chan *ops.StoreOp, 1000), workerMap: make(map[string]chan *ops.StoreResponse), globalErrorChan: gec, diff --git a/internal/store/store.go b/internal/store/store.go index 75f018fcf..e564329ee 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -48,17 +48,19 @@ type CmdWatchEvent struct { } type Store struct { - store common.ITable[string, *object.Obj] - expires common.ITable[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. - numKeys int - watchChan chan QueryWatchEvent + store common.ITable[string, *object.Obj] + expires common.ITable[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. + numKeys int + queryWatchChan chan QueryWatchEvent + cmdWatchChan chan CmdWatchEvent } -func NewStore(watchChan chan QueryWatchEvent) *Store { +func NewStore(queryWatchChan chan QueryWatchEvent, cmdWatchChan chan CmdWatchEvent) *Store { return &Store{ - store: NewStoreRegMap(), - expires: NewExpireRegMap(), - watchChan: watchChan, + store: NewStoreRegMap(), + expires: NewExpireRegMap(), + queryWatchChan: queryWatchChan, + cmdWatchChan: cmdWatchChan, } } @@ -154,9 +156,12 @@ func (store *Store) putHelper(k string, obj *object.Obj, opts ...PutOption) { } store.store.Put(k, obj) - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(k, Set, *obj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("SET", k) + } } // getHelper is a helper function to get the object from the store. It also updates the last accessed time if touch is true. @@ -254,9 +259,12 @@ func (store *Store) Rename(sourceKey, destKey string) bool { store.numKeys-- // Notify watchers about the deletion of the source key - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(sourceKey, Del, *sourceObj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("DEL", sourceKey) + } return true } @@ -297,9 +305,12 @@ func (store *Store) deleteKey(k string, obj *object.Obj) bool { store.expires.Delete(obj) store.numKeys-- - if store.watchChan != nil { + if store.queryWatchChan != nil { store.notifyQueryManager(k, Del, *obj) } + if store.cmdWatchChan != nil { + store.notifyWatchManager("DEL", k) + } return true } @@ -317,7 +328,11 @@ func (store *Store) delByPtr(ptr string) bool { // notifyQueryManager notifies the query manager about a key change, so that it can update the query cache if needed. func (store *Store) notifyQueryManager(k, operation string, obj object.Obj) { - store.watchChan <- QueryWatchEvent{k, operation, obj} + store.queryWatchChan <- QueryWatchEvent{k, operation, obj} +} + +func (store *Store) notifyWatchManager(cmd string, affectedKey string) { + store.cmdWatchChan <- CmdWatchEvent{cmd, affectedKey} } func (store *Store) GetStore() common.ITable[string, *object.Obj] { diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 322cbf6d8..9bf6e90aa 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -2,9 +2,7 @@ package watchmanager import ( "context" - "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/cmd" - "github.com/dicedb/dice/internal/comm" dstore "github.com/dicedb/dice/internal/store" "log/slog" "sync" @@ -12,67 +10,66 @@ import ( type ( WatchSubscription struct { - Subscribe bool // Subscribe is true for subscribe, false for unsubscribe - WatchCmd cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. - Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. - ClientFD int // ClientFD is the file descriptor of the client connection - CmdWatchClientChan chan comm.CmdWatchResponse // CmdWatchClientChan is the generic channel for HTTP/Websockets etc. - ClientIdentifierID uint32 // ClientIdentifierID Helps identify CmdWatch client on httpserver side + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. + AdhocReqChan chan *cmd.RedisCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. + WatchCmd *cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. } Manager struct { - querySubscriptionMap map[string]map[uint32]bool // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] - tcpSubscriptionMap map[uint32]map[clientio.ClientIdentifier]bool // tcpSubscriptionMap is a map of fingerprint -> [client1, client2, ...] - fingerprintCmdMap map[uint32]cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd - mu sync.RWMutex + querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[chan *cmd.RedisCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] + fingerprintCmdMap map[uint32]*cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd logger *slog.Logger } ) var ( CmdWatchSubscriptionChan chan WatchSubscription + affectedCmdMap = map[string]map[string]struct{}{ + "SET": {"GET": struct{}{}}, + "DEL": {"GET": struct{}{}}, + } ) func NewManager(logger *slog.Logger) *Manager { CmdWatchSubscriptionChan = make(chan WatchSubscription) return &Manager{ - querySubscriptionMap: make(map[string]map[uint32]bool), - tcpSubscriptionMap: make(map[uint32]map[clientio.ClientIdentifier]bool), - fingerprintCmdMap: make(map[uint32]cmd.RedisCmd), + querySubscriptionMap: make(map[string]map[uint32]struct{}), + tcpSubscriptionMap: make(map[uint32]map[chan *cmd.RedisCmd]struct{}), + fingerprintCmdMap: make(map[uint32]*cmd.RedisCmd), logger: logger, } } // Run starts the watch manager, listening for subscription requests and events -func (m *Manager) Run(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { +func (m *Manager) Run(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - m.listenForSubscriptions(ctx) - }() + wg.Add(1) go func() { defer wg.Done() - m.listenForEvents(ctx, eventChan) + m.listenForEvents(ctx, cmdWatchChan) }() + <-ctx.Done() wg.Wait() } -// listenForSubscriptions handles incoming subscription requests -func (m *Manager) listenForSubscriptions(ctx context.Context) { +func (m *Manager) listenForEvents(ctx context.Context, cmdWatchChan chan dstore.CmdWatchEvent) { for { select { + case <-ctx.Done(): + return case sub := <-CmdWatchSubscriptionChan: if sub.Subscribe { m.handleSubscription(sub) } else { m.handleUnsubscription(sub) } - case <-ctx.Done(): - return + case watchEvent := <-cmdWatchChan: + m.handleWatchEvent(watchEvent) } } } @@ -82,38 +79,29 @@ func (m *Manager) handleSubscription(sub WatchSubscription) { fingerprint := sub.WatchCmd.GetFingerprint() key := sub.WatchCmd.GetKey() - client := clientio.NewClientIdentifier(sub.ClientFD, false) - - m.mu.Lock() - defer m.mu.Unlock() - // Add fingerprint to querySubscriptionMap - if m.querySubscriptionMap[key] == nil { - m.querySubscriptionMap[key] = make(map[uint32]bool) + if _, exists := m.querySubscriptionMap[key]; !exists { + m.querySubscriptionMap[key] = make(map[uint32]struct{}) } - m.querySubscriptionMap[key][fingerprint] = true + m.querySubscriptionMap[key][fingerprint] = struct{}{} // Add RedisCmd to fingerprintCmdMap m.fingerprintCmdMap[fingerprint] = sub.WatchCmd - // Add clientID to tcpSubscriptionMap - if m.tcpSubscriptionMap[fingerprint] == nil { - m.tcpSubscriptionMap[fingerprint] = make(map[clientio.ClientIdentifier]bool) + // Add client channel to tcpSubscriptionMap + if _, exists := m.tcpSubscriptionMap[fingerprint]; !exists { + m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.RedisCmd]struct{}) } - m.tcpSubscriptionMap[fingerprint][client] = true + m.tcpSubscriptionMap[fingerprint][sub.AdhocReqChan] = struct{}{} } // handleUnsubscription processes an unsubscription request func (m *Manager) handleUnsubscription(sub WatchSubscription) { fingerprint := sub.Fingerprint - client := clientio.NewClientIdentifier(sub.ClientFD, false) - - m.mu.Lock() - defer m.mu.Unlock() // Remove clientID from tcpSubscriptionMap if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { - delete(clients, client) + delete(clients, sub.AdhocReqChan) // If there are no more clients listening to this fingerprint, remove it from the map if len(clients) == 0 { // Remove the fingerprint from tcpSubscriptionMap @@ -122,6 +110,7 @@ func (m *Manager) handleUnsubscription(sub WatchSubscription) { delete(m.fingerprintCmdMap, fingerprint) } else { // Update the map with the new set of clients + // TODO: Is this actually required? m.tcpSubscriptionMap[fingerprint] = clients } } @@ -136,47 +125,51 @@ func (m *Manager) handleUnsubscription(sub WatchSubscription) { if len(fingerprints) == 0 { delete(m.querySubscriptionMap, key) } else { - // Update the map with the new set of fingerprints + // Update the map with the new set of fingerprints. + // TODO: Is this actually required? m.querySubscriptionMap[key] = fingerprints } } } } -func (m *Manager) listenForEvents(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { - affectedCmdMap := map[string]map[string]bool{"SET": {"GET": true}} - for { - select { - case <-ctx.Done(): - return - case event := <-eventChan: - m.mu.RLock() - - // Check if any watch commands are listening to updates on this key. - if _, ok := m.querySubscriptionMap[event.AffectedKey]; ok { - // iterate through all command fingerprints that are listening to this key - for fingerprint := range m.querySubscriptionMap[event.AffectedKey] { - // Check if the command associated with this fingerprint actually needs to be executed for this event. - // For instance, if the event is a SET, only execute GET commands need to be executed. This also - // helps us handle cases where a key might get updated by an unrelated command which makes it - // incompatible with the watched command. - if affectedCommands, ok := affectedCmdMap[event.Cmd]; ok { - if _, ok := affectedCommands[m.fingerprintCmdMap[fingerprint].Cmd]; ok { - // TODO: execute the command, store the result, send to clients - if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { - for client := range clients { - notifyClient(client, result) - } - } - } - } else { - m.logger.Error("Received a watch event for an unknown command type", - slog.String("cmd", event.Cmd)) - } - } - } +func (m *Manager) handleWatchEvent(event dstore.CmdWatchEvent) { + // Check if any watch commands are listening to updates on this key. + fingerprints, exists := m.querySubscriptionMap[event.AffectedKey] + if !exists { + return + } + + affectedCommands, cmdExists := affectedCmdMap[event.Cmd] + if !cmdExists { + m.logger.Error("Received a watch event for an unknown command type", + slog.String("cmd", event.Cmd)) + return + } - m.mu.RUnlock() + // iterate through all command fingerprints that are listening to this key + for fingerprint := range fingerprints { + cmdToExecute := m.fingerprintCmdMap[fingerprint] + // Check if the command associated with this fingerprint actually needs to be executed for this event. + // For instance, if the event is a SET, only GET commands need to be executed. This also + // helps us handle cases where a key might get updated by an unrelated command which makes it + // incompatible with the watched command. + if _, affected := affectedCommands[cmdToExecute.Cmd]; affected { + m.notifyClients(fingerprint, cmdToExecute) } } } + +// notifyClients sends cmd to all clients listening to this fingerprint, so that they can execute it. +func (m *Manager) notifyClients(fingerprint uint32, cmd *cmd.RedisCmd) { + clients, exists := m.tcpSubscriptionMap[fingerprint] + if !exists { + m.logger.Warn("No clients found for fingerprint", + slog.Uint64("fingerprint", uint64(fingerprint))) + return + } + + for clientChan := range clients { + clientChan <- cmd + } +} diff --git a/internal/worker/cmd_meta.go b/internal/worker/cmd_meta.go index 76ab8fe85..8fe6c8ae7 100644 --- a/internal/worker/cmd_meta.go +++ b/internal/worker/cmd_meta.go @@ -15,6 +15,7 @@ const ( SingleShard MultiShard Custom + Watch ) // Global commands @@ -26,9 +27,10 @@ const ( // Single-shard commands. const ( - CmdSet = "SET" - CmdGet = "GET" - CmdGetSet = "GETSET" + CmdSet = "SET" + CmdGet = "GET" + CmdGetSet = "GETSET" + CmdGetWatch = "GET.WATCH" ) type CmdMeta struct { @@ -69,6 +71,9 @@ var CommandsMeta = map[string]CmdMeta{ CmdGetSet: { CmdType: SingleShard, }, + CmdGetWatch: { + CmdType: Watch, + }, } func init() { @@ -92,7 +97,7 @@ func validateCmdMeta(c string, meta CmdMeta) error { if meta.decomposeCommand == nil || meta.composeResponse == nil { return fmt.Errorf("multi-shard command %s must have both decomposeCommand and composeResponse implemented", c) } - case SingleShard, Custom: + case SingleShard, Watch, Custom: // No specific validations for these types currently default: return fmt.Errorf("unknown command type for %s", c) diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 392e35f29..2c6471160 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "syscall" @@ -34,6 +35,7 @@ type BaseWorker struct { parser requestparser.Parser shardManager *shard.ShardManager respChan chan *ops.StoreResponse + adhocReqChan chan *cmd.RedisCmd Session *auth.Session globalErrorChan chan error logger *slog.Logger @@ -52,6 +54,7 @@ func NewWorker(wid string, respChan chan *ops.StoreResponse, respChan: respChan, logger: logger, Session: auth.NewSession(), + adhocReqChan: make(chan *cmd.RedisCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. } } @@ -61,6 +64,21 @@ func (w *BaseWorker) ID() string { func (w *BaseWorker) Start(ctx context.Context) error { errChan := make(chan error, 1) + + dataChan := make(chan []byte) + readErrChan := make(chan error) + + go func() { + for { + data, err := w.ioHandler.Read(ctx) + if err != nil { + readErrChan <- err + return + } + dataChan <- data + } + }() + for { select { case <-ctx.Done(): @@ -77,12 +95,10 @@ func (w *BaseWorker) Start(ctx context.Context) error { } } return fmt.Errorf("error writing response: %w", err) - default: - data, err := w.ioHandler.Read(ctx) - if err != nil { - w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } + case cmdReq := <-w.adhocReqChan: + // Handle adhoc requests of RedisCmd + w.executeCommandHandler(errChan, nil, ctx, []*cmd.RedisCmd{cmdReq}, true) + case data := <-dataChan: cmds, err := w.parser.Parse(data) if err != nil { err = w.ioHandler.Write(ctx, err) @@ -120,26 +136,35 @@ func (w *BaseWorker) Start(ctx context.Context) error { } // executeCommand executes the command and return the response back to the client func(errChan chan error) { - execctx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests + execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests defer cancel() - err = w.executeCommand(execctx, cmds[0]) - if err != nil { - w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { - w.logger.Debug("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) - errChan <- err - } - } + w.executeCommandHandler(errChan, err, execCtx, cmds, false) }(errChan) + case err := <-errChan: + w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } + } +} + +func (w *BaseWorker) executeCommandHandler(errChan chan error, err error, execCtx context.Context, cmds []*cmd.RedisCmd, isWatchNotification bool) { + err = w.executeCommand(execCtx, cmds[0], isWatchNotification) + if err != nil { + w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { + w.logger.Debug("Connection closed for worker", slog.String("workerID", w.id), slog.Any("error", err)) + errChan <- err } } } -func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd) error { +func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd, isWatchNotification bool) error { // Break down the single command into multiple commands if multisharding is supported. // The length of cmdList helps determine how many shards to wait for responses. cmdList := make([]*cmd.RedisCmd, 0) + localErrChan := make(chan error, 1) + // Retrieve metadata for the command to determine if multisharding is supported. meta, ok := CommandsMeta[redisCmd.Cmd] if !ok { @@ -180,21 +205,50 @@ func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd) default: cmdList = append(cmdList, redisCmd) } + case Watch: + // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass + // it along as is. + watchCmd := &cmd.RedisCmd{ + Cmd: redisCmd.Cmd[:len(redisCmd.Cmd)-6], // Remove the .WATCH suffix + Args: redisCmd.Args, + } + + cmdList = append(cmdList, watchCmd) + + go func() { + err := <-localErrChan + if err != nil { + return + } + watchmanager.CmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: watchCmd, + AdhocReqChan: w.adhocReqChan, + } + }() } } // Scatter the broken-down commands to the appropriate shards. err := w.scatter(ctx, cmdList) if err != nil { + localErrChan <- err return err } + cmdType := meta.CmdType + if isWatchNotification { + cmdType = Watch + } + // Gather the responses from the shards and write them to the buffer. - err = w.gather(ctx, redisCmd.Cmd, len(cmdList), meta.CmdType) + err = w.gather(ctx, redisCmd.Cmd, len(cmdList), cmdType) if err != nil { + localErrChan <- err return err } + localErrChan <- nil return nil } @@ -302,6 +356,21 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } + case Watch: + if evalResp[0].Error != nil { + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse("GET.WATCH", evalResp[0].Error)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + } + + return err + } + + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse("GET.WATCH", evalResp[0].Result)) + if err != nil { + w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } default: w.logger.Error("Unknown command type", slog.String("workerID", w.id)) diff --git a/main.go b/main.go index d68485092..d4a5fe612 100644 --- a/main.go +++ b/main.go @@ -47,7 +47,8 @@ func main() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) - watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + queryWatchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + cmdWatchChan := make(chan dstore.CmdWatchEvent, config.DiceConfig.Server.KeysLimit) var serverErrCh chan error // Get the number of available CPU cores on the machine using runtime.NumCPU(). @@ -73,7 +74,7 @@ func main() { runtime.GOMAXPROCS(numCores) // Initialize the ShardManager - shardManager := shard.NewShardManager(uint8(numCores), watchChan, serverErrCh, logr) + shardManager := shard.NewShardManager(uint8(numCores), queryWatchChan, cmdWatchChan, serverErrCh, logr) wg := sync.WaitGroup{} @@ -88,7 +89,7 @@ func main() { // Initialize the AsyncServer server // Find a port and bind it if !config.EnableMultiThreading { - asyncServer := server.NewAsyncServer(shardManager, watchChan, logr) + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, cmdWatchChan, logr) if err := asyncServer.FindPortAndBind(); err != nil { cancel() logr.Error("Error finding and binding port", slog.Any("error", err)) @@ -183,7 +184,7 @@ func main() { }() } - websocketServer := server.NewWebSocketServer(shardManager, watchChan, logr) + websocketServer := server.NewWebSocketServer(shardManager, queryWatchChan, logr) serverWg.Add(1) go func() { defer serverWg.Done() From c2b48af88005ec8b9b9f81415ba1887a44c10faa Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 17:07:12 +0530 Subject: [PATCH 03/11] linter --- integration_tests/commands/async/setup.go | 2 +- .../commands/async/touch_test.go | 2 +- integration_tests/commands/http/set_test.go | 6 ++-- integration_tests/commands/http/setup.go | 2 +- integration_tests/commands/websocket/setup.go | 2 +- internal/cmd/cmds.go | 3 +- internal/eval/bloom_test.go | 4 +-- internal/eval/commands.go | 14 -------- internal/eval/eval_test.go | 36 +++++++++---------- internal/eval/hmap_test.go | 2 +- internal/eval/main_test.go | 2 +- internal/server/resp/server.go | 5 +-- internal/sql/executerbechmark_test.go | 30 ++++++++-------- internal/sql/executor_test.go | 22 ++++++------ internal/store/expire_test.go | 2 +- internal/store/lfu_eviction_test.go | 2 +- internal/store/store.go | 2 +- internal/watchmanager/watch_manager.go | 35 +++++++++--------- internal/worker/worker.go | 23 +++++++----- main.go | 4 +-- 20 files changed, 97 insertions(+), 103 deletions(-) diff --git a/integration_tests/commands/async/setup.go b/integration_tests/commands/async/setup.go index c922c3eaa..ec359f8f1 100644 --- a/integration_tests/commands/async/setup.go +++ b/integration_tests/commands/async/setup.go @@ -122,7 +122,7 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption var err error watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) gec := make(chan error) - shardManager := shard.NewShardManager(1, watchChan, gec, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, gec, opt.Logger) // Initialize the AsyncServer testServer := server.NewAsyncServer(shardManager, watchChan, opt.Logger) diff --git a/integration_tests/commands/async/touch_test.go b/integration_tests/commands/async/touch_test.go index 963cb20cd..bf51f5088 100644 --- a/integration_tests/commands/async/touch_test.go +++ b/integration_tests/commands/async/touch_test.go @@ -19,7 +19,7 @@ func TestTouch(t *testing.T) { delay []time.Duration }{ { - name: "Touch Simple Cmd", + name: "Touch Simple Value", commands: []string{"SET foo bar", "OBJECT IDLETIME foo", "TOUCH foo", "OBJECT IDLETIME foo"}, expected: []interface{}{"OK", int64(2), int64(1), int64(0)}, assertType: []string{"equal", "assert", "equal", "assert"}, diff --git a/integration_tests/commands/http/set_test.go b/integration_tests/commands/http/set_test.go index 42e170903..bf5d8d6d0 100644 --- a/integration_tests/commands/http/set_test.go +++ b/integration_tests/commands/http/set_test.go @@ -20,7 +20,7 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Cmd", + name: "Set and Get Simple Value", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v"}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, @@ -28,12 +28,12 @@ func TestSet(t *testing.T) { expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Cmd", + name: "Set and Get Integer Value", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": 123456789}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, - expected: []interface{}{"OK", 1.23456789e+08}, + expected: []interface{}{"OK", "1.23456789e+08"}, }, { name: "Overwrite Existing Key", diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index b0c20bf28..a10526cb3 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -103,7 +103,7 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) queryWatcherLocal := querymanager.NewQueryManager(opt.Logger) config.HTTPPort = opt.Port // Initialize the HTTPServer diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index 406f5fd52..10f718c77 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -99,7 +99,7 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO // Initialize the WebsocketServer globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) - shardManager := shard.NewShardManager(1, watchChan, globalErrChannel, opt.Logger) + shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) config.WebsocketPort = opt.Port testServer := server.NewWebSocketServer(shardManager, watchChan, opt.Logger) diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index fd830c71e..423eb68e0 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -2,8 +2,9 @@ package cmd import ( "fmt" - "github.com/dgryski/go-farm" "strings" + + "github.com/dgryski/go-farm" ) type DiceDBCmd struct { diff --git a/internal/eval/bloom_test.go b/internal/eval/bloom_test.go index 78f97fabe..5a2577e70 100644 --- a/internal/eval/bloom_test.go +++ b/internal/eval/bloom_test.go @@ -15,7 +15,7 @@ import ( ) func TestBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // This test only contains some basic checks for all the bloom filter // operations like BFINIT, BFADD, BFEXISTS. It assumes that the // functions called in the main function are working correctly and @@ -112,7 +112,7 @@ func TestBloomFilter(t *testing.T) { } func TestGetOrCreateBloomFilter(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Create a key and default opts key := "bf" opts, _ := newBloomOpts([]string{}, true) diff --git a/internal/eval/commands.go b/internal/eval/commands.go index e0f18544b..4c512b757 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -25,10 +25,6 @@ type DiceCmdMeta struct { // will utilize this function for evaluation, allowing for better handling of // complex command execution scenarios and improved response consistency. NewEval func([]string, *dstore.Store) *EvalResponse - - // CmdEquivalent refers to the regular version of a watch command. For instance, "GET" is the regular version of "GET.WATCH". - // This field is only populated for watch commands. - CmdEquivalent string } type KeySpecs struct { @@ -1051,16 +1047,6 @@ var ( Arity: -4, KeySpecs: KeySpecs{BeginIndex: 1}, } - getWatchCmdMeta = DiceCmdMeta{ - Name: "GET.WATCH", - Info: `GET.WATCH key - Returns the value of the key and starts watching for changes in the key's value. Note that some update - deliveries may be missed in case of high write rate on the given key. However, the values being delivered will - always be monotonically consistent.`, - Arity: 2, - KeySpecs: KeySpecs{BeginIndex: 1}, - CmdEquivalent: "GET", - } geoAddCmdMeta = DiceCmdMeta{ Name: "GEOADD", Info: `Adds one or more members to a geospatial index. The key is created if it doesn't exist.`, diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index c99455823..a068a9581 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -42,7 +42,7 @@ func setupTest(store *dstore.Store) *dstore.Store { } func TestEval(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) testEvalMSET(t, store) testEvalECHO(t, store) @@ -1108,7 +1108,7 @@ func testEvalJSONOBJLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -2747,13 +2747,13 @@ func runEvalTests(t *testing.T, tests map[string]evalTestCase, evalFunc func([]s func BenchmarkEvalMSET(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) } } func BenchmarkEvalHSET(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -2888,7 +2888,7 @@ func testEvalHKEYS(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHKEYS(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSET([]string{"KEY", fmt.Sprintf("FIELD_%d", i), fmt.Sprintf("VALUE_%d", i)}, store) @@ -2899,7 +2899,7 @@ func BenchmarkEvalHKEYS(b *testing.B) { } } func BenchmarkEvalPFCOUNT(b *testing.B) { - store := *dstore.NewStore(nil) + store := *dstore.NewStore(nil, nil) // Helper function to create and insert HLL objects createAndInsertHLL := func(key string, items []string) { @@ -3247,7 +3247,7 @@ func testEvalHLEN(t *testing.T, store *dstore.Store) { func BenchmarkEvalHLEN(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("HashSize_%d", size), func(b *testing.B) { @@ -3489,7 +3489,7 @@ func testEvalTYPE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalTYPE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Define different types of objects to benchmark objectTypes := map[string]func(){ @@ -3680,7 +3680,7 @@ func testEvalJSONOBJKEYS(t *testing.T, store *dstore.Store) { func BenchmarkEvalJSONOBJKEYS(b *testing.B) { sizes := []int{0, 10, 100, 1000, 10000, 100000} // Various sizes of JSON objects - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, size := range sizes { b.Run(fmt.Sprintf("JSONObjectSize_%d", size), func(b *testing.B) { @@ -3848,7 +3848,7 @@ func testEvalGETRANGE(t *testing.T, store *dstore.Store) { } func BenchmarkEvalGETRANGE(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("BENCHMARK_KEY", store.NewObj("Hello World", maxExDuration, object.ObjTypeString, object.ObjEncodingRaw)) inputs := []struct { @@ -3873,7 +3873,7 @@ func BenchmarkEvalGETRANGE(b *testing.B) { } func BenchmarkEvalHSETNX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalHSETNX([]string{"KEY", fmt.Sprintf("FIELD_%d", i/2), fmt.Sprintf("VALUE_%d", i)}, store) } @@ -3935,7 +3935,7 @@ func testEvalHSETNX(t *testing.T, store *dstore.Store) { } func TestMSETConsistency(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) evalMSET([]string{"KEY", "VAL", "KEY2", "VAL2"}, store) assert.Equal(t, "VAL", store.Get("KEY").Value) @@ -3943,7 +3943,7 @@ func TestMSETConsistency(t *testing.T) { } func BenchmarkEvalHINCRBY(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // creating new fields for i := 0; i < b.N; i++ { @@ -4193,7 +4193,7 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { } func BenchmarkEvalSETEX(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -4368,7 +4368,7 @@ func testEvalINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.Put("key1", store.NewObj("1", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) store.Put("key2", store.NewObj("1.2", maxExDuration, object.ObjTypeString, object.ObjEncodingEmbStr)) @@ -4483,7 +4483,7 @@ func testEvalBITOP(t *testing.T, store *dstore.Store) { } func BenchmarkEvalBITOP(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setup initial data for benchmarking store.Put("key1", store.NewObj(&ByteArray{data: []byte{0x01, 0x02, 0xff}}, maxExDuration, object.ObjTypeByteArray, object.ObjEncodingByteArray)) @@ -4777,7 +4777,7 @@ func testEvalAPPEND(t *testing.T, store *dstore.Store) { } func BenchmarkEvalAPPEND(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for i := 0; i < b.N; i++ { evalAPPEND([]string{"key", fmt.Sprintf("val_%d", i)}, store) } @@ -5250,7 +5250,7 @@ func testEvalHINCRBYFLOAT(t *testing.T, store *dstore.Store) { } func BenchmarkEvalHINCRBYFLOAT(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) // Setting initial fields with some values store.Put("key1", store.NewObj(HashMap{"field1": "1.0", "field2": "1.2"}, maxExDuration, object.ObjTypeHashMap, object.ObjEncodingHashMap)) diff --git a/internal/eval/hmap_test.go b/internal/eval/hmap_test.go index 07bfea205..cba318717 100644 --- a/internal/eval/hmap_test.go +++ b/internal/eval/hmap_test.go @@ -71,7 +71,7 @@ func TestHashMapIncrementValue(t *testing.T) { } func TestGetValueFromHashMap(t *testing.T) { - store := store.NewStore(nil) + store := store.NewStore(nil, nil) key := "key1" field := "field1" value := "value1" diff --git a/internal/eval/main_test.go b/internal/eval/main_test.go index 7b79a58dc..04d622233 100644 --- a/internal/eval/main_test.go +++ b/internal/eval/main_test.go @@ -13,7 +13,7 @@ func TestMain(m *testing.M) { l := logger.New(logger.Opts{WithTimestamp: false}) slog.SetDefault(l) - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) store.ResetStore() exitCode := m.Run() diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 32254ff1e..3bf4d06fd 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - dstore "github.com/dicedb/dice/internal/store" - "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "sync" @@ -13,6 +11,9 @@ import ( "syscall" "time" + dstore "github.com/dicedb/dice/internal/store" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio/iohandler/netconn" respparser "github.com/dicedb/dice/internal/clientio/requestparser/resp" diff --git a/internal/sql/executerbechmark_test.go b/internal/sql/executerbechmark_test.go index e7a86388c..b2a935d48 100644 --- a/internal/sql/executerbechmark_test.go +++ b/internal/sql/executerbechmark_test.go @@ -35,7 +35,7 @@ func generateBenchmarkData(count int, store *dstore.Store) { } func BenchmarkExecuteQueryOrderBykey(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -60,7 +60,7 @@ func BenchmarkExecuteQueryOrderBykey(b *testing.B) { } func BenchmarkExecuteQueryBasicOrderByValue(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -83,7 +83,7 @@ func BenchmarkExecuteQueryBasicOrderByValue(b *testing.B) { } func BenchmarkExecuteQueryLimit(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -106,7 +106,7 @@ func BenchmarkExecuteQueryLimit(b *testing.B) { } func BenchmarkExecuteQueryNoMatch(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -129,7 +129,7 @@ func BenchmarkExecuteQueryNoMatch(b *testing.B) { } func BenchmarkExecuteQueryWithBasicWhere(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -152,7 +152,7 @@ func BenchmarkExecuteQueryWithBasicWhere(b *testing.B) { } func BenchmarkExecuteQueryWithComplexWhere(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -175,7 +175,7 @@ func BenchmarkExecuteQueryWithComplexWhere(b *testing.B) { } func BenchmarkExecuteQueryWithCompareWhereKeyandValue(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -198,7 +198,7 @@ func BenchmarkExecuteQueryWithCompareWhereKeyandValue(b *testing.B) { } func BenchmarkExecuteQueryWithBasicWhereNoMatch(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -221,7 +221,7 @@ func BenchmarkExecuteQueryWithBasicWhereNoMatch(b *testing.B) { } func BenchmarkExecuteQueryWithCaseSesnsitivity(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -243,7 +243,7 @@ func BenchmarkExecuteQueryWithCaseSesnsitivity(b *testing.B) { } func BenchmarkExecuteQueryWithClauseOnKey(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -266,7 +266,7 @@ func BenchmarkExecuteQueryWithClauseOnKey(b *testing.B) { } func BenchmarkExecuteQueryWithAllMatchingKeyRegex(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizes { generateBenchmarkData(v, store) @@ -308,7 +308,7 @@ func generateBenchmarkJSONData(b *testing.B, count int, json string, store *dsto } func BenchmarkExecuteQueryWithJSON(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -333,7 +333,7 @@ func BenchmarkExecuteQueryWithJSON(b *testing.B) { } func BenchmarkExecuteQueryWithNestedJSON(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -358,7 +358,7 @@ func BenchmarkExecuteQueryWithNestedJSON(b *testing.B) { } func BenchmarkExecuteQueryWithJsonInLeftAndRightExpressions(b *testing.B) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for _, v := range benchmarkDataSizesJSON { for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) @@ -384,7 +384,7 @@ func BenchmarkExecuteQueryWithJsonInLeftAndRightExpressions(b *testing.B) { func BenchmarkExecuteQueryWithJsonNoMatch(b *testing.B) { for _, v := range benchmarkDataSizesJSON { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) for jsonSize, json := range jsonList { generateBenchmarkJSONData(b, v, json, store) diff --git a/internal/sql/executor_test.go b/internal/sql/executor_test.go index fbee39bc3..7feb9f4d3 100644 --- a/internal/sql/executor_test.go +++ b/internal/sql/executor_test.go @@ -42,7 +42,7 @@ func setup(store *dstore.Store, dataset []keyValue) { } func TestExecuteQueryOrderBykey(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryString := "SELECT $key, $value WHERE $key like 'k*' ORDER BY $key ASC" @@ -69,7 +69,7 @@ func TestExecuteQueryOrderBykey(t *testing.T) { } func TestExecuteQueryBasicOrderByValue(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $key, $value WHERE $key like 'k*' ORDER BY $value ASC" @@ -96,7 +96,7 @@ func TestExecuteQueryBasicOrderByValue(t *testing.T) { } func TestExecuteQueryLimit(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $value WHERE $key like 'k*' ORDER BY $key ASC LIMIT 3" @@ -123,7 +123,7 @@ func TestExecuteQueryLimit(t *testing.T) { } func TestExecuteQueryNoMatch(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) queryStr := "SELECT $key, $value WHERE $key like 'x*'" @@ -137,7 +137,7 @@ func TestExecuteQueryNoMatch(t *testing.T) { } func TestExecuteQueryWithWhere(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("BasicWhereClause", func(t *testing.T) { queryStr := "SELECT $key, $value WHERE $value = 'v3' AND $key like 'k*'" @@ -190,7 +190,7 @@ func TestExecuteQueryWithWhere(t *testing.T) { } func TestExecuteQueryWithIncompatibleTypes(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("ComparingStrWithInt", func(t *testing.T) { @@ -205,7 +205,7 @@ func TestExecuteQueryWithIncompatibleTypes(t *testing.T) { } func TestExecuteQueryWithEdgeCases(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, simpleKVDataset) t.Run("CaseSensitivity", func(t *testing.T) { @@ -285,7 +285,7 @@ func setupJSON(t *testing.T, store *dstore.Store, dataset []keyValue) { } func TestExecuteQueryWithJsonExpressionInWhere(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setupJSON(t, store, jsonWhereClauseDataset) t.Run("BasicWhereClauseWithJSON", func(t *testing.T) { @@ -394,7 +394,7 @@ var jsonOrderDataset = []keyValue{ } func TestExecuteQueryWithJsonOrderBy(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setupJSON(t, store, jsonOrderDataset) t.Run("OrderBySimpleJSONField", func(t *testing.T) { @@ -557,7 +557,7 @@ var stringComparisonDataset = []keyValue{ } func TestExecuteQueryWithLikeStringComparisons(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, stringComparisonDataset) testCases := []struct { @@ -642,7 +642,7 @@ func TestExecuteQueryWithLikeStringComparisons(t *testing.T) { } func TestExecuteQueryWithStringNotLikeComparisons(t *testing.T) { - store := dstore.NewStore(nil) + store := dstore.NewStore(nil, nil) setup(store, stringComparisonDataset) testCases := []struct { diff --git a/internal/store/expire_test.go b/internal/store/expire_test.go index afc9982ae..bace26856 100644 --- a/internal/store/expire_test.go +++ b/internal/store/expire_test.go @@ -7,7 +7,7 @@ import ( ) func TestDelExpiry(t *testing.T) { - store := NewStore(nil) + store := NewStore(nil, nil) // Initialize the test environment store.store = NewStoreMap() store.expires = NewExpireMap() diff --git a/internal/store/lfu_eviction_test.go b/internal/store/lfu_eviction_test.go index 9ccd90930..79d079364 100644 --- a/internal/store/lfu_eviction_test.go +++ b/internal/store/lfu_eviction_test.go @@ -11,7 +11,7 @@ import ( func TestLFUEviction(t *testing.T) { originalEvictionPolicy := config.DiceConfig.Server.EvictionPolicy - store := NewStore(nil) + store := NewStore(nil, nil) config.DiceConfig.Server.EvictionPolicy = config.EvictAllKeysLFU // Define test cases diff --git a/internal/store/store.go b/internal/store/store.go index e564329ee..1b8b607fd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -331,7 +331,7 @@ func (store *Store) notifyQueryManager(k, operation string, obj object.Obj) { store.queryWatchChan <- QueryWatchEvent{k, operation, obj} } -func (store *Store) notifyWatchManager(cmd string, affectedKey string) { +func (store *Store) notifyWatchManager(cmd, affectedKey string) { store.cmdWatchChan <- CmdWatchEvent{cmd, affectedKey} } diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 9bf6e90aa..289dd4753 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -2,24 +2,25 @@ package watchmanager import ( "context" - "github.com/dicedb/dice/internal/cmd" - dstore "github.com/dicedb/dice/internal/store" "log/slog" "sync" + + "github.com/dicedb/dice/internal/cmd" + dstore "github.com/dicedb/dice/internal/store" ) type ( WatchSubscription struct { - Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. - AdhocReqChan chan *cmd.RedisCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. - WatchCmd *cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. - Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe. Required. + AdhocReqChan chan *cmd.DiceDBCmd // AdhocReqChan is the channel to send adhoc requests to the worker. Required. + WatchCmd *cmd.DiceDBCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. } Manager struct { - querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] - tcpSubscriptionMap map[uint32]map[chan *cmd.RedisCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] - fingerprintCmdMap map[uint32]*cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd + querySubscriptionMap map[string]map[uint32]struct{} // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[chan *cmd.DiceDBCmd]struct{} // tcpSubscriptionMap is a map of fingerprint -> [client1Chan, client2Chan, ...] + fingerprintCmdMap map[uint32]*cmd.DiceDBCmd // fingerprintCmdMap is a map of fingerprint -> DiceDBCmd logger *slog.Logger } ) @@ -36,8 +37,8 @@ func NewManager(logger *slog.Logger) *Manager { CmdWatchSubscriptionChan = make(chan WatchSubscription) return &Manager{ querySubscriptionMap: make(map[string]map[uint32]struct{}), - tcpSubscriptionMap: make(map[uint32]map[chan *cmd.RedisCmd]struct{}), - fingerprintCmdMap: make(map[uint32]*cmd.RedisCmd), + tcpSubscriptionMap: make(map[uint32]map[chan *cmd.DiceDBCmd]struct{}), + fingerprintCmdMap: make(map[uint32]*cmd.DiceDBCmd), logger: logger, } } @@ -85,12 +86,12 @@ func (m *Manager) handleSubscription(sub WatchSubscription) { } m.querySubscriptionMap[key][fingerprint] = struct{}{} - // Add RedisCmd to fingerprintCmdMap + // Add DiceDBCmd to fingerprintCmdMap m.fingerprintCmdMap[fingerprint] = sub.WatchCmd // Add client channel to tcpSubscriptionMap if _, exists := m.tcpSubscriptionMap[fingerprint]; !exists { - m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.RedisCmd]struct{}) + m.tcpSubscriptionMap[fingerprint] = make(map[chan *cmd.DiceDBCmd]struct{}) } m.tcpSubscriptionMap[fingerprint][sub.AdhocReqChan] = struct{}{} } @@ -116,8 +117,8 @@ func (m *Manager) handleUnsubscription(sub WatchSubscription) { } // Remove fingerprint from querySubscriptionMap - if redisCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { - key := redisCmd.GetKey() + if diceDBCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { + key := diceDBCmd.GetKey() if fingerprints, ok := m.querySubscriptionMap[key]; ok { // Remove the fingerprint from the list of fingerprints listening to this key delete(fingerprints, fingerprint) @@ -161,7 +162,7 @@ func (m *Manager) handleWatchEvent(event dstore.CmdWatchEvent) { } // notifyClients sends cmd to all clients listening to this fingerprint, so that they can execute it. -func (m *Manager) notifyClients(fingerprint uint32, cmd *cmd.RedisCmd) { +func (m *Manager) notifyClients(fingerprint uint32, diceDBCmd *cmd.DiceDBCmd) { clients, exists := m.tcpSubscriptionMap[fingerprint] if !exists { m.logger.Warn("No clients found for fingerprint", @@ -170,6 +171,6 @@ func (m *Manager) notifyClients(fingerprint uint32, cmd *cmd.RedisCmd) { } for clientChan := range clients { - clientChan <- cmd + clientChan <- diceDBCmd } } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index fc7ae0203..705b2149c 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -4,12 +4,13 @@ import ( "context" "errors" "fmt" - "github.com/dicedb/dice/internal/watchmanager" "log/slog" "net" "syscall" "time" + "github.com/dicedb/dice/internal/watchmanager" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio" @@ -35,7 +36,7 @@ type BaseWorker struct { parser requestparser.Parser shardManager *shard.ShardManager respChan chan *ops.StoreResponse - adhocReqChan chan *cmd.RedisCmd + adhocReqChan chan *cmd.DiceDBCmd Session *auth.Session globalErrorChan chan error logger *slog.Logger @@ -54,7 +55,7 @@ func NewWorker(wid string, respChan chan *ops.StoreResponse, respChan: respChan, logger: logger, Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.RedisCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. + adhocReqChan: make(chan *cmd.DiceDBCmd, 20), // assuming we wouldn't have more than 20 adhoc requests being sent at a time. } } @@ -96,8 +97,12 @@ func (w *BaseWorker) Start(ctx context.Context) error { } return fmt.Errorf("error writing response: %w", err) case cmdReq := <-w.adhocReqChan: - // Handle adhoc requests of RedisCmd - w.executeCommandHandler(errChan, nil, ctx, []*cmd.RedisCmd{cmdReq}, true) + // Handle adhoc requests of DiceDBCmd + func() { + execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests + defer cancel() + w.executeCommandHandler(execCtx, errChan, []*cmd.DiceDBCmd{cmdReq}, true) + }() case data := <-dataChan: cmds, err := w.parser.Parse(data) if err != nil { @@ -138,17 +143,17 @@ func (w *BaseWorker) Start(ctx context.Context) error { func(errChan chan error) { execCtx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests defer cancel() - w.executeCommandHandler(errChan, err, execCtx, cmds, false) + w.executeCommandHandler(execCtx, errChan, cmds, false) }(errChan) - case err := <-errChan: + case err := <-readErrChan: w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err } } } -func (w *BaseWorker) executeCommandHandler(errChan chan error, err error, execCtx context.Context, cmds []*cmd.RedisCmd, isWatchNotification bool) { - err = w.executeCommand(execCtx, cmds[0], isWatchNotification) +func (w *BaseWorker) executeCommandHandler(execCtx context.Context, errChan chan error, cmds []*cmd.DiceDBCmd, isWatchNotification bool) { + err := w.executeCommand(execCtx, cmds[0], isWatchNotification) if err != nil { w.logger.Error("Error executing command", slog.String("workerID", w.id), slog.Any("error", err)) if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ETIMEDOUT) { diff --git a/main.go b/main.go index 0fae87dc8..5cd217b5c 100644 --- a/main.go +++ b/main.go @@ -92,7 +92,7 @@ func main() { // Initialize the AsyncServer server // Find a port and bind it if !config.EnableMultiThreading { - asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, cmdWatchChan, logr) + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, logr) if err := asyncServer.FindPortAndBind(); err != nil { cancel() logr.Error("Error finding and binding port", slog.Any("error", err)) @@ -155,7 +155,7 @@ func main() { } else { workerManager := worker.NewWorkerManager(config.DiceConfig.Server.MaxClients, shardManager) // Initialize the RESP Server - respServer := resp.NewServer(shardManager, workerManager, serverErrCh, logr) + respServer := resp.NewServer(shardManager, workerManager, cmdWatchChan, serverErrCh, logr) serverWg.Add(1) go func() { defer serverWg.Done() From 8f40506d7dc4fc0d55b52e37f087556efa58d420 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 18:18:25 +0530 Subject: [PATCH 04/11] fixes --- integration_tests/commands/http/set_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/commands/http/set_test.go b/integration_tests/commands/http/set_test.go index bf5d8d6d0..a35c92042 100644 --- a/integration_tests/commands/http/set_test.go +++ b/integration_tests/commands/http/set_test.go @@ -33,7 +33,7 @@ func TestSet(t *testing.T) { {Command: "SET", Body: map[string]interface{}{"key": "k", "value": 123456789}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, }, - expected: []interface{}{"OK", "1.23456789e+08"}, + expected: []interface{}{"OK", 1.23456789e+08}, }, { name: "Overwrite Existing Key", From 720eb778391ba9525be97ac4fe3eb9c69340d8ef Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 18:23:43 +0530 Subject: [PATCH 05/11] moar fixes --- integration_tests/commands/async/set_test.go | 30 ++++++++++---------- internal/eval/store_eval.go | 6 ++-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/integration_tests/commands/async/set_test.go b/integration_tests/commands/async/set_test.go index 98f1b5082..611b937c1 100644 --- a/integration_tests/commands/async/set_test.go +++ b/integration_tests/commands/async/set_test.go @@ -20,12 +20,12 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Cmd", + name: "Set and Get Simple Value", commands: []string{"SET k v", "GET k"}, expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Cmd", + name: "Set and Get Integer Value", commands: []string{"SET k 123456789", "GET k"}, expected: []interface{}{"OK", int64(123456789)}, }, @@ -146,31 +146,31 @@ func TestSetWithExat(t *testing.T) { func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Cmd mismatch for cmd SET k v EXAT "+Etime) - assert.Equal(t, "v", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Cmd mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Value mismatch for cmd SET k v EXAT "+Etime) + assert.Equal(t, "v", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Value mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Cmd mismatch for cmd TTL k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Value mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") }) t.Run("SET with invalid EXAT expires key immediately", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Cmd mismatch for cmd SET k v EXAT "+BadTime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") }) t.Run("SET with EXAT and PXAT returns syntax error", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Cmd mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") }) } @@ -187,7 +187,7 @@ func TestWithKeepTTLFlag(t *testing.T) { for i := 0; i < len(tcase.commands); i++ { cmd := tcase.commands[i] out := tcase.expected[i] - assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) } } @@ -196,5 +196,5 @@ func TestWithKeepTTLFlag(t *testing.T) { cmd := "GET k" out := "(nil)" - assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) } diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index 07a858f44..d55659b1e 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -201,7 +201,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { // Decode and return the value based on its encoding switch _, oEnc := object.ExtractTypeEncoding(obj); oEnc { case object.ObjEncodingInt: - // Cmd is stored as an int64, so use type assertion + // Value is stored as an int64, so use type assertion if val, ok := obj.Value.(int64); ok { return &EvalResponse{ Result: val, @@ -215,7 +215,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { } case object.ObjEncodingEmbStr, object.ObjEncodingRaw: - // Cmd is stored as a string, use type assertion + // Value is stored as a string, use type assertion if val, ok := obj.Value.(string); ok { return &EvalResponse{ Result: val, @@ -228,7 +228,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { } case object.ObjEncodingByteArray: - // Cmd is stored as a bytearray, use type assertion + // Value is stored as a bytearray, use type assertion if val, ok := obj.Value.(*ByteArray); ok { return &EvalResponse{ Result: string(val.data), From 6d1f44d04c03ed89c77f3716c14434e6575960ac Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 18:29:25 +0530 Subject: [PATCH 06/11] undo unwanted changes --- internal/clientio/client_identifier.go | 13 ------------- internal/querymanager/query_manager.go | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 22 deletions(-) delete mode 100644 internal/clientio/client_identifier.go diff --git a/internal/clientio/client_identifier.go b/internal/clientio/client_identifier.go deleted file mode 100644 index 91ea3bce0..000000000 --- a/internal/clientio/client_identifier.go +++ /dev/null @@ -1,13 +0,0 @@ -package clientio - -type ClientIdentifier struct { - ClientIdentifierID int - IsHTTPClient bool -} - -func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { - return ClientIdentifier{ - ClientIdentifierID: clientIdentifierID, - IsHTTPClient: isHTTPClient, - } -} diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index 41bfc58c0..eca014b43 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -64,6 +64,11 @@ type ( Query string `json:"query"` Data []any `json:"data"` } + + ClientIdentifier struct { + ClientIdentifierID int + IsHTTPClient bool + } ) var ( @@ -74,8 +79,8 @@ var ( AdhocQueryChan chan AdhocQuery ) -func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) clientio.ClientIdentifier { - return clientio.ClientIdentifier{ +func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { + return ClientIdentifier{ ClientIdentifierID: clientIdentifierID, IsHTTPClient: isHTTPClient, } @@ -141,11 +146,11 @@ func (m *Manager) listenForSubscriptions(ctx context.Context) { for { select { case event := <-QuerySubscriptionChan: - var client clientio.ClientIdentifier + var client ClientIdentifier if event.QwatchClientChan != nil { - client = clientio.NewClientIdentifier(int(event.ClientIdentifierID), true) + client = NewClientIdentifier(int(event.ClientIdentifierID), true) } else { - client = clientio.NewClientIdentifier(event.ClientFD, false) + client = NewClientIdentifier(event.ClientFD, false) } if event.Subscribe { @@ -235,7 +240,7 @@ func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryRe clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly - switch clientIdentifier := clientKey.(clientio.ClientIdentifier); { + switch clientIdentifier := clientKey.(ClientIdentifier); { case clientIdentifier.IsHTTPClient: qwatchClientResponseChannel := clientVal.(chan comm.QwatchResponse) qwatchClientResponseChannel <- comm.QwatchResponse{ @@ -285,7 +290,7 @@ func (m *Manager) sendWithRetry(query *sql.DSQLQuery, clientFD int, data []byte) slog.Int("client", clientFD), slog.Any("error", err), ) - m.removeWatcher(query, clientio.NewClientIdentifier(clientFD, false), nil) + m.removeWatcher(query, NewClientIdentifier(clientFD, false), nil) return } } @@ -308,7 +313,7 @@ func (m *Manager) serveAdhocQueries(ctx context.Context) { } // addWatcher adds a client as a watcher to a query. -func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, +func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, qwatchClientChan chan comm.QwatchResponse, cacheChan chan *[]struct { Key string Value *object.Obj @@ -338,7 +343,7 @@ func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier clientio.Cli } // removeWatcher removes a client from the watchlist for a query. -func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, +func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, qwatchClientChan chan comm.QwatchResponse) { queryString := query.String() if clients, ok := m.WatchList.Load(queryString); ok { From d5aa8730babf8fa12a7b123d4e18ea8ffe6c38e5 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 23:21:27 +0530 Subject: [PATCH 07/11] cleaned up worker code --- .../commands/resp/getwatch_test.go | 2 +- internal/clientio/push_response.go | 9 +- internal/eval/eval.go | 2 +- internal/querymanager/query_manager.go | 2 +- internal/worker/worker.go | 147 +++++++++--------- 5 files changed, 83 insertions(+), 79 deletions(-) diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go index 30c32d796..70e291311 100644 --- a/integration_tests/commands/resp/getwatch_test.go +++ b/integration_tests/commands/resp/getwatch_test.go @@ -72,7 +72,7 @@ func TestGETWATCH(t *testing.T) { t.Errorf("Type assertion to []interface{} failed for value: %v", v) } assert.Equal(t, 3, len(castedValue)) - assert.Equal(t, "GET.WATCH", castedValue[1]) + assert.Equal(t, "GET", castedValue[1]) assert.Equal(t, tc.val, castedValue[2]) } } diff --git a/internal/clientio/push_response.go b/internal/clientio/push_response.go index fc544a86a..94593742f 100644 --- a/internal/clientio/push_response.go +++ b/internal/clientio/push_response.go @@ -1,16 +1,17 @@ package clientio -import ( - "github.com/dicedb/dice/internal/sql" +const ( + ResponseTypeRegular = iota + ResponseTypePush ) // CreatePushResponse creates a push response. Push responses refer to messages that the server sends to clients without // the client explicitly requesting them. These are typically seen in scenarios where the client has subscribed to some // kind of event or data feed and is notified in real-time when changes occur. // `key` is the unique key that identifies the push response. -func CreatePushResponse[T any](key string, result T) (response []interface{}) { +func CreatePushResponse[T any](cmd, key string, result T) (response []interface{}) { response = make([]interface{}, 3) - response[0] = sql.Qwatch + response[0] = cmd response[1] = key response[2] = result return diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 635669ec1..7e75fda43 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -2119,7 +2119,7 @@ func EvalQWATCH(args []string, httpOp bool, client *comm.Client, store *dstore.S } // TODO: We should return the list of all queries being watched by the client. - return clientio.Encode(clientio.CreatePushResponse(query.String(), *queryResult.Result), false) + return clientio.Encode(clientio.CreatePushResponse(sql.Qwatch, query.String(), *queryResult.Result), false) } // EvalQUNWATCH removes the specified key from the watch list for the caller client. diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index eca014b43..02d0e8975 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -236,7 +236,7 @@ func (m *Manager) updateQueryCache(queryFingerprint string, event dstore.QueryWa } func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryResult *[]sql.QueryResultRow) { - encodedResult := clientio.Encode(clientio.CreatePushResponse(query.String(), *queryResult), false) + encodedResult := clientio.Encode(clientio.CreatePushResponse(sql.Qwatch, query.String(), *queryResult), false) clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 705b2149c..296e528ab 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -164,12 +164,11 @@ func (w *BaseWorker) executeCommandHandler(execCtx context.Context, errChan chan } func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, isWatchNotification bool) error { + responseType := clientio.ResponseTypeRegular // Break down the single command into multiple commands if multisharding is supported. // The length of cmdList helps determine how many shards to wait for responses. cmdList := make([]*cmd.DiceDBCmd, 0) - localErrChan := make(chan error, 1) - // Retrieve metadata for the command to determine if multisharding is supported. meta, ok := CommandsMeta[diceDBCmd.Cmd] if !ok { @@ -187,10 +186,39 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm case SingleShard: // For single-shard or custom commands, process them without breaking up. cmdList = append(cmdList, diceDBCmd) + responseType = clientio.ResponseTypeRegular case MultiShard: // If the command supports multisharding, break it down into multiple commands. cmdList = meta.decomposeCommand(diceDBCmd) + responseType = clientio.ResponseTypeRegular + case Watch: + // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass + // it along as is. + watchCmd := &cmd.DiceDBCmd{ + Cmd: diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6], // Remove the .WATCH suffix + Args: diceDBCmd.Args, + } + + cmdList = append(cmdList, watchCmd) + + // Execute the command (scatter and gather) + if err := w.scatter(ctx, cmdList); err != nil { + return err + } + + if err := w.gather(ctx, diceDBCmd, len(cmdList), clientio.ResponseTypePush); err != nil { + return err + } + + // Proceed to subscribe after successful execution + watchmanager.CmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: watchCmd, + AdhocReqChan: w.adhocReqChan, + } + + return nil case Custom: switch diceDBCmd.Cmd { case CmdAuth: @@ -210,50 +238,24 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm default: cmdList = append(cmdList, diceDBCmd) } - case Watch: - // Generate the Cmd being watched. All we need to do is remove the .WATCH suffix from the command and pass - // it along as is. - watchCmd := &cmd.DiceDBCmd{ - Cmd: diceDBCmd.Cmd[:len(diceDBCmd.Cmd)-6], // Remove the .WATCH suffix - Args: diceDBCmd.Args, - } - - cmdList = append(cmdList, watchCmd) - - go func() { - err := <-localErrChan - if err != nil { - return - } - watchmanager.CmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: true, - WatchCmd: watchCmd, - AdhocReqChan: w.adhocReqChan, - } - }() } } // Scatter the broken-down commands to the appropriate shards. - err := w.scatter(ctx, cmdList) - if err != nil { - localErrChan <- err + if err := w.scatter(ctx, cmdList); err != nil { return err } - cmdType := meta.CmdType + // For watch notifications, we need to set the responseType to push if isWatchNotification { - cmdType = Watch + responseType = clientio.ResponseTypePush } // Gather the responses from the shards and write them to the buffer. - err = w.gather(ctx, diceDBCmd.Cmd, len(cmdList), cmdType) - if err != nil { - localErrChan <- err + if err := w.gather(ctx, diceDBCmd, len(cmdList), responseType); err != nil { return err } - localErrChan <- nil return nil } @@ -294,8 +296,8 @@ func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd) error { // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdType) error { - // Loop to wait for messages from numberof shards +func (w *BaseWorker) gather(ctx context.Context, cmd *cmd.DiceDBCmd, numCmds int, responseType int) error { + // Loop to wait for messages from number of shards var evalResp []eval.EvalResponse for numCmds != 0 { select { @@ -314,38 +316,38 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy } } - // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. - // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. - // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. - // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. - // If not found, treat it as a command that's not yet refactored, and write the response back to the client. - val, ok := CommandsMeta[c] - if !ok { - if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, []byte(evalResp[0].Error.Error())) - if err != nil { - w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } - } - - err := w.ioHandler.Write(ctx, evalResp[0].Result.([]byte)) + switch responseType { + case clientio.ResponseTypeRegular: + return w.handleRegularResponse(ctx, cmd, evalResp) + case clientio.ResponseTypePush: + return w.handlePushResponse(ctx, cmd, cmd.Cmd, evalResp) + default: + w.logger.Error("Unknown response type", slog.String("workerID", w.id), slog.Int("responseType", responseType)) + err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - return nil } +} - switch ct { - case SingleShard, Custom: +// handleRegularResponse handles the response for regular commands, i.e., responses for which are pushed from the server to the client. +func (w *BaseWorker) handleRegularResponse(ctx context.Context, cmd *cmd.DiceDBCmd, evalResp []eval.EvalResponse) error { + // Check if the command is multi-shard capable + // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. + // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. + // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. + // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. + // If not found, treat it as a command that's not yet refactored, and write the response back to the client. + val, ok := CommandsMeta[cmd.Cmd] + if !ok || val.CmdType == SingleShard || val.CmdType == Custom { + // Handle single-shard or custom commands if evalResp[0].Error != nil { err := w.ioHandler.Write(ctx, evalResp[0].Error) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) } - return err } @@ -354,38 +356,39 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - - case MultiShard: + } else if val.CmdType == MultiShard { + // Handle multi-shard commands err := w.ioHandler.Write(ctx, val.composeResponse(evalResp...)) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } - case Watch: - if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse("GET.WATCH", evalResp[0].Error)) - if err != nil { - w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - } - - return err - } - - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse("GET.WATCH", evalResp[0].Result)) + } else { + w.logger.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", cmd.Cmd), slog.Any("evalResp", evalResp)) + err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err } + } + return nil +} - default: - w.logger.Error("Unknown command type", slog.String("workerID", w.id)) - err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) +// handlePushResponse handles the response for push commands, i.e., responses for which are pushed from the server to the client. +func (w *BaseWorker) handlePushResponse(ctx context.Context, cmd *cmd.DiceDBCmd, pushResponseKey string, evalResp []eval.EvalResponse) error { + if evalResp[0].Error != nil { + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmd.Cmd, pushResponseKey, evalResp[0].Error)) if err != nil { - w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err + w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) } + return err } + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmd.Cmd, pushResponseKey, evalResp[0].Result)) + if err != nil { + w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) + return err + } return nil } From 5c4b71e9205ff8316fc27a27070a83880c796660 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 23:26:46 +0530 Subject: [PATCH 08/11] improve tests --- integration_tests/commands/resp/getwatch_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go index 70e291311..7dd748794 100644 --- a/integration_tests/commands/resp/getwatch_test.go +++ b/integration_tests/commands/resp/getwatch_test.go @@ -72,6 +72,7 @@ func TestGETWATCH(t *testing.T) { t.Errorf("Type assertion to []interface{} failed for value: %v", v) } assert.Equal(t, 3, len(castedValue)) + assert.Equal(t, "GET", castedValue[0]) assert.Equal(t, "GET", castedValue[1]) assert.Equal(t, tc.val, castedValue[2]) } From 7643136c5f50069a4adaeaa313bc87231274bd1c Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 23:28:45 +0530 Subject: [PATCH 09/11] fix linter --- internal/worker/worker.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 296e528ab..45963b275 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -252,11 +252,8 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm } // Gather the responses from the shards and write them to the buffer. - if err := w.gather(ctx, diceDBCmd, len(cmdList), responseType); err != nil { - return err - } - - return nil + err := w.gather(ctx, diceDBCmd, len(cmdList), responseType) + return err } // scatter distributes the DiceDB commands to the respective shards based on the key. @@ -296,7 +293,7 @@ func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd) error { // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). -func (w *BaseWorker) gather(ctx context.Context, cmd *cmd.DiceDBCmd, numCmds int, responseType int) error { +func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds, responseType int) error { // Loop to wait for messages from number of shards var evalResp []eval.EvalResponse for numCmds != 0 { @@ -318,9 +315,9 @@ func (w *BaseWorker) gather(ctx context.Context, cmd *cmd.DiceDBCmd, numCmds int switch responseType { case clientio.ResponseTypeRegular: - return w.handleRegularResponse(ctx, cmd, evalResp) + return w.handleRegularResponse(ctx, diceDBCmd, evalResp) case clientio.ResponseTypePush: - return w.handlePushResponse(ctx, cmd, cmd.Cmd, evalResp) + return w.handlePushResponse(ctx, diceDBCmd, diceDBCmd.Cmd, evalResp) default: w.logger.Error("Unknown response type", slog.String("workerID", w.id), slog.Int("responseType", responseType)) err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) @@ -333,14 +330,14 @@ func (w *BaseWorker) gather(ctx context.Context, cmd *cmd.DiceDBCmd, numCmds int } // handleRegularResponse handles the response for regular commands, i.e., responses for which are pushed from the server to the client. -func (w *BaseWorker) handleRegularResponse(ctx context.Context, cmd *cmd.DiceDBCmd, evalResp []eval.EvalResponse) error { +func (w *BaseWorker) handleRegularResponse(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, evalResp []eval.EvalResponse) error { // Check if the command is multi-shard capable // TODO: This is a temporary solution. In the future, all commands should be refactored to be multi-shard compatible. // TODO: There are a few commands such as QWATCH, RENAME, MGET, MSET that wouldn't work in multi-shard mode without refactoring. // TODO: These commands should be refactored to be multi-shard compatible before DICE-DB is completely multi-shard. // Check if command is part of the new WorkerCommandsMeta map i.e. if the command has been refactored to be multi-shard compatible. // If not found, treat it as a command that's not yet refactored, and write the response back to the client. - val, ok := CommandsMeta[cmd.Cmd] + val, ok := CommandsMeta[diceDBCmd.Cmd] if !ok || val.CmdType == SingleShard || val.CmdType == Custom { // Handle single-shard or custom commands if evalResp[0].Error != nil { @@ -364,7 +361,7 @@ func (w *BaseWorker) handleRegularResponse(ctx context.Context, cmd *cmd.DiceDBC return err } } else { - w.logger.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", cmd.Cmd), slog.Any("evalResp", evalResp)) + w.logger.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", evalResp)) err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) @@ -375,16 +372,16 @@ func (w *BaseWorker) handleRegularResponse(ctx context.Context, cmd *cmd.DiceDBC } // handlePushResponse handles the response for push commands, i.e., responses for which are pushed from the server to the client. -func (w *BaseWorker) handlePushResponse(ctx context.Context, cmd *cmd.DiceDBCmd, pushResponseKey string, evalResp []eval.EvalResponse) error { +func (w *BaseWorker) handlePushResponse(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, pushResponseKey string, evalResp []eval.EvalResponse) error { if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmd.Cmd, pushResponseKey, evalResp[0].Error)) + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(diceDBCmd.Cmd, pushResponseKey, evalResp[0].Error)) if err != nil { w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) } return err } - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmd.Cmd, pushResponseKey, evalResp[0].Result)) + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(diceDBCmd.Cmd, pushResponseKey, evalResp[0].Result)) if err != nil { w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err From 9dadb33fd1ee7dd020f86fc0764a1fb8aceff91e Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 23:57:25 +0530 Subject: [PATCH 10/11] add fingerprint to push response --- integration_tests/commands/resp/getwatch_test.go | 2 +- internal/worker/worker.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/integration_tests/commands/resp/getwatch_test.go b/integration_tests/commands/resp/getwatch_test.go index 7dd748794..b5d5e6aef 100644 --- a/integration_tests/commands/resp/getwatch_test.go +++ b/integration_tests/commands/resp/getwatch_test.go @@ -73,7 +73,7 @@ func TestGETWATCH(t *testing.T) { } assert.Equal(t, 3, len(castedValue)) assert.Equal(t, "GET", castedValue[0]) - assert.Equal(t, "GET", castedValue[1]) + assert.Equal(t, "1768826704", castedValue[1]) assert.Equal(t, tc.val, castedValue[2]) } } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 45963b275..7f2b24473 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -317,7 +317,7 @@ func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCm case clientio.ResponseTypeRegular: return w.handleRegularResponse(ctx, diceDBCmd, evalResp) case clientio.ResponseTypePush: - return w.handlePushResponse(ctx, diceDBCmd, diceDBCmd.Cmd, evalResp) + return w.handlePushResponse(ctx, diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), evalResp) default: w.logger.Error("Unknown response type", slog.String("workerID", w.id), slog.Int("responseType", responseType)) err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) @@ -372,16 +372,16 @@ func (w *BaseWorker) handleRegularResponse(ctx context.Context, diceDBCmd *cmd.D } // handlePushResponse handles the response for push commands, i.e., responses for which are pushed from the server to the client. -func (w *BaseWorker) handlePushResponse(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, pushResponseKey string, evalResp []eval.EvalResponse) error { +func (w *BaseWorker) handlePushResponse(ctx context.Context, cmdName string, pushResponseKey string, evalResp []eval.EvalResponse) error { if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(diceDBCmd.Cmd, pushResponseKey, evalResp[0].Error)) + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmdName, pushResponseKey, evalResp[0].Error)) if err != nil { w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) } return err } - err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(diceDBCmd.Cmd, pushResponseKey, evalResp[0].Result)) + err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmdName, pushResponseKey, evalResp[0].Result)) if err != nil { w.logger.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err From 22527eb9faff45ce36a9ef84e4538b429d692871 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sun, 6 Oct 2024 23:59:25 +0530 Subject: [PATCH 11/11] linter --- internal/worker/worker.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 7f2b24473..3d3e04be1 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -372,7 +372,7 @@ func (w *BaseWorker) handleRegularResponse(ctx context.Context, diceDBCmd *cmd.D } // handlePushResponse handles the response for push commands, i.e., responses for which are pushed from the server to the client. -func (w *BaseWorker) handlePushResponse(ctx context.Context, cmdName string, pushResponseKey string, evalResp []eval.EvalResponse) error { +func (w *BaseWorker) handlePushResponse(ctx context.Context, cmdName, pushResponseKey string, evalResp []eval.EvalResponse) error { if evalResp[0].Error != nil { err := w.ioHandler.Write(ctx, clientio.CreatePushResponse(cmdName, pushResponseKey, evalResp[0].Error)) if err != nil {