Skip to content

Commit

Permalink
Merge pull request #224 from distributedio/feature/zsets
Browse files Browse the repository at this point in the history
Feature/zsets
  • Loading branch information
nioshield authored Jul 19, 2021
2 parents 545791b + e6e8782 commit 7d7dc88
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 9 deletions.
5 changes: 3 additions & 2 deletions command/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ func init() {
"zrevrangebyscore": Desc{Proc: AutoCommit(ZRevRangeByScore), Txn: ZRevRangeByScore, Cons: Constraint{-4, flags("rF"), 1, 1, 1}},
"zrem": Desc{Proc: AutoCommit(ZRem), Txn: ZRem, Cons: Constraint{-3, flags("wF"), 1, 1, 1}},
"zcard": Desc{Proc: AutoCommit(ZCard), Txn: ZCard, Cons: Constraint{2, flags("rF"), 1, 1, 1}},
//"zcount": Desc{Proc: AutoCommit(ZCount), Txn: ZCount, Cons: Constraint{4, flags("rF"), 1, 1, 1}},
"zscore": Desc{Proc: AutoCommit(ZScore), Txn: ZScore, Cons: Constraint{3, flags("rF"), 1, 1, 1}},
"zcount": Desc{Proc: AutoCommit(ZCount), Txn: ZCount, Cons: Constraint{-4, flags("rF"), 1, 1, 1}},
"zscore": Desc{Proc: AutoCommit(ZScore), Txn: ZScore, Cons: Constraint{3, flags("rF"), 1, 1, 1}},
"zscan": Desc{Proc: AutoCommit(ZScan), Txn: ZScan, Cons: Constraint{-3, flags("rF"), 1, 1, 1}},

// extension commands
"escan": Desc{Proc: AutoCommit(Escan), Txn: Escan, Cons: Constraint{-1, flags("rR"), 0, 0, 0}},
Expand Down
127 changes: 124 additions & 3 deletions command/zsets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ package command

import (
"errors"
"fmt"
"math"
"strconv"
"strings"

"github.com/distributedio/titan/db"
"github.com/distributedio/titan/encoding/resp"
)

// ZAdd adds the specified members with scores to the sorted set
func ZAdd(ctx *Context, txn *db.Transaction) (OnCommit, error) {
key := []byte(ctx.Args[0])

fmt.Println("zadd", ctx.Args)

kvs := ctx.Args[1:]
if len(kvs)%2 != 0 {
return nil, errors.New("ERR syntax error")
Expand Down Expand Up @@ -111,6 +109,41 @@ func ZRevRangeByScore(ctx *Context, txn *db.Transaction) (OnCommit, error) {
return zAnyOrderRangeByScore(ctx, txn, false)
}

func ZCount(ctx *Context, txn *db.Transaction) (OnCommit, error) {
key := []byte(ctx.Args[0])
startScore, startInclude, err := getFloatAndInclude(ctx.Args[1])
if err != nil {
return nil, ErrMinOrMaxNotFloat
}
endScore, endInclude, err := getFloatAndInclude(ctx.Args[2])
if err != nil {
return nil, ErrMinOrMaxNotFloat
}
zset, err := txn.ZSet(key)
if err != nil {
if err == db.ErrTypeMismatch {
return nil, ErrTypeMismatch
}
return nil, errors.New("ERR " + err.Error())
}
if !zset.Exist() {
return Integer(ctx.Out, 0), nil
}

items, err := zset.ZAnyOrderRangeByScore(startScore, startInclude,
endScore, endInclude,
false,
int64(0), math.MaxInt64,
true)
if err != nil {
return nil, errors.New("ERR " + err.Error())
}
if len(items) == 0 {
return Integer(ctx.Out, 0), nil
}
return Integer(ctx.Out, int64(len(items))), nil
}

func zAnyOrderRangeByScore(ctx *Context, txn *db.Transaction, positiveOrder bool) (OnCommit, error) {
key := []byte(ctx.Args[0])

Expand Down Expand Up @@ -240,3 +273,91 @@ func ZScore(ctx *Context, txn *db.Transaction) (OnCommit, error) {

return BulkString(ctx.Out, string(score)), nil
}

func ZScan(ctx *Context, txn *db.Transaction) (OnCommit, error) {
var (
key []byte
cursor []byte
lastCursor = []byte("0")
count = uint64(defaultScanCount)
kvs = [][]byte{}
pattern []byte
isAll bool
err error
)
key = []byte(ctx.Args[0])
if strings.Compare(ctx.Args[1], "0") != 0 {
cursor = []byte(ctx.Args[1])
}

// define return result
result := func() {
if _, err := resp.ReplyArray(ctx.Out, 2); err != nil {
return
}
resp.ReplyBulkString(ctx.Out, string(lastCursor))
if _, err := resp.ReplyArray(ctx.Out, len(kvs)); err != nil {
return
}
for i := range kvs {
resp.ReplyBulkString(ctx.Out, string(kvs[i]))
}
}
zset, err := txn.ZSet(key)
if err != nil {
if err == db.ErrTypeMismatch {
return nil, ErrTypeMismatch
}
return nil, errors.New("ERR " + err.Error())
}

if !zset.Exist() {
return result, nil
}

if len(ctx.Args)%2 != 0 {
return nil, ErrSyntax
}

for i := 2; i < len(ctx.Args); i += 2 {
arg := strings.ToLower(ctx.Args[i])
next := ctx.Args[i+1]
switch arg {
case "count":
if count, err = strconv.ParseUint(next, 10, 64); err != nil {
return nil, ErrInteger
}
if count > ScanMaxCount {
count = ScanMaxCount
}
if count == 0 {
count = uint64(defaultScanCount)
}
case "match":
pattern = []byte(next)
isAll = (pattern[0] == '*' && len(pattern) == 1)
}
}

if len(pattern) == 0 {
isAll = true
}
f := func(member, score []byte) bool {
if count <= 0 {
lastCursor = member
return false
}
if isAll || globMatch(pattern, member, false) {
kvs = append(kvs, member)
kvs = append(kvs, score)
count--
}
return true
}

if err := zset.ZScan(cursor, f); err != nil {
return nil, errors.New("ERR " + err.Error())
}
return result, nil

}
55 changes: 51 additions & 4 deletions db/zset.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,68 @@ func (zset *ZSet) ZCard() int64 {
}

func (zset *ZSet) ZScore(member []byte) ([]byte, error) {
dkey := DataKey(zset.txn.db, zset.meta.ID)
memberKey := zsetMemberKey(dkey, member)
bytesScore, err := zset.txn.t.Get(zset.txn.ctx, memberKey)
bScore, err := zset.zScoreBytes(member)
if err != nil {
if IsErrNotFound(err) {
return nil, nil
}
return nil, err
}

fscore := DecodeFloat64(bytesScore)
fscore := DecodeFloat64(bScore)
sscore := strconv.FormatFloat(fscore, 'f', -1, 64)
return []byte(sscore), nil
}

func (zset *ZSet) zScoreBytes(member []byte) ([]byte, error) {
dkey := DataKey(zset.txn.db, zset.meta.ID)
memberKey := zsetMemberKey(dkey, member)
bScore, err := zset.txn.t.Get(zset.txn.ctx, memberKey)
if err != nil {
return nil, err
}
return bScore, nil
}

func (zset *ZSet) ZScan(cursor []byte, f func(key, val []byte) bool) error {
if !zset.Exist() {
return nil
}
dkey := DataKey(zset.txn.db, zset.meta.ID)
prefix := ZSetScorePrefix(dkey)
endPrefix := kv.Key(prefix).PrefixNext()
ikey := prefix
if len(cursor) > 0 {
bScore, err := zset.zScoreBytes(cursor)
if err != nil {
if IsErrNotFound(err) {
return nil
}
return err
}
if len(bScore) > 0 {
ikey = append(ikey, bScore...)
}
}
iter, err := zset.txn.t.Iter(ikey, endPrefix)
if err != nil {
return err
}
for iter.Valid() && iter.Key().HasPrefix(prefix) {
scoreAndMember := iter.Key()[len(prefix):]
member := scoreAndMember[byteScoreLen+len(":"):]
byteScore := scoreAndMember[0:byteScoreLen]
score := []byte(strconv.FormatFloat(DecodeFloat64(byteScore), 'f', -1, 64))
if !f(member, score) {
break
}
if err := iter.Next(); err != nil {
return err
}
}
return nil
}

func zsetMemberKey(dkey []byte, member []byte) []byte {
var memberKey []byte
memberKey = append(memberKey, dkey...)
Expand Down
64 changes: 64 additions & 0 deletions db/zset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,67 @@ func TestZSetZAnyOrderRangeScore(t *testing.T) {
})
}
}

func TestZSet_ZScan(t *testing.T) {
var members [][]byte
var score []float64

members = append(members, []byte("abc"))
members = append(members, []byte("aec"))
members = append(members, []byte("acc"))
members = append(members, []byte("bc"))
score = append(score, -1.1, -1, 1, 2.1)

zset, txn, err := getZSet(t, []byte("TestZSet_ZScan"))
assert.NoError(t, err)
assert.NotNil(t, txn)
assert.NotNil(t, zset)
count, err := zset.ZAdd(members, score)
assert.NoError(t, err)
assert.Equal(t, count, int64(len(members)))
txn.Commit(context.TODO())

type args struct {
cursor []byte
f func(key, val []byte) bool
}
var value [][]byte
count = 2

tests := []struct {
name string
args args
want [][]byte
}{
{
name: "TestZSet_ZScan",
args: args{
f: func(member, score []byte) bool {
if count == 0 {
return false
}
value = append(value, member, score)
count--
return true

},
},
want: append([][]byte{}, []byte("abc"), []byte("-1.1"), []byte("aec"), []byte("-1")),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

zset, txn, err := getZSet(t, []byte("TestZSet_ZScan"))
assert.NoError(t, err)
assert.NotNil(t, txn)
assert.NotNil(t, zset)

err = zset.ZScan(tt.args.cursor, tt.args.f)
txn.Commit(context.TODO())

assert.Equal(t, value, tt.want)
assert.NoError(t, err)
})
}
}
30 changes: 30 additions & 0 deletions tools/autotest/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,36 @@ func (ac *AutoClient) ZSetCase(t *testing.T) {
ac.ez.ZRangeByScoreEqual(t, "key-zset", "(2", "3.6", true, "", "member6 2.05 member3 3.6")
ac.ez.ZRangeByScoreEqual(t, "key-zset", "0", "(2", true, "", "member5 0 member2 1.5")

ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(+inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+inf", "(-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-inf", false, "", "member3 member6 member11 member1 member2 member5 member4")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-3.5", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "(-3.5", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "3.6", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+3.6", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(3.6", "(0.0", true, "", "member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "", "member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT -1 1", "")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "limit 0 -1", "member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 0", "")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 2", "member6 2.05 member11 2")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 4", "member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 5", "member6 2.05 member11 2 member1 2 member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 1 2", "member11 2 member1 2")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 3 2", "member2 1.5")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 4 2", "")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "3.6", "(2", true, "", "member3 3.6 member6 2.05")
ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(2", "0", true, "", "member2 1.5 member5 0")

ac.ez.ZCountEqual(t, "key-zset", "0", "(2", int64(2))

ac.ez.ZScanEqual(t, "key-zset", "0", "*", 2, "member2 member4 -3.5 member5 0")
ac.ez.ZScanEqual(t, "key-zset", "member2", "member*", 2, "member11 member2 1.5 member1 2")

ac.ez.ZRemEqual(t, "key-zset", "member2", "member1", "member3", "member4", "member1")
ac.ez.ZRangeEqual(t, "key-zset", 0, -1, true)

Expand Down
40 changes: 40 additions & 0 deletions tools/autotest/cmd/zset.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"fmt"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -174,6 +175,45 @@ func (ez *ExampleZSet) ZRevRangeEqualErr(t *testing.T, errValue string, args ...
assert.EqualError(t, err, errValue)
}

func (ez *ExampleZSet) ZScanEqual(t *testing.T, key string, cursor string, pattern string, count int, expected string) {
cmd := "zscan"
req := make([]interface{}, 0)
req = append(req, key)
req = append(req, cursor)
req = append(req, "match", pattern)
req = append(req, "count", count)

reply, err := redis.MultiBulk(ez.conn.Do(cmd, req...))
lastCursor, _ := redis.String(reply[0], err)
strs, _ := redis.Strings(reply[1], err)
fmt.Println(lastCursor, strs)
if expected != "" {
expectedStrs := strings.Split(expected, " ")
assert.Equal(t, expectedStrs[0], lastCursor)
assert.Equal(t, expectedStrs[1:], strs)
} else {
assert.Equal(t, "0", lastCursor)
}
assert.Nil(t, err)
}

func (ez *ExampleZSet) ZCountEqual(t *testing.T, key string, start string, stop string, expected int64) {
cmd := "zcount"
req := make([]interface{}, 0)
req = append(req, key)
req = append(req, start)
req = append(req, stop)

reply, err := redis.Int64(ez.conn.Do(cmd, req...))
assert.Equal(t, expected, reply)
assert.Nil(t, err)
}

func (ez *ExampleZSet) ZCountEqualErr(t *testing.T, errValue string, args ...interface{}) {
_, err := ez.conn.Do("zcount", args...)
assert.EqualError(t, err, errValue)
}

func (ez *ExampleZSet) ZRangeByScoreEqual(t *testing.T, key string, start string, stop string, withScores bool, limit string, expected string) {
ez.ZAnyOrderRangeByScoreEqual(t, key, start, stop, withScores, true, limit, expected)
}
Expand Down

0 comments on commit 7d7dc88

Please sign in to comment.