From 9f8a69ed6f72cb3358bc4bef9738005b7f586d60 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh Date: Sat, 9 Nov 2024 03:26:26 +0530 Subject: [PATCH] Refactors worker logic (#1257) --- .../commands/resp/getunwatch_test.go | 2 +- internal/worker/worker.go | 250 ++++++++++-------- 2 files changed, 140 insertions(+), 112 deletions(-) diff --git a/integration_tests/commands/resp/getunwatch_test.go b/integration_tests/commands/resp/getunwatch_test.go index 542260974..790bc6fdb 100644 --- a/integration_tests/commands/resp/getunwatch_test.go +++ b/integration_tests/commands/resp/getunwatch_test.go @@ -14,7 +14,7 @@ import ( const ( getUnwatchKey = "getunwatchkey" - fingerprint = "3557732805" + fingerprint = "426696421" ) type getUnwatchTestCase struct { diff --git a/internal/worker/worker.go b/internal/worker/worker.go index a8f03a999..8a33c28e5 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -270,29 +270,7 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm // Unsubscribe Unwatch command type if meta.CmdType == Unwatch { - // extract the fingerprint - command := cmdList[len(cmdList)-1] - fp, fperr := strconv.ParseUint(command.Args[0], 10, 32) - if fperr != nil { - err := w.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint) - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return fperr - } - - // send the unsubscribe request - w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: false, - AdhocReqChan: w.adhocReqChan, - Fingerprint: uint32(fp), - } - - err := w.ioHandler.Write(ctx, "OK") - if err != nil { - return fmt.Errorf("error sending push response to client: %v", err) - } - return nil + return w.handleCommandUnwatch(ctx, cmdList) } // Scatter the broken-down commands to the appropriate shards. @@ -307,13 +285,46 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm if meta.CmdType == Watch { // Proceed to subscribe after successful execution - w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ - Subscribe: true, - WatchCmd: cmdList[len(cmdList)-1], - AdhocReqChan: w.adhocReqChan, + w.handleCommandWatch(cmdList) + } + + return nil +} + +// handleCommandWatch sends a watch subscription request to the watch manager. +func (w *BaseWorker) handleCommandWatch(cmdList []*cmd.DiceDBCmd) { + w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: true, + WatchCmd: cmdList[len(cmdList)-1], + AdhocReqChan: w.adhocReqChan, + } +} + +// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client. +// The response is sent before the unwatch request is processed by the watch manager. +func (w *BaseWorker) handleCommandUnwatch(ctx context.Context, cmdList []*cmd.DiceDBCmd) error { + // extract the fingerprint + command := cmdList[len(cmdList)-1] + fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32) + if parseErr != nil { + err := w.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint) + if err != nil { + return fmt.Errorf("error sending push response to client: %v", err) } + return parseErr + } + + // send the unsubscribe request + w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{ + Subscribe: false, + AdhocReqChan: w.adhocReqChan, + Fingerprint: uint32(fp), } + err := w.ioHandler.Write(ctx, clientio.RespOK) + if err != nil { + return fmt.Errorf("error sending push response to client: %v", err) + } return nil } @@ -327,131 +338,148 @@ func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd) error { return ctx.Err() default: for i := uint8(0); i < uint8(len(cmds)); i++ { - var rc chan *ops.StoreOp - var sid shard.ShardID - var key string - if len(cmds[i].Args) > 0 { - key = cmds[i].Args[0] - } else { - key = cmds[i].Cmd - } + shardID, responseChan := w.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i])) - sid, rc = w.shardManager.GetShardInfo(key) - - rc <- &ops.StoreOp{ + responseChan <- &ops.StoreOp{ SeqID: i, RequestID: GenerateUniqueRequestID(), Cmd: cmds[i], WorkerID: w.id, - ShardID: sid, + ShardID: shardID, Client: nil, } } } - return nil } +// getRoutingKeyFromCommand determines the key used for shard routing +func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string { + if len(diceDBCmd.Args) > 0 { + return diceDBCmd.Args[0] + } + return diceDBCmd.Cmd +} + // gather collects the responses from multiple shards and writes the results into the provided buffer. // It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard). func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool) error { - // Loop to wait for messages from number of shards + // Collect responses from all shards + storeOp, err := w.gatherResponses(ctx, numCmds) + if err != nil { + return err + } + + if len(storeOp) == 0 { + slog.Error("No response from shards", + slog.String("workerID", w.id), + slog.String("command", diceDBCmd.Cmd)) + return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd) + } + + if isWatchNotification { + return w.handleWatchNotification(ctx, diceDBCmd, storeOp[0]) + } + + // Process command based on its type + cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd] + if !ok { + return w.handleLegacyCommand(ctx, storeOp[0]) + } + + return w.handleCommand(ctx, cmdMeta, diceDBCmd, storeOp) +} + +// gatherResponses collects responses from all shards +func (w *BaseWorker) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) { var storeOp []ops.StoreResponse - for numCmds != 0 { + + for numCmds > 0 { select { case <-ctx.Done(): - slog.Error("Timed out waiting for response from shards", slog.String("workerID", w.id), slog.Any("error", ctx.Err())) + slog.Error("Timed out waiting for response from shards", + slog.String("workerID", w.id), + slog.Any("error", ctx.Err())) + return nil, ctx.Err() + case resp, ok := <-w.responseChan: if ok { storeOp = append(storeOp, *resp) } numCmds-- - continue + case sError, ok := <-w.shardManager.ShardErrorChan: if ok { - slog.Error("Error from shard", slog.String("workerID", w.id), slog.Any("error", sError)) + slog.Error("Error from shard", + slog.String("workerID", w.id), + slog.Any("error", sError)) + return nil, sError.Error } } } - val, ok := CommandsMeta[diceDBCmd.Cmd] + return storeOp, nil +} + +// handleWatchNotification processes watch notification responses +func (w *BaseWorker) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse) error { + fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint()) - if isWatchNotification { - if storeOp[0].EvalResponse.Error != nil { - err := w.ioHandler.Write(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), storeOp[0].EvalResponse.Error)) - if err != nil { - slog.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) - } - return err - } + if resp.EvalResponse.Error != nil { + return w.writeResponse(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fingerprint, resp.EvalResponse.Error)) + } - err := w.ioHandler.Write(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), storeOp[0].EvalResponse.Result)) - if err != nil { - slog.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } - return nil // Exit after handling watch case + return w.writeResponse(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fingerprint, resp.EvalResponse.Result)) +} + +// handleLegacyCommand processes commands not in CommandsMeta +func (w *BaseWorker) handleLegacyCommand(ctx context.Context, resp ops.StoreResponse) error { + if resp.EvalResponse.Error != nil { + return w.writeResponse(ctx, resp.EvalResponse.Error) } + return w.writeResponse(ctx, resp.EvalResponse.Result) +} - // TODO: Remove it once we have migrated all the commands - if !ok { +// handleCommand processes commands based on their type +func (w *BaseWorker) handleCommand(ctx context.Context, cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error { + var err error + + switch cmdMeta.CmdType { + case SingleShard, Custom: if storeOp[0].EvalResponse.Error != nil { - err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Error) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - } - return err + err = w.writeResponse(ctx, storeOp[0].EvalResponse.Error) + } else { + err = w.writeResponse(ctx, storeOp[0].EvalResponse.Result) } - err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Result) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err + if err == nil && w.wl != nil { + w.wl.LogCommand(diceDBCmd) } - } else { - switch val.CmdType { - case SingleShard, Custom: - // Handle single-shard or custom commands - if storeOp[0].EvalResponse.Error != nil { - err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Error) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - } - return err - } - - err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Result) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } + case MultiShard: + err = w.writeResponse(ctx, cmdMeta.composeResponse(storeOp...)) - if w.wl != nil { - w.wl.LogCommand(diceDBCmd) - } - - case MultiShard: - err := w.ioHandler.Write(ctx, val.composeResponse(storeOp...)) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } - - if w.wl != nil { - w.wl.LogCommand(diceDBCmd) - } - - default: - slog.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", storeOp)) - err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) - if err != nil { - slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) - return err - } + if err == nil && w.wl != nil { + w.wl.LogCommand(diceDBCmd) } + default: + slog.Error("Unknown command type", + slog.String("workerID", w.id), + slog.String("command", diceDBCmd.Cmd), + slog.Any("evalResp", storeOp)) + err = w.writeResponse(ctx, diceerrors.ErrInternalServer) } + return err +} - return nil +// writeResponse handles writing responses and logging errors +func (w *BaseWorker) writeResponse(ctx context.Context, response interface{}) error { + err := w.ioHandler.Write(ctx, response) + if err != nil { + slog.Debug("Error sending response to client", + slog.String("workerID", w.id), + slog.Any("error", err)) + } + return err } func (w *BaseWorker) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {