Skip to content

Commit

Permalink
genericize key
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Sep 24, 2023
1 parent 8dd46e9 commit 79bda29
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 82 deletions.
167 changes: 118 additions & 49 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,60 @@ package shardmap
import (
"runtime"
"sync"
"sync/atomic"
"unsafe"

"github.com/zeebo/xxh3"
)

// Map is a hashmap. Like map[string]any, but sharded and thread-safe.
type Map[V any] struct {
init sync.Once
// Map is a hashmap. Like map[comparable]any, but sharded and thread-safe.
type Map[K comparable, V any] struct {
inited uint32
initmu sync.Mutex
cap int
shards int
mus []sync.RWMutex
maps []*rhhMap[V]
maps []*rhhMap[K, V]
ksize int
kstr bool
zero V
}

// New returns a new hashmap with the specified capacity. This function is only
// needed when you must define a minimum capacity, otherwise just use:
//
// var m shardmap.Map[any]
func New[V any](cap int) *Map[V] {
return &Map[V]{cap: cap}
// var m shardmap.Map[string, any]
func New[K comparable, V any](cap int) *Map[K, V] {
return &Map[K, V]{cap: cap}
}

// Clear out all values from map
func (m *Map[V]) Clear() {
m.initDo()
func (m *Map[K, V]) Clear() {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
for i := 0; i < m.shards; i++ {
m.mus[i].Lock()
m.maps[i] = rhhNew[V](m.cap / m.shards)
m.maps[i] = rhhNew[K, V](m.cap / m.shards)
m.mus[i].Unlock()
}
}

// Set assigns a value to a key.
// Returns the previous value, or false when no value was assigned.
func (m *Map[V]) Set(key string, value V) (prev V, replaced bool) {
m.initDo()
xxh := xxh3.HashString(key)
func (m *Map[K, V]) Set(key K, value V) (prev V, replaced bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var xxh uint64
if m.kstr {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&key)))
} else {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&struct {
data unsafe.Pointer
len int
}{unsafe.Pointer(&key), m.ksize})))
}
shard := int(xxh & uint64(m.shards-1))
m.mus[shard].Lock()
prev, replaced = m.maps[shard].Set(xxh, key, value)
Expand All @@ -52,12 +69,19 @@ func (m *Map[V]) Set(key string, value V) (prev V, replaced bool) {
// It's also provides a safe way to block other others from writing to the
// same shard while inspecting.
// Returns the previous value, or false when no value was assigned.
func (m *Map[V]) SetAccept(
key string, value V,
accept func(prev V, replaced bool) bool,
) (prev V, replaced bool) {
m.initDo()
xxh := xxh3.HashString(key)
func (m *Map[K, V]) SetAccept(key K, value V, accept func(prev V, replaced bool) bool) (prev V, replaced bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var xxh uint64
if m.kstr {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&key)))
} else {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&struct {
data unsafe.Pointer
len int
}{unsafe.Pointer(&key), m.ksize})))
}
shard := int(xxh & uint64(m.shards-1))
m.mus[shard].Lock()
defer m.mus[shard].Unlock()
Expand All @@ -80,9 +104,19 @@ func (m *Map[V]) SetAccept(

// Get returns a value for a key.
// Returns false when no value has been assign for key.
func (m *Map[V]) Get(key string) (value V, ok bool) {
m.initDo()
xxh := xxh3.HashString(key)
func (m *Map[K, V]) Get(key K) (value V, ok bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var xxh uint64
if m.kstr {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&key)))
} else {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&struct {
data unsafe.Pointer
len int
}{unsafe.Pointer(&key), m.ksize})))
}
shard := int(xxh & uint64(m.shards-1))
m.mus[shard].RLock()
value, ok = m.maps[shard].Get(xxh, key)
Expand All @@ -92,9 +126,19 @@ func (m *Map[V]) Get(key string) (value V, ok bool) {

// Delete deletes a value for a key.
// Returns the deleted value, or false when no value was assigned.
func (m *Map[V]) Delete(key string) (prev V, deleted bool) {
m.initDo()
xxh := xxh3.HashString(key)
func (m *Map[K, V]) Delete(key K) (prev V, deleted bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var xxh uint64
if m.kstr {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&key)))
} else {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&struct {
data unsafe.Pointer
len int
}{unsafe.Pointer(&key), m.ksize})))
}
shard := int(xxh & uint64(m.shards-1))
m.mus[shard].Lock()
prev, deleted = m.maps[shard].Delete(xxh, key)
Expand All @@ -107,12 +151,19 @@ func (m *Map[V]) Delete(key string) (prev V, deleted bool) {
// It's also provides a safe way to block other others from writing to the
// same shard while inspecting.
// Returns the deleted value, or false when no value was assigned.
func (m *Map[V]) DeleteAccept(
key string,
accept func(prev V, replaced bool) bool,
) (prev V, deleted bool) {
m.initDo()
xxh := xxh3.HashString(key)
func (m *Map[K, V]) DeleteAccept(key K, accept func(prev V, replaced bool) bool) (prev V, deleted bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var xxh uint64
if m.kstr {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&key)))
} else {
xxh = xxh3.HashString(*(*string)(unsafe.Pointer(&struct {
data unsafe.Pointer
len int
}{unsafe.Pointer(&key), m.ksize})))
}
shard := int(xxh & uint64(m.shards-1))
m.mus[shard].Lock()
defer m.mus[shard].Unlock()
Expand All @@ -132,8 +183,10 @@ func (m *Map[V]) DeleteAccept(
}

// Len returns the number of values in map.
func (m *Map[V]) Len() int {
m.initDo()
func (m *Map[K, V]) Len() int {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var len int
for i := 0; i < m.shards; i++ {
m.mus[i].Lock()
Expand All @@ -145,14 +198,16 @@ func (m *Map[V]) Len() int {

// Range iterates overall all key/values.
// It's not safe to call or Set or Delete while ranging.
func (m *Map[V]) Range(iter func(key string, value V) bool) {
m.initDo()
func (m *Map[K, V]) Range(iter func(key K, value V) bool) {
if atomic.LoadUint32(&m.inited) == 0 {
m.initDo()
}
var done bool
for i := 0; i < m.shards; i++ {
func(i int) {
m.mus[i].RLock()
defer m.mus[i].RUnlock()
m.maps[i].Range(func(key string, value V) bool {
m.maps[i].Range(func(key K, value V) bool {
if !iter(key, value) {
done = true
return false
Expand All @@ -166,17 +221,31 @@ func (m *Map[V]) Range(iter func(key string, value V) bool) {
}
}

func (m *Map[V]) initDo() {
m.init.Do(func() {
m.shards = 1
for m.shards < runtime.NumCPU()*16 {
m.shards *= 2
}
scap := m.cap / m.shards
m.mus = make([]sync.RWMutex, m.shards)
m.maps = make([]*rhhMap[V], m.shards)
for i := 0; i < len(m.maps); i++ {
m.maps[i] = rhhNew[V](scap)
}
})
func (m *Map[K, V]) initDo() {
m.initmu.Lock()
defer m.initmu.Unlock()

if atomic.LoadUint32(&m.inited) != 0 {
return
}
defer atomic.StoreUint32(&m.inited, 1)

m.shards = 1
for m.shards < runtime.NumCPU()*16 {
m.shards *= 2
}
scap := m.cap / m.shards
m.mus = make([]sync.RWMutex, m.shards)
m.maps = make([]*rhhMap[K, V], m.shards)
for i := 0; i < len(m.maps); i++ {
m.maps[i] = rhhNew[K, V](scap)
}

var k K
switch ((any)(k)).(type) {
case string:
m.kstr = true
default:
m.ksize = int(unsafe.Sizeof(k))
}
}
14 changes: 7 additions & 7 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ func TestRandomData(t *testing.T) {
start := time.Now()
for time.Since(start) < time.Second*2 {
nums := random(N, true)
var m *Map[string]
var m *Map[string, string]
switch rand.Int() % 5 {
default:
m = New[string](N / ((rand.Int() % 3) + 1))
m = New[string, string](N / ((rand.Int() % 3) + 1))
case 1:
m = new(Map[string])
m = new(Map[string, string])
case 2:
m = New[string](0)
m = New[string, string](0)
}
v, ok := m.Get(k(999))
if ok || v != "" {
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestRandomData(t *testing.T) {
}

func TestSetAccept(t *testing.T) {
var m Map[string]
var m Map[string, string]
m.Set("hello", "world")
prev, replaced := m.SetAccept("hello", "planet", nil)
if !replaced {
Expand Down Expand Up @@ -247,7 +247,7 @@ func TestSetAccept(t *testing.T) {
}

func TestDeleteAccept(t *testing.T) {
var m Map[string]
var m Map[string, string]
m.Set("hello", "world")
prev, deleted := m.DeleteAccept("hello", nil)
if !deleted {
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestDeleteAccept(t *testing.T) {
}

func TestClear(t *testing.T) {
var m Map[int]
var m Map[string, int]
for i := 0; i < 1000; i++ {
m.Set(fmt.Sprintf("%d", i), i)
}
Expand Down
Loading

0 comments on commit 79bda29

Please sign in to comment.