diff --git a/cmd_hll.go b/cmd_hll.go new file mode 100644 index 00000000..bd2f90c8 --- /dev/null +++ b/cmd_hll.go @@ -0,0 +1,95 @@ +package miniredis + +import "github.com/alicebob/miniredis/v2/server" + +// commandsHll handles all hll related operations. +func commandsHll(m *Miniredis) { + m.srv.Register("PFADD", m.cmdPfadd) + m.srv.Register("PFCOUNT", m.cmdPfcount) + m.srv.Register("PFMERGE", m.cmdPfmerge) +} + +// PFADD +func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + key, items := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if db.exists(key) && db.t(key) != "hll" { + c.WriteError(ErrNotValidHllValue.Error()) + return + } + + altered := db.hllAdd(key, items...) + c.WriteInt(altered) + }) +} + +// PFCOUNT +func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + keys := args[:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + count, err := db.hllCount(keys) + if err != nil { + c.WriteError(err.Error()) + return + } + + c.WriteInt(count) + }) +} + +// PFMERGE +func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + keys := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if err := db.hllMerge(keys); err != nil { + c.WriteError(err.Error()) + return + } + c.WriteOK() + }) +} diff --git a/cmd_hll_test.go b/cmd_hll_test.go new file mode 100644 index 00000000..25d62dc9 --- /dev/null +++ b/cmd_hll_test.go @@ -0,0 +1,238 @@ +package miniredis + +import ( + "testing" + + "github.com/alicebob/miniredis/v2/proto" +) + +// Test PFADD +func TestPfadd(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + mustDo(t, c, + "PFADD", "h", "aap", "noot", "mies", + proto.Int(1), + ) + + mustDo(t, c, + "PFADD", "h", "aap", // already exists in hll => returns 0 + proto.Int(0), + ) + + mustDo(t, c, + "TYPE", "h", + proto.Inline("hll"), + ) + + t.Run("direct usage", func(t *testing.T) { + added, err := s.SetAdd("s1", "aap") + ok(t, err) + equals(t, 1, added) + + members, err := s.Members("s1") + ok(t, err) + equals(t, []string{"aap"}, members) + }) + + t.Run("errors", func(t *testing.T) { + mustOK(t, c, "SET", "str", "value") + mustDo(t, c, + "PFADD", "str", "hi", + proto.Error(msgNotValidHllValue), + ) + // Wrong argument counts + mustDo(t, c, + "PFADD", + proto.Error(errWrongNumber("pfadd")), + ) + }) +} + +// Test PFCOUNT +func TestPfcount(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + // Add 100 unique random values + for i := 0; i < 100; i++ { + mustDo(t, c, + "PFADD", "h1", randomStr(10), + proto.Int(1), // hll changes each time + ) + } + + // Add 1 more unique value + specificValue := randomStr(10) + mustDo(t, c, + "PFADD", "h1", specificValue, + proto.Int(1), // hll changes because of new element + ) + for i := 0; i < 50; i++ { + mustDo(t, c, + "PFADD", "h1", specificValue, + proto.Int(0), // hll doesn't change because this element has already been added before + ) + } + + mustDo(t, c, + "PFCOUNT", "h1", + proto.Int(101), + ) + + // Create a new hll + mustDo(t, c, + "PFADD", "h2", randomStr(10), randomStr(10), randomStr(10), + proto.Int(1), + ) + + mustDo(t, c, + "PFCOUNT", "h2", + proto.Int(3), + ) + + // Several hlls are involved - a sum of all the counts is returned + mustDo(t, c, + "PFCOUNT", + "h1", // has 101 unique values + "h2", // has 3 unique values + "h3", // empty key + proto.Int(104), + ) + + // A nonexisting key + mustDo(t, c, + "PFCOUNT", "h9", + proto.Int(0), + ) + + t.Run("errors", func(t *testing.T) { + s.Set("str", "value") + + mustDo(t, c, + "PFCOUNT", + proto.Error(errWrongNumber("pfcount")), + ) + mustDo(t, c, + "PFCOUNT", "str", + proto.Error(msgNotValidHllValue), + ) + mustDo(t, c, + "PFCOUNT", "h1", "str", + proto.Error(msgNotValidHllValue), + ) + }) +} + +// Test PFMERGE +func TestPfmerge(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + // Add 100 unique random values to h1 and 50 of these 100 to h2 + for i := 0; i < 100; i++ { + value := randomStr(10) + mustDo(t, c, + "PFADD", "h1", value, + proto.Int(1), // hll changes each time + ) + if i%2 == 0 { + mustDo(t, c, + "PFADD", "h2", value, + proto.Int(1), // hll changes each time + ) + } + } + + for i := 0; i < 100; i++ { + mustDo(t, c, + "PFADD", "h3", randomStr(10), + proto.Int(1), // hll changes each time + ) + } + + // Merge non-intersecting hlls + { + mustOK(t, c, + "PFMERGE", + "res1", + "h1", // count 100 + "h3", // count 100 + ) + mustDo(t, c, + "PFCOUNT", "res1", + proto.Int(200), + ) + } + + // Merge intersecting hlls + { + mustOK(t, c, + "PFMERGE", + "res2", + "h1", // count 100 + "h2", // count 50 (all 50 are presented in h1) + ) + mustDo(t, c, + "PFCOUNT", "res2", + proto.Int(100), + ) + } + + // Merge all hlls + { + mustOK(t, c, + "PFMERGE", + "res3", + "h1", // count 100 + "h2", // count 50 (all 50 are presented in h1) + "h3", // count 100 + "h4", // empty key + ) + mustDo(t, c, + "PFCOUNT", "res3", + proto.Int(200), + ) + } + + t.Run("direct", func(t *testing.T) { + commonElem := randomStr(10) + s.PfAdd("h5", commonElem, randomStr(10), randomStr(10), randomStr(10), randomStr(10)) + s.PfAdd("h6", commonElem, randomStr(10), randomStr(10)) + + sum, err := s.PfCount("h5", "h6", "h7") // h7 is empty + ok(t, err) + equals(t, sum, 8) + + s.PfMerge("h8", "h5", "h6") + sum, err = s.PfCount("h8") + ok(t, err) + equals(t, sum, 7) // common elem is counted once + }) + + t.Run("errors", func(t *testing.T) { + s.Set("str", "value") + + mustDo(t, c, + "PFMERGE", + proto.Error(errWrongNumber("pfmerge")), + ) + mustDo(t, c, + "PFMERGE", "h10", "str", + proto.Error(msgNotValidHllValue), + ) + }) +} diff --git a/cmd_set_test.go b/cmd_set_test.go index 716dcffb..6749a748 100644 --- a/cmd_set_test.go +++ b/cmd_set_test.go @@ -331,7 +331,7 @@ func TestSpop(t *testing.T) { s.SetAdd("s", "aap", "noot", "mies", "vuur") mustDo(t, c, "SPOP", "s", "2", - proto.Strings("vuur", "mies"), + proto.Strings("aap", "noot"), ) members, err := s.Members("s") ok(t, err) diff --git a/db.go b/db.go index 04cbae6c..b4e9d739 100644 --- a/db.go +++ b/db.go @@ -40,6 +40,7 @@ func (db *RedisDB) flush() { db.hashKeys = map[string]hashKey{} db.listKeys = map[string]listKey{} db.setKeys = map[string]setKey{} + db.hllKeys = map[string]*hll{} db.sortedsetKeys = map[string]sortedSet{} db.ttl = map[string]time.Duration{} db.streamKeys = map[string]*streamKey{} @@ -69,6 +70,8 @@ func (db *RedisDB) move(key string, to *RedisDB) bool { to.sortedsetKeys[key] = db.sortedsetKeys[key] case "stream": to.streamKeys[key] = db.streamKeys[key] + case "hll": + to.hllKeys[key] = db.hllKeys[key] default: panic("unhandled key type") } @@ -95,6 +98,8 @@ func (db *RedisDB) rename(from, to string) { db.sortedsetKeys[to] = db.sortedsetKeys[from] case "stream": db.streamKeys[to] = db.streamKeys[from] + case "hll": + db.hllKeys[to] = db.hllKeys[from] default: panic("missing case") } @@ -130,6 +135,8 @@ func (db *RedisDB) del(k string, delTTL bool) { delete(db.sortedsetKeys, k) case "stream": delete(db.streamKeys, k) + case "hll": + delete(db.hllKeys, k) default: panic("Unknown key type: " + t) } @@ -617,3 +624,69 @@ func (db *RedisDB) checkTTL(key string) { db.del(key, true) } } + +// hllAdd adds members to a hll. Returns 1 if at least 1 if internal HyperLogLog was altered, otherwise 0 +func (db *RedisDB) hllAdd(k string, elems ...string) int { + s, ok := db.hllKeys[k] + if !ok { + s = newHll() + db.keys[k] = "hll" + } + hllAltered := 0 + for _, e := range elems { + if s.Add([]byte(e)) { + hllAltered = 1 + } + } + db.hllKeys[k] = s + db.keyVersion[k]++ + return hllAltered +} + +// hllCount estimates the amount of members added to hll by hllAdd. If called with several arguments, hllCount returns a sum of estimations +func (db *RedisDB) hllCount(keys []string) (int, error) { + countOverall := 0 + for _, key := range keys { + if db.exists(key) && db.t(key) != "hll" { + return 0, ErrNotValidHllValue + } + if !db.exists(key) { + continue + } + countOverall += db.hllKeys[key].Count() + } + + return countOverall, nil +} + +// hllMerge merges all the hlls provided as keys to the first key. Creates a new hll in the first key if it contains nothing +func (db *RedisDB) hllMerge(keys []string) error { + for _, key := range keys { + if db.exists(key) && db.t(key) != "hll" { + return ErrNotValidHllValue + } + } + + destKey := keys[0] + restKeys := keys[1:] + + var destHll *hll + if db.exists(destKey) { + destHll = db.hllKeys[destKey] + } else { + destHll = newHll() + } + + for _, key := range restKeys { + if !db.exists(key) { + continue + } + destHll.Merge(db.hllKeys[key]) + } + + db.hllKeys[destKey] = destHll + db.keys[destKey] = "hll" + db.keyVersion[destKey]++ + + return nil +} diff --git a/direct.go b/direct.go index 708fdd10..23b6703a 100644 --- a/direct.go +++ b/direct.go @@ -15,6 +15,9 @@ var ( // ErrWrongType when a key is not the right type. ErrWrongType = errors.New(msgWrongType) + // ErrNotValidHllValue when a key is not a valid HyperLogLog string value. + ErrNotValidHllValue = errors.New(msgNotValidHllValue) + // ErrIntValueError can returned by INCRBY ErrIntValueError = errors.New(msgInvalidInt) @@ -748,3 +751,45 @@ func (m *Miniredis) PubSubNumPat() int { return countPsubs(m.allSubscribers()) } + +// PfAdd adds keys to a hll. Returns the flag which equals to 1 if the inner hll value has been changed. +func (m *Miniredis) PfAdd(k string, elems ...string) (int, error) { + return m.DB(m.selectedDB).HllAdd(k, elems...) +} + +// HllAdd adds keys to a hll. Returns the flag which equals to true if the inner hll value has been changed. +func (db *RedisDB) HllAdd(k string, elems ...string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + + if db.exists(k) && db.t(k) != "hll" { + return 0, ErrWrongType + } + return db.hllAdd(k, elems...), nil +} + +// PfCount returns an estimation of the amount of elements previously added to a hll. +func (m *Miniredis) PfCount(keys ...string) (int, error) { + return m.DB(m.selectedDB).HllCount(keys...) +} + +// HllCount returns an estimation of the amount of elements previously added to a hll. +func (db *RedisDB) HllCount(keys ...string) (int, error) { + db.master.Lock() + defer db.master.Unlock() + + return db.hllCount(keys) +} + +// PfMerge merges all the input hlls into a hll under destKey key. +func (m *Miniredis) PfMerge(destKey string, sourceKeys ...string) error { + return m.DB(m.selectedDB).HllMerge(destKey, sourceKeys...) +} + +// HllMerge merges all the input hlls into a hll under destKey key. +func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error { + db.master.Lock() + defer db.master.Unlock() + + return db.hllMerge(append([]string{destKey}, sourceKeys...)) +} diff --git a/go.sum b/go.sum index 378d5b0f..e7d8f685 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,6 @@ github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGn github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0= -github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da h1:NimzV1aGyq29m5ukMK0AMWEhFaL/lrEOaephfuoiARg= github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/hll.go b/hll.go new file mode 100644 index 00000000..2f55fac9 --- /dev/null +++ b/hll.go @@ -0,0 +1,36 @@ +package miniredis + +import ( + "github.com/alicebob/miniredis/v2/hyperloglog" +) + +type hll struct { + inner *hyperloglog.Sketch +} + +func newHll() *hll { + return &hll{ + inner: hyperloglog.New14(), + } +} + +// Add returns true if cardinality has been changed, or false otherwise. +func (h *hll) Add(item []byte) bool { + return h.inner.Insert(item) +} + +// Count returns the estimation of a set cardinality. +func (h *hll) Count() int { + return int(h.inner.Estimate()) +} + +// Merge merges the other hll into original one (not making a copy but doing this in place). +func (h *hll) Merge(other *hll) { + _ = h.inner.Merge(other.inner) +} + +// Bytes returns raw-bytes representation of hll data structure. +func (h *hll) Bytes() []byte { + dataBytes, _ := h.inner.MarshalBinary() + return dataBytes +} diff --git a/hyperloglog/LICENSE b/hyperloglog/LICENSE new file mode 100644 index 00000000..8436fdb4 --- /dev/null +++ b/hyperloglog/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Axiom Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/hyperloglog/README.md b/hyperloglog/README.md new file mode 100644 index 00000000..0fac68df --- /dev/null +++ b/hyperloglog/README.md @@ -0,0 +1 @@ +This is a copy of github.com/axiomhq/hyperloglog. \ No newline at end of file diff --git a/hyperloglog/compressed.go b/hyperloglog/compressed.go new file mode 100644 index 00000000..4b908be4 --- /dev/null +++ b/hyperloglog/compressed.go @@ -0,0 +1,180 @@ +package hyperloglog + +import "encoding/binary" + +// Original author of this file is github.com/clarkduvall/hyperloglog +type iterable interface { + decode(i int, last uint32) (uint32, int) + Len() int + Iter() *iterator +} + +type iterator struct { + i int + last uint32 + v iterable +} + +func (iter *iterator) Next() uint32 { + n, i := iter.v.decode(iter.i, iter.last) + iter.last = n + iter.i = i + return n +} + +func (iter *iterator) Peek() uint32 { + n, _ := iter.v.decode(iter.i, iter.last) + return n +} + +func (iter iterator) HasNext() bool { + return iter.i < iter.v.Len() +} + +type compressedList struct { + count uint32 + last uint32 + b variableLengthList +} + +func (v *compressedList) Clone() *compressedList { + if v == nil { + return nil + } + + newV := &compressedList{ + count: v.count, + last: v.last, + } + + newV.b = make(variableLengthList, len(v.b)) + copy(newV.b, v.b) + return newV +} + +func (v *compressedList) MarshalBinary() (data []byte, err error) { + // Marshal the variableLengthList + bdata, err := v.b.MarshalBinary() + if err != nil { + return nil, err + } + + // At least 4 bytes for the two fixed sized values plus the size of bdata. + data = make([]byte, 0, 4+4+len(bdata)) + + // Marshal the count and last values. + data = append(data, []byte{ + // Number of items in the list. + byte(v.count >> 24), + byte(v.count >> 16), + byte(v.count >> 8), + byte(v.count), + // The last item in the list. + byte(v.last >> 24), + byte(v.last >> 16), + byte(v.last >> 8), + byte(v.last), + }...) + + // Append the list + return append(data, bdata...), nil +} + +func (v *compressedList) UnmarshalBinary(data []byte) error { + if len(data) < 12 { + return ErrorTooShort + } + + // Set the count. + v.count, data = binary.BigEndian.Uint32(data[:4]), data[4:] + + // Set the last value. + v.last, data = binary.BigEndian.Uint32(data[:4]), data[4:] + + // Set the list. + sz, data := binary.BigEndian.Uint32(data[:4]), data[4:] + v.b = make([]uint8, sz) + if uint32(len(data)) < sz { + return ErrorTooShort + } + for i := uint32(0); i < sz; i++ { + v.b[i] = data[i] + } + return nil +} + +func newCompressedList() *compressedList { + v := &compressedList{} + v.b = make(variableLengthList, 0) + return v +} + +func (v *compressedList) Len() int { + return len(v.b) +} + +func (v *compressedList) decode(i int, last uint32) (uint32, int) { + n, i := v.b.decode(i, last) + return n + last, i +} + +func (v *compressedList) Append(x uint32) { + v.count++ + v.b = v.b.Append(x - v.last) + v.last = x +} + +func (v *compressedList) Iter() *iterator { + return &iterator{0, 0, v} +} + +type variableLengthList []uint8 + +func (v variableLengthList) MarshalBinary() (data []byte, err error) { + // 4 bytes for the size of the list, and a byte for each element in the + // list. + data = make([]byte, 0, 4+v.Len()) + + // Length of the list. We only need 32 bits because the size of the set + // couldn't exceed that on 32 bit architectures. + sz := v.Len() + data = append(data, []byte{ + byte(sz >> 24), + byte(sz >> 16), + byte(sz >> 8), + byte(sz), + }...) + + // Marshal each element in the list. + for i := 0; i < sz; i++ { + data = append(data, v[i]) + } + + return data, nil +} + +func (v variableLengthList) Len() int { + return len(v) +} + +func (v *variableLengthList) Iter() *iterator { + return &iterator{0, 0, v} +} + +func (v variableLengthList) decode(i int, last uint32) (uint32, int) { + var x uint32 + j := i + for ; v[j]&0x80 != 0; j++ { + x |= uint32(v[j]&0x7f) << (uint(j-i) * 7) + } + x |= uint32(v[j]) << (uint(j-i) * 7) + return x, j + 1 +} + +func (v variableLengthList) Append(x uint32) variableLengthList { + for x&0xffffff80 != 0 { + v = append(v, uint8((x&0x7f)|0x80)) + x >>= 7 + } + return append(v, uint8(x&0x7f)) +} diff --git a/hyperloglog/hyperloglog.go b/hyperloglog/hyperloglog.go new file mode 100644 index 00000000..82663915 --- /dev/null +++ b/hyperloglog/hyperloglog.go @@ -0,0 +1,424 @@ +package hyperloglog + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "sort" +) + +const ( + capacity = uint8(16) + pp = uint8(25) + mp = uint32(1) << pp + version = 1 +) + +// Sketch is a HyperLogLog data-structure for the count-distinct problem, +// approximating the number of distinct elements in a multiset. +type Sketch struct { + p uint8 + b uint8 + m uint32 + alpha float64 + tmpSet set + sparseList *compressedList + regs *registers +} + +// New returns a HyperLogLog Sketch with 2^14 registers (precision 14) +func New() *Sketch { + return New14() +} + +// New14 returns a HyperLogLog Sketch with 2^14 registers (precision 14) +func New14() *Sketch { + sk, _ := newSketch(14, true) + return sk +} + +// New16 returns a HyperLogLog Sketch with 2^16 registers (precision 16) +func New16() *Sketch { + sk, _ := newSketch(16, true) + return sk +} + +// NewNoSparse returns a HyperLogLog Sketch with 2^14 registers (precision 14) +// that will not use a sparse representation +func NewNoSparse() *Sketch { + sk, _ := newSketch(14, false) + return sk +} + +// New16NoSparse returns a HyperLogLog Sketch with 2^16 registers (precision 16) +// that will not use a sparse representation +func New16NoSparse() *Sketch { + sk, _ := newSketch(16, false) + return sk +} + +// newSketch returns a HyperLogLog Sketch with 2^precision registers +func newSketch(precision uint8, sparse bool) (*Sketch, error) { + if precision < 4 || precision > 18 { + return nil, fmt.Errorf("p has to be >= 4 and <= 18") + } + m := uint32(math.Pow(2, float64(precision))) + s := &Sketch{ + m: m, + p: precision, + alpha: alpha(float64(m)), + } + if sparse { + s.tmpSet = set{} + s.sparseList = newCompressedList() + } else { + s.regs = newRegisters(m) + } + return s, nil +} + +func (sk *Sketch) sparse() bool { + return sk.sparseList != nil +} + +// Clone returns a deep copy of sk. +func (sk *Sketch) Clone() *Sketch { + return &Sketch{ + b: sk.b, + p: sk.p, + m: sk.m, + alpha: sk.alpha, + tmpSet: sk.tmpSet.Clone(), + sparseList: sk.sparseList.Clone(), + regs: sk.regs.clone(), + } +} + +// Converts to normal if the sparse list is too large. +func (sk *Sketch) maybeToNormal() { + if uint32(len(sk.tmpSet))*100 > sk.m { + sk.mergeSparse() + if uint32(sk.sparseList.Len()) > sk.m { + sk.toNormal() + } + } +} + +// Merge takes another Sketch and combines it with Sketch h. +// If Sketch h is using the sparse Sketch, it will be converted +// to the normal Sketch. +func (sk *Sketch) Merge(other *Sketch) error { + if other == nil { + // Nothing to do + return nil + } + cpOther := other.Clone() + + if sk.p != cpOther.p { + return errors.New("precisions must be equal") + } + + if sk.sparse() && other.sparse() { + for k := range other.tmpSet { + sk.tmpSet.add(k) + } + for iter := other.sparseList.Iter(); iter.HasNext(); { + sk.tmpSet.add(iter.Next()) + } + sk.maybeToNormal() + return nil + } + + if sk.sparse() { + sk.toNormal() + } + + if cpOther.sparse() { + for k := range cpOther.tmpSet { + i, r := decodeHash(k, cpOther.p, pp) + sk.insert(i, r) + } + + for iter := cpOther.sparseList.Iter(); iter.HasNext(); { + i, r := decodeHash(iter.Next(), cpOther.p, pp) + sk.insert(i, r) + } + } else { + if sk.b < cpOther.b { + sk.regs.rebase(cpOther.b - sk.b) + sk.b = cpOther.b + } else { + cpOther.regs.rebase(sk.b - cpOther.b) + cpOther.b = sk.b + } + + for i, v := range cpOther.regs.tailcuts { + v1 := v.get(0) + if v1 > sk.regs.get(uint32(i)*2) { + sk.regs.set(uint32(i)*2, v1) + } + v2 := v.get(1) + if v2 > sk.regs.get(1+uint32(i)*2) { + sk.regs.set(1+uint32(i)*2, v2) + } + } + } + return nil +} + +// Convert from sparse Sketch to dense Sketch. +func (sk *Sketch) toNormal() { + if len(sk.tmpSet) > 0 { + sk.mergeSparse() + } + + sk.regs = newRegisters(sk.m) + for iter := sk.sparseList.Iter(); iter.HasNext(); { + i, r := decodeHash(iter.Next(), sk.p, pp) + sk.insert(i, r) + } + + sk.tmpSet = nil + sk.sparseList = nil +} + +func (sk *Sketch) insert(i uint32, r uint8) bool { + changed := false + if r-sk.b >= capacity { + //overflow + db := sk.regs.min() + if db > 0 { + sk.b += db + sk.regs.rebase(db) + changed = true + } + } + if r > sk.b { + val := r - sk.b + if c1 := capacity - 1; c1 < val { + val = c1 + } + + if val > sk.regs.get(i) { + sk.regs.set(i, val) + changed = true + } + } + return changed +} + +// Insert adds element e to sketch +func (sk *Sketch) Insert(e []byte) bool { + x := hash(e) + return sk.InsertHash(x) +} + +// InsertHash adds hash x to sketch +func (sk *Sketch) InsertHash(x uint64) bool { + if sk.sparse() { + changed := sk.tmpSet.add(encodeHash(x, sk.p, pp)) + if !changed { + return false + } + if uint32(len(sk.tmpSet))*100 > sk.m/2 { + sk.mergeSparse() + if uint32(sk.sparseList.Len()) > sk.m/2 { + sk.toNormal() + } + } + return true + } else { + i, r := getPosVal(x, sk.p) + return sk.insert(uint32(i), r) + } +} + +// Estimate returns the cardinality of the Sketch +func (sk *Sketch) Estimate() uint64 { + if sk.sparse() { + sk.mergeSparse() + return uint64(linearCount(mp, mp-sk.sparseList.count)) + } + + sum, ez := sk.regs.sumAndZeros(sk.b) + m := float64(sk.m) + var est float64 + + var beta func(float64) float64 + if sk.p < 16 { + beta = beta14 + } else { + beta = beta16 + } + + if sk.b == 0 { + est = (sk.alpha * m * (m - ez) / (sum + beta(ez))) + } else { + est = (sk.alpha * m * m / sum) + } + + return uint64(est + 0.5) +} + +func (sk *Sketch) mergeSparse() { + if len(sk.tmpSet) == 0 { + return + } + + keys := make(uint64Slice, 0, len(sk.tmpSet)) + for k := range sk.tmpSet { + keys = append(keys, k) + } + sort.Sort(keys) + + newList := newCompressedList() + for iter, i := sk.sparseList.Iter(), 0; iter.HasNext() || i < len(keys); { + if !iter.HasNext() { + newList.Append(keys[i]) + i++ + continue + } + + if i >= len(keys) { + newList.Append(iter.Next()) + continue + } + + x1, x2 := iter.Peek(), keys[i] + if x1 == x2 { + newList.Append(iter.Next()) + i++ + } else if x1 > x2 { + newList.Append(x2) + i++ + } else { + newList.Append(iter.Next()) + } + } + + sk.sparseList = newList + sk.tmpSet = set{} +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (sk *Sketch) MarshalBinary() (data []byte, err error) { + // Marshal a version marker. + data = append(data, version) + // Marshal p. + data = append(data, sk.p) + // Marshal b + data = append(data, sk.b) + + if sk.sparse() { + // It's using the sparse Sketch. + data = append(data, byte(1)) + + // Add the tmp_set + tsdata, err := sk.tmpSet.MarshalBinary() + if err != nil { + return nil, err + } + data = append(data, tsdata...) + + // Add the sparse Sketch + sdata, err := sk.sparseList.MarshalBinary() + if err != nil { + return nil, err + } + return append(data, sdata...), nil + } + + // It's using the dense Sketch. + data = append(data, byte(0)) + + // Add the dense sketch Sketch. + sz := len(sk.regs.tailcuts) + data = append(data, []byte{ + byte(sz >> 24), + byte(sz >> 16), + byte(sz >> 8), + byte(sz), + }...) + + // Marshal each element in the list. + for i := 0; i < len(sk.regs.tailcuts); i++ { + data = append(data, byte(sk.regs.tailcuts[i])) + } + + return data, nil +} + +// ErrorTooShort is an error that UnmarshalBinary try to parse too short +// binary. +var ErrorTooShort = errors.New("too short binary") + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +func (sk *Sketch) UnmarshalBinary(data []byte) error { + if len(data) < 8 { + return ErrorTooShort + } + + // Unmarshal version. We may need this in the future if we make + // non-compatible changes. + _ = data[0] + + // Unmarshal p. + p := data[1] + + // Unmarshal b. + sk.b = data[2] + + // Determine if we need a sparse Sketch + sparse := data[3] == byte(1) + + // Make a newSketch Sketch if the precision doesn't match or if the Sketch was used + if sk.p != p || sk.regs != nil || len(sk.tmpSet) > 0 || (sk.sparseList != nil && sk.sparseList.Len() > 0) { + newh, err := newSketch(p, sparse) + if err != nil { + return err + } + newh.b = sk.b + *sk = *newh + } + + // h is now initialised with the correct p. We just need to fill the + // rest of the details out. + if sparse { + // Using the sparse Sketch. + + // Unmarshal the tmp_set. + tssz := binary.BigEndian.Uint32(data[4:8]) + sk.tmpSet = make(map[uint32]struct{}, tssz) + + // We need to unmarshal tssz values in total, and each value requires us + // to read 4 bytes. + tsLastByte := int((tssz * 4) + 8) + for i := 8; i < tsLastByte; i += 4 { + k := binary.BigEndian.Uint32(data[i : i+4]) + sk.tmpSet[k] = struct{}{} + } + + // Unmarshal the sparse Sketch. + return sk.sparseList.UnmarshalBinary(data[tsLastByte:]) + } + + // Using the dense Sketch. + sk.sparseList = nil + sk.tmpSet = nil + dsz := binary.BigEndian.Uint32(data[4:8]) + sk.regs = newRegisters(dsz * 2) + data = data[8:] + + for i, val := range data { + sk.regs.tailcuts[i] = reg(val) + if uint8(sk.regs.tailcuts[i]<<4>>4) > 0 { + sk.regs.nz-- + } + if uint8(sk.regs.tailcuts[i]>>4) > 0 { + sk.regs.nz-- + } + } + + return nil +} diff --git a/hyperloglog/registers.go b/hyperloglog/registers.go new file mode 100644 index 00000000..19bb5d47 --- /dev/null +++ b/hyperloglog/registers.go @@ -0,0 +1,114 @@ +package hyperloglog + +import ( + "math" +) + +type reg uint8 +type tailcuts []reg + +type registers struct { + tailcuts + nz uint32 +} + +func (r *reg) set(offset, val uint8) bool { + var isZero bool + if offset == 0 { + isZero = *r < 16 + tmpVal := uint8((*r) << 4 >> 4) + *r = reg(tmpVal | (val << 4)) + } else { + isZero = *r&0x0f == 0 + tmpVal := uint8((*r) >> 4 << 4) + *r = reg(tmpVal | val) + } + return isZero +} + +func (r *reg) get(offset uint8) uint8 { + if offset == 0 { + return uint8((*r) >> 4) + } + return uint8((*r) << 4 >> 4) +} + +func newRegisters(size uint32) *registers { + return ®isters{ + tailcuts: make(tailcuts, size/2), + nz: size, + } +} + +func (rs *registers) clone() *registers { + if rs == nil { + return nil + } + tc := make([]reg, len(rs.tailcuts)) + copy(tc, rs.tailcuts) + return ®isters{ + tailcuts: tc, + nz: rs.nz, + } +} + +func (rs *registers) rebase(delta uint8) { + nz := uint32(len(rs.tailcuts)) * 2 + for i := range rs.tailcuts { + for j := uint8(0); j < 2; j++ { + val := rs.tailcuts[i].get(j) + if val >= delta { + rs.tailcuts[i].set(j, val-delta) + if val-delta > 0 { + nz-- + } + } + } + } + rs.nz = nz +} + +func (rs *registers) set(i uint32, val uint8) { + offset, index := uint8(i)&1, i/2 + if rs.tailcuts[index].set(offset, val) { + rs.nz-- + } +} + +func (rs *registers) get(i uint32) uint8 { + offset, index := uint8(i)&1, i/2 + return rs.tailcuts[index].get(offset) +} + +func (rs *registers) sumAndZeros(base uint8) (res, ez float64) { + for _, r := range rs.tailcuts { + for j := uint8(0); j < 2; j++ { + v := float64(base + r.get(j)) + if v == 0 { + ez++ + } + res += 1.0 / math.Pow(2.0, v) + } + } + rs.nz = uint32(ez) + return res, ez +} + +func (rs *registers) min() uint8 { + if rs.nz > 0 { + return 0 + } + min := uint8(math.MaxUint8) + for _, r := range rs.tailcuts { + if r == 0 || min == 0 { + return 0 + } + if val := uint8(r << 4 >> 4); val < min { + min = val + } + if val := uint8(r >> 4); val < min { + min = val + } + } + return min +} diff --git a/hyperloglog/sparse.go b/hyperloglog/sparse.go new file mode 100644 index 00000000..8c457d32 --- /dev/null +++ b/hyperloglog/sparse.go @@ -0,0 +1,92 @@ +package hyperloglog + +import ( + "math/bits" +) + +func getIndex(k uint32, p, pp uint8) uint32 { + if k&1 == 1 { + return bextr32(k, 32-p, p) + } + return bextr32(k, pp-p+1, p) +} + +// Encode a hash to be used in the sparse representation. +func encodeHash(x uint64, p, pp uint8) uint32 { + idx := uint32(bextr(x, 64-pp, pp)) + if bextr(x, 64-pp, pp-p) == 0 { + zeros := bits.LeadingZeros64((bextr(x, 0, 64-pp)<> 24), + byte(sl >> 16), + byte(sl >> 8), + byte(sl), + }...) + + // Marshal each element in the set. + for k := range s { + data = append(data, []byte{ + byte(k >> 24), + byte(k >> 16), + byte(k >> 8), + byte(k), + }...) + } + + return data, nil +} + +type uint64Slice []uint32 + +func (p uint64Slice) Len() int { return len(p) } +func (p uint64Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/hyperloglog/utils.go b/hyperloglog/utils.go new file mode 100644 index 00000000..896bf7e7 --- /dev/null +++ b/hyperloglog/utils.go @@ -0,0 +1,69 @@ +package hyperloglog + +import ( + "github.com/alicebob/miniredis/v2/metro" + "math" + "math/bits" +) + +var hash = hashFunc + +func beta14(ez float64) float64 { + zl := math.Log(ez + 1) + return -0.370393911*ez + + 0.070471823*zl + + 0.17393686*math.Pow(zl, 2) + + 0.16339839*math.Pow(zl, 3) + + -0.09237745*math.Pow(zl, 4) + + 0.03738027*math.Pow(zl, 5) + + -0.005384159*math.Pow(zl, 6) + + 0.00042419*math.Pow(zl, 7) +} + +func beta16(ez float64) float64 { + zl := math.Log(ez + 1) + return -0.37331876643753059*ez + + -1.41704077448122989*zl + + 0.40729184796612533*math.Pow(zl, 2) + + 1.56152033906584164*math.Pow(zl, 3) + + -0.99242233534286128*math.Pow(zl, 4) + + 0.26064681399483092*math.Pow(zl, 5) + + -0.03053811369682807*math.Pow(zl, 6) + + 0.00155770210179105*math.Pow(zl, 7) +} + +func alpha(m float64) float64 { + switch m { + case 16: + return 0.673 + case 32: + return 0.697 + case 64: + return 0.709 + } + return 0.7213 / (1 + 1.079/m) +} + +func getPosVal(x uint64, p uint8) (uint64, uint8) { + i := bextr(x, 64-p, p) // {x63,...,x64-p} + w := x<

> start) & ((1 << length) - 1) +} + +func bextr32(v uint32, start, length uint8) uint32 { + return (v >> start) & ((1 << length) - 1) +} + +func hashFunc(e []byte) uint64 { + return metro.Hash64(e, 1337) +} diff --git a/integration/hll_test.go b/integration/hll_test.go new file mode 100644 index 00000000..207898b5 --- /dev/null +++ b/integration/hll_test.go @@ -0,0 +1,100 @@ +// +build int + +package main + +// Hash keys. + +import ( + "math/rand" + "testing" +) + +func TestHll(t *testing.T) { + t.Run("basics", func(t *testing.T) { + testRaw(t, func(c *client) { + // Add 100 unique random values to h1 and 50 of these 100 to h2 + for i := 0; i < 100; i++ { + value := randomStr(10) + c.Do("PFADD", "h1", value) + if i%2 == 0 { + c.Do("PFADD", "h2", value) + } + } + + for i := 0; i < 100; i++ { + c.Do("PFADD", "h3", randomStr(10)) + } + + // Merge non-intersecting hlls + { + c.Do( + "PFMERGE", + "res1", + "h1", // count 100 + "h3", // count 100 + ) + c.DoApprox(2, "PFCOUNT", "res1") + } + + // Merge intersecting hlls + { + c.Do( + "PFMERGE", + "res2", + "h1", // count 100 + "h2", // count 50 (all 50 are presented in h1) + ) + c.DoApprox(2, "PFCOUNT", "res2") + } + + // Merge all hlls + { + c.Do( + "PFMERGE", + "res3", + "h1", // count 100 + "h2", // count 50 (all 50 are presented in h1) + "h3", // count 100 + "h4", // empty key + ) + c.DoApprox(2, "PFCOUNT", "res3") + } + + // failure cases + c.Error("wrong number", "PFADD") + c.Error("wrong number", "PFCOUNT") + c.Error("wrong number", "PFMERGE") + c.Do("SET", "str", "I am a string") + c.Error("not a valid HyperLogLog", "PFADD", "str", "noot", "mies") + c.Error("not a valid HyperLogLog", "PFCOUNT", "str", "h1") + c.Error("not a valid HyperLogLog", "PFMERGE", "str", "noot") + c.Error("not a valid HyperLogLog", "PFMERGE", "noot", "str") + + c.Do("DEL", "h1", "h2", "h3", "h4", "res1", "res2", "res3") + c.Do("PFCOUNT", "h1", "h2", "h3", "h4", "res1", "res2", "res3") + }) + }) + + t.Run("tx", func(t *testing.T) { + testRaw(t, func(c *client) { + c.Do("MULTI") + c.Do("PFADD", "h1", "noot", "mies", "vuur", "wim") + c.Do("PFADD", "h2", "noot1", "mies1", "vuur1", "wim1") + c.Do("PFMERGE", "h3", "h1", "h2") + c.Do("PFCOUNT", "h1") + c.Do("PFCOUNT", "h2") + c.Do("PFCOUNT", "h3") + c.Do("EXEC") + }) + }) +} + +const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randomStr(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} diff --git a/integration/test.go b/integration/test.go index 5496db58..0a044aa8 100644 --- a/integration/test.go +++ b/integration/test.go @@ -4,6 +4,7 @@ package main import ( "fmt" + "math" "reflect" "sort" "strconv" @@ -481,6 +482,48 @@ func (c *client) DoRounded(rounded int, cmd string, args ...string) { } } +// result must match, with floats rounded +func (c *client) DoApprox(threshold int, cmd string, args ...string) { + c.t.Helper() + + resReal, errReal := c.real.Do(append([]string{cmd}, args...)...) + if errReal != nil { + c.t.Errorf("error from realredis: %s", errReal) + return + } + resMini, errMini := c.mini.Do(append([]string{cmd}, args...)...) + if errMini != nil { + c.t.Errorf("error from miniredis: %s", errMini) + return + } + + // c.t.Logf("real:%q mini:%q", string(resReal), string(resMini)) + + mini, err := proto.Parse(resMini) + if err != nil { + c.t.Errorf("parse error miniredis: %s", err) + return + } + real, err := proto.Parse(resReal) + if err != nil { + c.t.Errorf("parse error realredis: %s", err) + return + } + miniInt, ok := mini.(int) + if !ok { + c.t.Errorf("parse int error miniredis: %T found (%#v)", mini, mini) + return + } + realInt, ok := real.(int) + if !ok { + c.t.Errorf("parse int error miniredis: %T found (%#v)", real, real) + return + } + if math.Abs(float64(miniInt-realInt)) > float64(threshold) { + c.t.Errorf("expected an approximated match (threshold is %d) want: %#v have: %#v", threshold, real, mini) + } +} + // both must return an error, which much both Contain() the message. func (c *client) Error(msg string, cmd string, args ...string) { c.t.Helper() diff --git a/metro/LICENSE b/metro/LICENSE new file mode 100644 index 00000000..6243b617 --- /dev/null +++ b/metro/LICENSE @@ -0,0 +1,24 @@ +This package is a mechanical translation of the reference C++ code for +MetroHash, available at https://github.com/jandrewrogers/MetroHash + +The MIT License (MIT) + +Copyright (c) 2016 Damian Gryski + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/metro/README.md b/metro/README.md new file mode 100644 index 00000000..07e4ee9f --- /dev/null +++ b/metro/README.md @@ -0,0 +1 @@ +This is a partial copy of github.com/dgryski/go-metro. \ No newline at end of file diff --git a/metro/metro64.go b/metro/metro64.go new file mode 100644 index 00000000..5b3db9a9 --- /dev/null +++ b/metro/metro64.go @@ -0,0 +1,87 @@ +package metro + +import "encoding/binary" + +func Hash64(buffer []byte, seed uint64) uint64 { + + const ( + k0 = 0xD6D018F5 + k1 = 0xA2AA033B + k2 = 0x62992FC1 + k3 = 0x30BC5B29 + ) + + ptr := buffer + + hash := (seed + k2) * k0 + + if len(ptr) >= 32 { + v := [4]uint64{hash, hash, hash, hash} + + for len(ptr) >= 32 { + v[0] += binary.LittleEndian.Uint64(ptr[:8]) * k0 + v[0] = rotate_right(v[0], 29) + v[2] + v[1] += binary.LittleEndian.Uint64(ptr[8:16]) * k1 + v[1] = rotate_right(v[1], 29) + v[3] + v[2] += binary.LittleEndian.Uint64(ptr[16:24]) * k2 + v[2] = rotate_right(v[2], 29) + v[0] + v[3] += binary.LittleEndian.Uint64(ptr[24:32]) * k3 + v[3] = rotate_right(v[3], 29) + v[1] + ptr = ptr[32:] + } + + v[2] ^= rotate_right(((v[0]+v[3])*k0)+v[1], 37) * k1 + v[3] ^= rotate_right(((v[1]+v[2])*k1)+v[0], 37) * k0 + v[0] ^= rotate_right(((v[0]+v[2])*k0)+v[3], 37) * k1 + v[1] ^= rotate_right(((v[1]+v[3])*k1)+v[2], 37) * k0 + hash += v[0] ^ v[1] + } + + if len(ptr) >= 16 { + v0 := hash + (binary.LittleEndian.Uint64(ptr[:8]) * k2) + v0 = rotate_right(v0, 29) * k3 + v1 := hash + (binary.LittleEndian.Uint64(ptr[8:16]) * k2) + v1 = rotate_right(v1, 29) * k3 + v0 ^= rotate_right(v0*k0, 21) + v1 + v1 ^= rotate_right(v1*k3, 21) + v0 + hash += v1 + ptr = ptr[16:] + } + + if len(ptr) >= 8 { + hash += binary.LittleEndian.Uint64(ptr[:8]) * k3 + ptr = ptr[8:] + hash ^= rotate_right(hash, 55) * k1 + } + + if len(ptr) >= 4 { + hash += uint64(binary.LittleEndian.Uint32(ptr[:4])) * k3 + hash ^= rotate_right(hash, 26) * k1 + ptr = ptr[4:] + } + + if len(ptr) >= 2 { + hash += uint64(binary.LittleEndian.Uint16(ptr[:2])) * k3 + ptr = ptr[2:] + hash ^= rotate_right(hash, 48) * k1 + } + + if len(ptr) >= 1 { + hash += uint64(ptr[0]) * k3 + hash ^= rotate_right(hash, 37) * k1 + } + + hash ^= rotate_right(hash, 28) + hash *= k0 + hash ^= rotate_right(hash, 29) + + return hash +} + +func Hash64Str(buffer string, seed uint64) uint64 { + return Hash64([]byte(buffer), seed) +} + +func rotate_right(v uint64, k uint) uint64 { + return (v >> k) | (v << (64 - k)) +} diff --git a/miniredis.go b/miniredis.go index d5acccc3..465635c1 100644 --- a/miniredis.go +++ b/miniredis.go @@ -42,6 +42,7 @@ type RedisDB struct { hashKeys map[string]hashKey // MGET/MSET &c. keys listKeys map[string]listKey // LPUSH &c. keys setKeys map[string]setKey // SADD &c. keys + hllKeys map[string]*hll // PFADD &c. keys sortedsetKeys map[string]sortedSet // ZADD &c. keys streamKeys map[string]*streamKey // XADD &c. keys ttl map[string]time.Duration // effective TTL values @@ -105,6 +106,7 @@ func newRedisDB(id int, m *Miniredis) RedisDB { hashKeys: map[string]hashKey{}, listKeys: map[string]listKey{}, setKeys: map[string]setKey{}, + hllKeys: map[string]*hll{}, sortedsetKeys: map[string]sortedSet{}, streamKeys: map[string]*streamKey{}, ttl: map[string]time.Duration{}, @@ -174,6 +176,7 @@ func (m *Miniredis) start(s *server.Server) error { commandsGeo(m) commandsCluster(m) commandsCommand(m) + commandsHll(m) return nil } @@ -369,6 +372,10 @@ func (m *Miniredis) Dump() string { r += fmt.Sprintf("%s%s%s: %s\n", indent, indent, v(ev[2*i]), v(ev[2*i+1])) } } + case "hll": + for _, entry := range db.hllKeys { + r += fmt.Sprintf("%s%s\n", indent, v(string(entry.Bytes()))) + } default: r += fmt.Sprintf("%s(a %s, fixme!)\n", indent, t) } diff --git a/redis.go b/redis.go index 479de015..4f6ea641 100644 --- a/redis.go +++ b/redis.go @@ -12,6 +12,7 @@ import ( const ( msgWrongType = "WRONGTYPE Operation against a key holding the wrong kind of value" + msgNotValidHllValue = "WRONGTYPE Key is not a valid HyperLogLog string value." msgInvalidInt = "ERR value is not an integer or out of range" msgInvalidFloat = "ERR value is not a valid float" msgInvalidMinMax = "ERR min or max is not a float" diff --git a/test_test.go b/test_test.go index 73515b83..0c49fff8 100644 --- a/test_test.go +++ b/test_test.go @@ -1,6 +1,7 @@ package miniredis import ( + "math/rand" "reflect" "strings" "testing" @@ -115,3 +116,13 @@ func mustContain(tb testing.TB, c *proto.Client, args ...string) { func useRESP3(t *testing.T, c *proto.Client) { mustContain(t, c, "HELLO", "3", "miniredis") } + +const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randomStr(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +}