Skip to content

Commit

Permalink
feat: support transaction in cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
zjregee committed Sep 24, 2024
1 parent 15f53d9 commit 6c9f814
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 815 deletions.
33 changes: 22 additions & 11 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,29 @@ const (
reset = "\033[0m"
)

var commands = map[string]func(*session, ...string) (string, error){
var commands = map[string]func(*session, ...string) string{
"GET": Get,
"SET": Modify,
"APPEND": Modify,
"DEL": Modify,
"MULTI": Modify,
"EXEC": Modify,
"DELETE": Modify,
"MULTI": Multi,
"EXEC": Exec,
}

func Get(session *session, args ...string) string {
return session.TxnGet(args...)
}

func Modify(session *session, args ...string) string {
return session.TxnModify(args...)
}

func Multi(session *session, args ...string) string {
return session.TxnMulti()
}

func Exec(session *session, args ...string) string {
return session.TxnExec()
}

func repl(session *session) {
Expand Down Expand Up @@ -77,13 +93,8 @@ func repl(session *session) {
fmt.Println("unknown command")
continue
}
result, err := handler(session, args...)
if err != nil {
success = false
fmt.Println(err.Error())
} else {
fmt.Println(result)
}
result := handler(session, args...)
fmt.Println(result)
}
}

Expand Down
63 changes: 63 additions & 0 deletions cmd/cmd_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"hash/fnv"

pb "github.com/zjregee/shardkv/proto"
)

type buffers map[uint64]*pb.Mutation

func newBuffers() *buffers {
bs := make(buffers)
return &bs
}

func (bs *buffers) addMutation(m *pb.Mutation) {
key := encodeKey(m.Key)
if _, ok := (*bs)[key]; !ok {
(*bs)[encodeKey(m.Key)] = m
return
}
switch (*bs)[key].Kind {
case pb.MutationKind_Put:
switch m.Kind {
case pb.MutationKind_Put, pb.MutationKind_Delete:
(*bs)[key] = m
case pb.MutationKind_Append:
(*bs)[key].Value = append((*bs)[key].Value, m.Value...)
}
case pb.MutationKind_Append:
switch m.Kind {
case pb.MutationKind_Put, pb.MutationKind_Delete:
(*bs)[key] = m
case pb.MutationKind_Append:
(*bs)[key].Value = append((*bs)[key].Value, m.Value...)
}
case pb.MutationKind_Delete:
switch m.Kind {
case pb.MutationKind_Put, pb.MutationKind_Delete:
(*bs)[key] = m
case pb.MutationKind_Append:
newM := &pb.Mutation{}
newM.Kind = pb.MutationKind_Put
newM.Key = m.Key
newM.Value = m.Value
(*bs)[key] = newM
}
}
}

func (bs *buffers) mutations() []*pb.Mutation {
var ms []*pb.Mutation
for _, m := range *bs {
ms = append(ms, m)
}
return ms
}

func encodeKey(key []byte) uint64 {
h := fnv.New64a()
h.Write(key)
return h.Sum64()
}
181 changes: 181 additions & 0 deletions cmd/cmd_session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package main

import (
"context"
"fmt"
"strings"
"time"

"github.com/zjregee/shardkv/common/utils"
pb "github.com/zjregee/shardkv/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

type session struct {
isInMultiTxn bool
startTs uint64
buffers *buffers
peers []*grpc.ClientConn
leaderIndex int32
}

func newSession(peers []string, leaderIndex int32) (*session, error) {
session := &session{}
session.isInMultiTxn = false
session.peers = make([]*grpc.ClientConn, 0)
session.leaderIndex = leaderIndex
for _, peer := range peers {
conn, err := grpc.NewClient(peer, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
session.peers = append(session.peers, conn)
}
return session, nil
}

func (session *session) Close() {
for _, conn := range session.peers {
_ = conn.Close()
}
}

func (session *session) TxnGet(args ...string) string {
if len(args) != 2 {
return "usage: GET <key>"
}
getArgs := &pb.TxnGetArgs{}
getArgs.Id = utils.Nrand()
getArgs.Key = []byte(args[0])
if session.isInMultiTxn {
getArgs.StartTs = session.startTs
} else {
getArgs.StartTs = uint64(time.Now().UnixMilli())
}
for {
if session.leaderIndex == -1 {
session.leaderIndex = int32(utils.Nrand()) % int32(len(session.peers))
}
client := pb.NewKvServiceClient(session.peers[session.leaderIndex])
ctx, cancel := context.WithTimeout(context.Background(), RPC_TIMEOUT)
reply, err := client.HandleTxnGet(ctx, getArgs)
cancel()
if err != nil {
session.leaderIndex = -1
continue
}
switch reply.Err {
case pb.Err_ErrClosed, pb.Err_ErrWrongLeader:
session.leaderIndex = -1
case pb.Err_OK:
return string(reply.Value)
case pb.Err_ErrNoKey:
return ""
}
}
}

func (session *session) TxnModify(args ...string) string {
kind := strings.ToUpper(args[0])
if kind == "SET" || kind == "APPEND" {
if len(args) != 3 {
return fmt.Sprintf("usage: %s <key> <value>", kind)
}
} else {
if len(args) != 2 {
return fmt.Sprintf("usage: %s <key> <value>", kind)
}
}
m := &pb.Mutation{}
m.Key = []byte(args[1])
switch kind {
case "SET":
m.Kind = pb.MutationKind_Put
m.Value = []byte(args[2])
case "APPEND":
m.Kind = pb.MutationKind_Append
m.Value = []byte(args[2])
case "DEL":
m.Kind = pb.MutationKind_Delete
}
if session.isInMultiTxn {
session.buffers.addMutation(m)
return ""
}
return session.commitTxn([]*pb.Mutation{m})
}

func (session *session) TxnMulti() string {
if session.isInMultiTxn {
return ""
}
session.isInMultiTxn = true
session.startTs = uint64(time.Now().UnixMilli())
session.buffers = newBuffers()
return ""
}

func (session *session) TxnExec() string {
if !session.isInMultiTxn {
return ""
}
session.isInMultiTxn = false
return session.commitTxn(session.buffers.mutations())
}

func (session *session) commitTxn(ms []*pb.Mutation) string {
prewriteArgs := &pb.TxnPrewriteArgs{}
prewriteArgs.Id = utils.Nrand()
prewriteArgs.StartTs = session.startTs
prewriteArgs.Mutations = ms
success := false
for !success {
if session.leaderIndex == -1 {
session.leaderIndex = int32(utils.Nrand()) % int32(len(session.peers))
}
client := pb.NewKvServiceClient(session.peers[session.leaderIndex])
ctx, cancel := context.WithTimeout(context.Background(), RPC_TIMEOUT)
reply, err := client.HandleTxnPrewrite(ctx, prewriteArgs)
cancel()
if err != nil {
session.leaderIndex = -1
continue
}
switch reply.Err {
case pb.Err_ErrClosed, pb.Err_ErrWrongLeader:
session.leaderIndex = -1
case pb.Err_OK, pb.Err_Duplicate:
success = true
case pb.Err_ErrConflict:
return "can't commit due to transaction conflict"
}
}
commitArgs := &pb.TxnCommitArgs{}
commitArgs.Id = utils.Nrand()
commitArgs.StartTs = session.startTs
commitArgs.CommitTs = uint64(time.Now().UnixMilli())
commitArgs.Keys = make([][]byte, 0)
for _, m := range ms {
commitArgs.Keys = append(commitArgs.Keys, m.Key)
}
for {
if session.leaderIndex == -1 {
session.leaderIndex = int32(utils.Nrand()) % int32(len(session.peers))
}
client := pb.NewKvServiceClient(session.peers[session.leaderIndex])
ctx, cancel := context.WithTimeout(context.Background(), RPC_TIMEOUT)
reply, err := client.HandleTxnCommit(ctx, commitArgs)
cancel()
if err != nil {
session.leaderIndex = -1
continue
}
switch reply.Err {
case pb.Err_ErrClosed, pb.Err_ErrWrongLeader:
session.leaderIndex = -1
case pb.Err_OK, pb.Err_Duplicate:
return ""
}
}
}
103 changes: 0 additions & 103 deletions cmd/cmd_utils.go

This file was deleted.

Loading

0 comments on commit 6c9f814

Please sign in to comment.