Skip to content

Commit

Permalink
Refactors worker logic (DiceDB#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
JyotinderSingh authored Nov 8, 2024
1 parent e32bf3c commit 9f8a69e
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 112 deletions.
2 changes: 1 addition & 1 deletion integration_tests/commands/resp/getunwatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

const (
getUnwatchKey = "getunwatchkey"
fingerprint = "3557732805"
fingerprint = "426696421"
)

type getUnwatchTestCase struct {
Expand Down
250 changes: 139 additions & 111 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,7 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm

// Unsubscribe Unwatch command type
if meta.CmdType == Unwatch {
// extract the fingerprint
command := cmdList[len(cmdList)-1]
fp, fperr := strconv.ParseUint(command.Args[0], 10, 32)
if fperr != nil {
err := w.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint)
if err != nil {
return fmt.Errorf("error sending push response to client: %v", err)
}
return fperr
}

// send the unsubscribe request
w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
Subscribe: false,
AdhocReqChan: w.adhocReqChan,
Fingerprint: uint32(fp),
}

err := w.ioHandler.Write(ctx, "OK")
if err != nil {
return fmt.Errorf("error sending push response to client: %v", err)
}
return nil
return w.handleCommandUnwatch(ctx, cmdList)
}

// Scatter the broken-down commands to the appropriate shards.
Expand All @@ -307,13 +285,46 @@ func (w *BaseWorker) executeCommand(ctx context.Context, diceDBCmd *cmd.DiceDBCm

if meta.CmdType == Watch {
// Proceed to subscribe after successful execution
w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
Subscribe: true,
WatchCmd: cmdList[len(cmdList)-1],
AdhocReqChan: w.adhocReqChan,
w.handleCommandWatch(cmdList)
}

return nil
}

// handleCommandWatch sends a watch subscription request to the watch manager.
func (w *BaseWorker) handleCommandWatch(cmdList []*cmd.DiceDBCmd) {
w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
Subscribe: true,
WatchCmd: cmdList[len(cmdList)-1],
AdhocReqChan: w.adhocReqChan,
}
}

// handleCommandUnwatch sends an unwatch subscription request to the watch manager. It also sends a response to the client.
// The response is sent before the unwatch request is processed by the watch manager.
func (w *BaseWorker) handleCommandUnwatch(ctx context.Context, cmdList []*cmd.DiceDBCmd) error {
// extract the fingerprint
command := cmdList[len(cmdList)-1]
fp, parseErr := strconv.ParseUint(command.Args[0], 10, 32)
if parseErr != nil {
err := w.ioHandler.Write(ctx, diceerrors.ErrInvalidFingerprint)
if err != nil {
return fmt.Errorf("error sending push response to client: %v", err)
}
return parseErr
}

// send the unsubscribe request
w.cmdWatchSubscriptionChan <- watchmanager.WatchSubscription{
Subscribe: false,
AdhocReqChan: w.adhocReqChan,
Fingerprint: uint32(fp),
}

err := w.ioHandler.Write(ctx, clientio.RespOK)
if err != nil {
return fmt.Errorf("error sending push response to client: %v", err)
}
return nil
}

Expand All @@ -327,131 +338,148 @@ func (w *BaseWorker) scatter(ctx context.Context, cmds []*cmd.DiceDBCmd) error {
return ctx.Err()
default:
for i := uint8(0); i < uint8(len(cmds)); i++ {
var rc chan *ops.StoreOp
var sid shard.ShardID
var key string
if len(cmds[i].Args) > 0 {
key = cmds[i].Args[0]
} else {
key = cmds[i].Cmd
}
shardID, responseChan := w.shardManager.GetShardInfo(getRoutingKeyFromCommand(cmds[i]))

sid, rc = w.shardManager.GetShardInfo(key)

rc <- &ops.StoreOp{
responseChan <- &ops.StoreOp{
SeqID: i,
RequestID: GenerateUniqueRequestID(),
Cmd: cmds[i],
WorkerID: w.id,
ShardID: sid,
ShardID: shardID,
Client: nil,
}
}
}

return nil
}

// getRoutingKeyFromCommand determines the key used for shard routing
func getRoutingKeyFromCommand(diceDBCmd *cmd.DiceDBCmd) string {
if len(diceDBCmd.Args) > 0 {
return diceDBCmd.Args[0]
}
return diceDBCmd.Cmd
}

// gather collects the responses from multiple shards and writes the results into the provided buffer.
// It first waits for responses from all the shards and then processes the result based on the command type (SingleShard, Custom, or Multishard).
func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCmds int, isWatchNotification bool) error {
// Loop to wait for messages from number of shards
// Collect responses from all shards
storeOp, err := w.gatherResponses(ctx, numCmds)
if err != nil {
return err
}

if len(storeOp) == 0 {
slog.Error("No response from shards",
slog.String("workerID", w.id),
slog.String("command", diceDBCmd.Cmd))
return fmt.Errorf("no response from shards for command: %s", diceDBCmd.Cmd)
}

if isWatchNotification {
return w.handleWatchNotification(ctx, diceDBCmd, storeOp[0])
}

// Process command based on its type
cmdMeta, ok := CommandsMeta[diceDBCmd.Cmd]
if !ok {
return w.handleLegacyCommand(ctx, storeOp[0])
}

return w.handleCommand(ctx, cmdMeta, diceDBCmd, storeOp)
}

// gatherResponses collects responses from all shards
func (w *BaseWorker) gatherResponses(ctx context.Context, numCmds int) ([]ops.StoreResponse, error) {
var storeOp []ops.StoreResponse
for numCmds != 0 {

for numCmds > 0 {
select {
case <-ctx.Done():
slog.Error("Timed out waiting for response from shards", slog.String("workerID", w.id), slog.Any("error", ctx.Err()))
slog.Error("Timed out waiting for response from shards",
slog.String("workerID", w.id),
slog.Any("error", ctx.Err()))
return nil, ctx.Err()

case resp, ok := <-w.responseChan:
if ok {
storeOp = append(storeOp, *resp)
}
numCmds--
continue

case sError, ok := <-w.shardManager.ShardErrorChan:
if ok {
slog.Error("Error from shard", slog.String("workerID", w.id), slog.Any("error", sError))
slog.Error("Error from shard",
slog.String("workerID", w.id),
slog.Any("error", sError))
return nil, sError.Error
}
}
}

val, ok := CommandsMeta[diceDBCmd.Cmd]
return storeOp, nil
}

// handleWatchNotification processes watch notification responses
func (w *BaseWorker) handleWatchNotification(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, resp ops.StoreResponse) error {
fingerprint := fmt.Sprintf("%d", diceDBCmd.GetFingerprint())

if isWatchNotification {
if storeOp[0].EvalResponse.Error != nil {
err := w.ioHandler.Write(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), storeOp[0].EvalResponse.Error))
if err != nil {
slog.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err))
}
return err
}
if resp.EvalResponse.Error != nil {
return w.writeResponse(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fingerprint, resp.EvalResponse.Error))
}

err := w.ioHandler.Write(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fmt.Sprintf("%d", diceDBCmd.GetFingerprint()), storeOp[0].EvalResponse.Result))
if err != nil {
slog.Debug("Error sending push response to client", slog.String("workerID", w.id), slog.Any("error", err))
return err
}
return nil // Exit after handling watch case
return w.writeResponse(ctx, querymanager.GenericWatchResponse(diceDBCmd.Cmd, fingerprint, resp.EvalResponse.Result))
}

// handleLegacyCommand processes commands not in CommandsMeta
func (w *BaseWorker) handleLegacyCommand(ctx context.Context, resp ops.StoreResponse) error {
if resp.EvalResponse.Error != nil {
return w.writeResponse(ctx, resp.EvalResponse.Error)
}
return w.writeResponse(ctx, resp.EvalResponse.Result)
}

// TODO: Remove it once we have migrated all the commands
if !ok {
// handleCommand processes commands based on their type
func (w *BaseWorker) handleCommand(ctx context.Context, cmdMeta CmdMeta, diceDBCmd *cmd.DiceDBCmd, storeOp []ops.StoreResponse) error {
var err error

switch cmdMeta.CmdType {
case SingleShard, Custom:
if storeOp[0].EvalResponse.Error != nil {
err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Error)
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
}
return err
err = w.writeResponse(ctx, storeOp[0].EvalResponse.Error)
} else {
err = w.writeResponse(ctx, storeOp[0].EvalResponse.Result)
}

err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Result)
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
return err
if err == nil && w.wl != nil {
w.wl.LogCommand(diceDBCmd)
}
} else {
switch val.CmdType {
case SingleShard, Custom:
// Handle single-shard or custom commands
if storeOp[0].EvalResponse.Error != nil {
err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Error)
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
}
return err
}

err := w.ioHandler.Write(ctx, storeOp[0].EvalResponse.Result)
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
return err
}
case MultiShard:
err = w.writeResponse(ctx, cmdMeta.composeResponse(storeOp...))

if w.wl != nil {
w.wl.LogCommand(diceDBCmd)
}

case MultiShard:
err := w.ioHandler.Write(ctx, val.composeResponse(storeOp...))
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
return err
}

if w.wl != nil {
w.wl.LogCommand(diceDBCmd)
}

default:
slog.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", storeOp))
err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer)
if err != nil {
slog.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err))
return err
}
if err == nil && w.wl != nil {
w.wl.LogCommand(diceDBCmd)
}
default:
slog.Error("Unknown command type",
slog.String("workerID", w.id),
slog.String("command", diceDBCmd.Cmd),
slog.Any("evalResp", storeOp))
err = w.writeResponse(ctx, diceerrors.ErrInternalServer)
}
return err
}

return nil
// writeResponse handles writing responses and logging errors
func (w *BaseWorker) writeResponse(ctx context.Context, response interface{}) error {
err := w.ioHandler.Write(ctx, response)
if err != nil {
slog.Debug("Error sending response to client",
slog.String("workerID", w.id),
slog.Any("error", err))
}
return err
}

func (w *BaseWorker) isAuthenticated(diceDBCmd *cmd.DiceDBCmd) error {
Expand Down

0 comments on commit 9f8a69e

Please sign in to comment.