From 846b2322976cff49d153fcabe656a949eb0bb47b Mon Sep 17 00:00:00 2001 From: Lianbo Date: Thu, 28 Nov 2024 17:23:02 +0800 Subject: [PATCH] fix invalid cache length when set with zero value key (#24) * fix invalid cache length when set with zero value key * Update ttl_shard.go Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --------- Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --- lru_shard.go | 4 +++- lru_shard_test.go | 39 +++++++++++++++++++++++++++++++++++++++ ttl_cache_test.go | 31 +++++++++++++++++++++++++++++++ ttl_shard.go | 4 +++- 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/lru_shard.go b/lru_shard.go index 4dbbce2..f6f8333 100644 --- a/lru_shard.go +++ b/lru_shard.go @@ -131,7 +131,9 @@ func (s *lrushard[K, V]) Set(hash uint32, key K, value V) (prev V, replaced bool index := s.list[0].prev node := (*lrunode[K, V])(unsafe.Add(unsafe.Pointer(&s.list[0]), uintptr(index)*unsafe.Sizeof(s.list[0]))) evictedValue := node.value - if key != node.key { + + // delete the old key if the list is full, note that the list length is size+1 + if uint32(len(s.list)-1) < s.tableLength+1 && key != node.key { s.tableDelete(uint32(s.tableHasher(noescape(unsafe.Pointer(&node.key)), s.tableSeed)), node.key) } diff --git a/lru_shard_test.go b/lru_shard_test.go index 650b0bd..de30e70 100644 --- a/lru_shard_test.go +++ b/lru_shard_test.go @@ -1,7 +1,9 @@ package lru import ( + "fmt" "testing" + "time" "unsafe" ) @@ -41,3 +43,40 @@ func TestLRUShardTableSet(t *testing.T) { t.Errorf("foobar should be set to 42: %v %v", i, ok) } } + +func TestLRUCacheLengthWithZeroValue(t *testing.T) { + cache := NewTTLCache[string, string](128, WithShards[string, string](1)) + + cache.Set("", "", time.Hour) + cache.Set("1", "1", time.Hour) + + if got, want := cache.Len(), 2; got != want { + t.Fatalf("curent cache length %v should be %v", got, want) + } + + for i := 2; i < 128; i++ { + k := fmt.Sprintf("%d", i) + if _, replace := cache.Set(k, k, time.Hour); replace { + t.Fatalf("key %v should not be replaced", k) + } + } + + if l := cache.Len(); l != 128 { + t.Fatalf("cache length %v should be 128", l) + } + + for i := 128; i < 256; i++ { + k := fmt.Sprintf("%d", i) + v := "" + if i-128 > 0 { + v = fmt.Sprintf("%d", i-128) + } + if prev, _ := cache.Set(k, k, time.Hour); prev != v { + t.Fatalf("value %v should be evicted", prev) + } + } + + if l := cache.Len(); l != 128 { + t.Fatalf("cache length %v should be 128", l) + } +} diff --git a/ttl_cache_test.go b/ttl_cache_test.go index 0b78312..f487d04 100644 --- a/ttl_cache_test.go +++ b/ttl_cache_test.go @@ -73,6 +73,37 @@ func TestTTLCacheGetSet(t *testing.T) { } } +func TestTTLCacheLengthWithZeroValue(t *testing.T) { + cache := NewTTLCache[int, int](128, WithShards[int, int](1)) + + cache.Set(0, 0, time.Hour) + cache.Set(1, 1, time.Hour) + + if got, want := cache.Len(), 2; got != want { + t.Fatalf("curent cache length %v should be %v", got, want) + } + + for i := 2; i < 128; i++ { + if _, replace := cache.Set(i, i, time.Hour); replace { + t.Fatalf("no value should be replaced") + } + } + + if l := cache.Len(); l != 128 { + t.Fatalf("cache length %v should be 128", l) + } + + for i := 128; i < 256; i++ { + if prev, _ := cache.Set(i, i, time.Hour); prev != i-128 { + t.Fatalf("value %v should be evicted", prev) + } + } + + if l := cache.Len(); l != 128 { + t.Fatalf("cache length %v should be 128", l) + } +} + func TestTTLCacheSetIfAbsent(t *testing.T) { cache := NewTTLCache[int, int](128) diff --git a/ttl_shard.go b/ttl_shard.go index 9a55201..b6ecb87 100644 --- a/ttl_shard.go +++ b/ttl_shard.go @@ -183,7 +183,9 @@ func (s *ttlshard[K, V]) Set(hash uint32, key K, value V, ttl time.Duration) (pr index := s.list[0].prev node := (*ttlnode[K, V])(unsafe.Add(unsafe.Pointer(&s.list[0]), uintptr(index)*unsafe.Sizeof(s.list[0]))) evictedValue := node.value - if key != node.key { + + // delete the old key if the list is full, note that the list length is size+1 + if len(s.list)-1 < int(s.tableLength+1) && key != node.key { s.tableDelete(uint32(s.tableHasher(noescape(unsafe.Pointer(&node.key)), s.tableSeed)), node.key) }