diff --git a/cache.go b/cache.go index 0f6d4163..7343360b 100644 --- a/cache.go +++ b/cache.go @@ -147,6 +147,16 @@ type Config[K Key, V any] struct { // as well as on rejection of the value. OnExit func(val V) + // ShouldUpdate is called when a value already exists in cache and is being updated. + // If ShouldUpdate returns true, the cache continues with the update (Set). If the + // function returns false, no changes are made in the cache. If the value doesn't + // already exist, the cache continue with setting that value for the given key. + // + // In this function, you can check whether the new value is valid. For example, if + // your value has timestamp assosicated with it, you could check whether the new + // value has the latest timestamp, preventing you from setting an older value. + ShouldUpdate func(cur, prev V) bool + // KeyToHash function is used to customize the key hashing algorithm. // Each key will be hashed using the provided function. If keyToHash value // is not set, the default keyToHash function is used. @@ -231,6 +241,7 @@ func NewCache[K Key, V any](config *Config[K, V]) (*Cache[K, V], error) { ignoreInternalCost: config.IgnoreInternalCost, cleanupTicker: time.NewTicker(time.Duration(config.TtlTickerDurationInSec) * time.Second / 2), } + cache.storedItems.SetShouldUpdateFn(config.ShouldUpdate) cache.onExit = func(val V) { if config.OnExit != nil { config.OnExit(val) diff --git a/store.go b/store.go index dfcfaa8d..8c4fe4c7 100644 --- a/store.go +++ b/store.go @@ -21,6 +21,8 @@ import ( "time" ) +type updateFn[V any] func(cur, prev V) bool + // TODO: Do we need this to be a separate struct from Item? type storeItem[V any] struct { key uint64 @@ -53,6 +55,7 @@ type store[V any] interface { Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V])) // Clear clears all contents of the store. Clear(onEvict func(item *Item[V])) + SetShouldUpdateFn(f updateFn[V]) } // newStore returns the default store implementation. @@ -78,6 +81,12 @@ func newShardedMap[V any]() *shardedMap[V] { return sm } +func (m *shardedMap[V]) SetShouldUpdateFn(f updateFn[V]) { + for i := range m.shards { + m.shards[i].setShouldUpdateFn(f) + } +} + func (sm *shardedMap[V]) Get(key, conflict uint64) (V, bool) { return sm.shards[key%numShards].get(key, conflict) } @@ -116,17 +125,25 @@ func (sm *shardedMap[V]) Clear(onEvict func(item *Item[V])) { type lockedMap[V any] struct { sync.RWMutex - data map[uint64]storeItem[V] - em *expirationMap[V] + data map[uint64]storeItem[V] + em *expirationMap[V] + shouldUpdate updateFn[V] } func newLockedMap[V any](em *expirationMap[V]) *lockedMap[V] { return &lockedMap[V]{ data: make(map[uint64]storeItem[V]), em: em, + shouldUpdate: func(cur, prev V) bool { + return true + }, } } +func (m *lockedMap[V]) setShouldUpdateFn(f updateFn[V]) { + m.shouldUpdate = f +} + func (m *lockedMap[V]) get(key, conflict uint64) (V, bool) { m.RLock() item, ok := m.data[key] @@ -167,6 +184,9 @@ func (m *lockedMap[V]) Set(i *Item[V]) { if i.Conflict != 0 && (i.Conflict != item.conflict) { return } + if m.shouldUpdate != nil && !m.shouldUpdate(i.Value, item.value) { + return + } m.em.update(i.Key, i.Conflict, item.expiration, i.Expiration) } else { // The value is not in the map already. There's no need to return anything. @@ -205,15 +225,17 @@ func (m *lockedMap[V]) Del(key, conflict uint64) (uint64, V) { func (m *lockedMap[V]) Update(newItem *Item[V]) (V, bool) { m.Lock() + defer m.Unlock() item, ok := m.data[newItem.Key] if !ok { - m.Unlock() return zeroValue[V](), false } if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) { - m.Unlock() return zeroValue[V](), false } + if m.shouldUpdate != nil && !m.shouldUpdate(newItem.Value, item.value) { + return item.value, false + } m.em.update(newItem.Key, newItem.Conflict, item.expiration, newItem.Expiration) m.data[newItem.Key] = storeItem[V]{ @@ -223,7 +245,6 @@ func (m *lockedMap[V]) Update(newItem *Item[V]) (V, bool) { expiration: newItem.Expiration, } - m.Unlock() return item.value, true } diff --git a/store_test.go b/store_test.go index 2cc99004..9ffa1bc2 100644 --- a/store_test.go +++ b/store_test.go @@ -76,6 +76,29 @@ func TestStoreClear(t *testing.T) { } } +func TestShouldUpdate(t *testing.T) { + // Create a should update function where the value only increases. + s := newStore[int]() + s.SetShouldUpdateFn(func(cur, prev int) bool { + return cur > prev + }) + + key, conflict := z.KeyToHash(1) + i := Item[int]{ + Key: key, + Conflict: conflict, + Value: 2, + } + s.Set(&i) + i.Value = 1 + _, ok := s.Update(&i) + require.False(t, ok) + + i.Value = 3 + _, ok = s.Update(&i) + require.True(t, ok) +} + func TestStoreUpdate(t *testing.T) { s := newStore[int]() key, conflict := z.KeyToHash(1)