From 0991f9bd3aa1eaeefe9a01a25fb9a1da9355d4bb Mon Sep 17 00:00:00 2001
From: Vlasashk <79715336+vlasashk@users.noreply.github.com>
Date: Thu, 29 Aug 2024 12:15:47 +0300
Subject: [PATCH] Use `defer` to unlock mutex (#66)

* FIX: deadlock fix - mutex unlock via defer in GetOrSet

* FIX: move all mutex Unlocks into defer

* FIX: remove unnecessary new lines
---
 cleaner.go |  6 ++---
 imcache.go | 72 ++++++++++++------------------------------------------
 2 files changed, 18 insertions(+), 60 deletions(-)

diff --git a/cleaner.go b/cleaner.go
index 019839d..1122624 100644
--- a/cleaner.go
+++ b/cleaner.go
@@ -30,14 +30,13 @@ func (c *cleaner) start(r eremover, interval time.Duration) error {
 		return errors.New("imcache: interval must be greater than 0")
 	}
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.running {
-		c.mu.Unlock()
 		return errors.New("imcache: cleaner already running")
 	}
 	c.running = true
 	c.stopCh = make(chan struct{})
 	c.doneCh = make(chan struct{})
-	c.mu.Unlock()
 	go func() {
 		ticker := time.NewTicker(interval)
 		defer ticker.Stop()
@@ -56,13 +55,12 @@ func (c *cleaner) start(r eremover, interval time.Duration) error {
 
 func (c *cleaner) stop() {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if !c.running {
-		c.mu.Unlock()
 		return
 	}
 	c.running = false
 	close(c.stopCh)
 	// Wait for the cleaner goroutine to stop while holding the lock.
 	<-c.doneCh
-	c.mu.Unlock()
 }
diff --git a/imcache.go b/imcache.go
index fb592f3..29a0a25 100644
--- a/imcache.go
+++ b/imcache.go
@@ -99,20 +99,18 @@ func (c *Cache[K, V]) Get(key K) (V, bool) {
 	now := time.Now()
 	var zero V
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return zero, false
 	}
 	node, ok := c.m[key]
 	if !ok {
-		c.mu.Unlock()
 		return zero, false
 	}
 	entry := node.entry()
 	if entry.expired(now) {
 		c.queue.remove(node)
 		delete(c.m, key)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(key, entry.val, EvictionReasonExpired)
 		}
@@ -121,7 +119,6 @@ func (c *Cache[K, V]) Get(key K) (V, bool) {
 	entry.slide(now)
 	node.setEntry(entry)
 	c.queue.touch(node)
-	c.mu.Unlock()
 	return entry.val, true
 }
 
@@ -136,8 +133,8 @@ func (c *Cache[K, V]) GetMultiple(keys ...K) map[K]V {
 
 func (c *Cache[K, V]) getMultiple(now time.Time, keys ...K) map[K]V {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return nil
 	}
 	got := make(map[K]V, len(keys))
@@ -159,7 +156,6 @@ func (c *Cache[K, V]) getMultiple(now time.Time, keys ...K) map[K]V {
 			c.queue.touch(node)
 			got[key] = entry.val
 		}
-		c.mu.Unlock()
 		return got
 	}
 	var expired []entry[K, V]
@@ -180,7 +176,6 @@ func (c *Cache[K, V]) getMultiple(now time.Time, keys ...K) map[K]V {
 		c.queue.touch(node)
 		got[key] = entry.val
 	}
-	c.mu.Unlock()
 	go func() {
 		for _, entry := range expired {
 			c.onEviction(entry.key, entry.val, EvictionReasonExpired)
@@ -198,8 +193,8 @@ func (c *Cache[K, V]) GetAll() map[K]V {
 
 func (c *Cache[K, V]) getAll(now time.Time) map[K]V {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return nil
 	}
 	// To avoid copying the expired entries if there's no eviction callback.
@@ -217,7 +212,6 @@ func (c *Cache[K, V]) getAll(now time.Time) map[K]V {
 			got[key] = entry.val
 		}
 		c.queue.touchall()
-		c.mu.Unlock()
 		return got
 	}
 	var expired []entry[K, V]
@@ -235,7 +229,6 @@ func (c *Cache[K, V]) getAll(now time.Time) map[K]V {
 		got[key] = entry.val
 	}
 	c.queue.touchall()
-	c.mu.Unlock()
 	if len(expired) != 0 {
 		go func() {
 			for _, kv := range expired {
@@ -333,8 +326,8 @@ func (c *Cache[K, V]) peekAll(now time.Time) map[K]V {
 func (c *Cache[K, V]) Set(key K, val V, exp Expiration) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return
 	}
 	// Make sure that the shard is initialized.
@@ -346,7 +339,6 @@ func (c *Cache[K, V]) Set(key K, val V, exp Expiration) {
 		if !currentEntry.expired(now) {
 			currentNode.setEntry(newEntry)
 			c.queue.touch(currentNode)
-			c.mu.Unlock()
 			if c.onEviction != nil {
 				go c.onEviction(key, currentEntry.val, EvictionReasonReplaced)
 			}
@@ -354,7 +346,6 @@ func (c *Cache[K, V]) Set(key K, val V, exp Expiration) {
 		}
 		c.m[key] = c.queue.add(newEntry)
 		c.queue.remove(currentNode)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(key, currentEntry.val, EvictionReasonExpired)
 		}
@@ -366,7 +357,6 @@ func (c *Cache[K, V]) Set(key K, val V, exp Expiration) {
 		delete(c.m, evictedNode.entry().key)
 	}
 	c.m[key] = c.queue.add(entry[K, V]{key: key, val: val, exp: exp.new(now, c.defaultExp, c.sliding)})
-	c.mu.Unlock()
 	if c.onEviction == nil || evictedNode == nil {
 		return
 	}
@@ -386,8 +376,8 @@ func (c *Cache[K, V]) Set(key K, val V, exp Expiration) {
 func (c *Cache[K, V]) GetOrSet(key K, val V, exp Expiration) (value V, present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		var zero V
 		return zero, false
 	}
@@ -399,7 +389,6 @@ func (c *Cache[K, V]) GetOrSet(key K, val V, exp Expiration) (value V, present b
 		currentEntry.slide(now)
 		currentNode.setEntry(currentEntry)
 		c.queue.touch(currentNode)
-		c.mu.Unlock()
 		return currentEntry.val, true
 	}
 	if ok {
@@ -416,7 +405,6 @@ func (c *Cache[K, V]) GetOrSet(key K, val V, exp Expiration) (value V, present b
 		delete(c.m, evictedNode.entry().key)
 	}
 	c.m[key] = c.queue.add(entry[K, V]{key: key, val: val, exp: exp.new(now, c.defaultExp, c.sliding)})
-	c.mu.Unlock()
 	if c.onEviction == nil || evictedNode == nil {
 		return val, false
 	}
@@ -441,20 +429,18 @@ func (c *Cache[K, V]) GetOrSet(key K, val V, exp Expiration) (value V, present b
 func (c *Cache[K, V]) Replace(key K, val V, exp Expiration) (present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return false
 	}
 	currentNode, ok := c.m[key]
 	if !ok {
-		c.mu.Unlock()
 		return false
 	}
 	currentEntry := currentNode.entry()
 	if currentEntry.expired(now) {
 		c.queue.remove(currentNode)
 		delete(c.m, key)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(key, currentEntry.val, EvictionReasonExpired)
 		}
@@ -462,7 +448,6 @@ func (c *Cache[K, V]) Replace(key K, val V, exp Expiration) (present bool) {
 	}
 	currentNode.setEntry(entry[K, V]{key: key, val: val, exp: exp.new(now, c.defaultExp, c.sliding)})
 	c.queue.touch(currentNode)
-	c.mu.Unlock()
 	if c.onEviction != nil {
 		go c.onEviction(key, currentEntry.val, EvictionReasonReplaced)
 	}
@@ -518,20 +503,18 @@ func Decrement[V Number](old V) V {
 func (c *Cache[K, V]) ReplaceWithFunc(key K, f func(current V) (new V), exp Expiration) (present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return false
 	}
 	currentNode, ok := c.m[key]
 	if !ok {
-		c.mu.Unlock()
 		return false
 	}
 	currentEntry := currentNode.entry()
 	if currentEntry.expired(now) {
 		c.queue.remove(currentNode)
 		delete(c.m, key)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(key, currentEntry.val, EvictionReasonExpired)
 		}
@@ -539,7 +522,6 @@ func (c *Cache[K, V]) ReplaceWithFunc(key K, f func(current V) (new V), exp Expi
 	}
 	currentNode.setEntry(entry[K, V]{key: key, val: f(currentEntry.val), exp: exp.new(now, c.defaultExp, c.sliding)})
 	c.queue.touch(currentNode)
-	c.mu.Unlock()
 	if c.onEviction != nil {
 		go c.onEviction(key, currentEntry.val, EvictionReasonReplaced)
 	}
@@ -555,20 +537,18 @@ func (c *Cache[K, V]) ReplaceWithFunc(key K, f func(current V) (new V), exp Expi
 func (c *Cache[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return false
 	}
 	oldNode, ok := c.m[old]
 	if !ok {
-		c.mu.Unlock()
 		return false
 	}
 	oldEntry := oldNode.entry()
 	delete(c.m, old)
 	c.queue.remove(oldNode)
 	if oldEntry.expired(now) {
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(old, oldEntry.val, EvictionReasonExpired)
 		}
@@ -578,7 +558,6 @@ func (c *Cache[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	currentNode, ok := c.m[new]
 	if !ok {
 		c.m[new] = c.queue.add(newEntry)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(old, oldEntry.val, EvictionReasonKeyReplaced)
 		}
@@ -587,7 +566,6 @@ func (c *Cache[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	currentEntry := currentNode.entry()
 	currentNode.setEntry(newEntry)
 	c.queue.touch(currentNode)
-	c.mu.Unlock()
 	if c.onEviction != nil {
 		go func() {
 			c.onEviction(old, oldEntry.val, EvictionReasonKeyReplaced)
@@ -610,20 +588,18 @@ func (c *Cache[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 func (c *Cache[K, V]) CompareAndSwap(key K, expected, new V, compare func(V, V) bool, exp Expiration) (swapped, present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return false, false
 	}
 	currentNode, ok := c.m[key]
 	if !ok {
-		c.mu.Unlock()
 		return false, false
 	}
 	currentEntry := currentNode.entry()
 	if currentEntry.expired(now) {
 		c.queue.remove(currentNode)
 		delete(c.m, key)
-		c.mu.Unlock()
 		if c.onEviction != nil {
 			go c.onEviction(key, currentEntry.val, EvictionReasonExpired)
 		}
@@ -631,12 +607,10 @@ func (c *Cache[K, V]) CompareAndSwap(key K, expected, new V, compare func(V, V)
 	}
 	if !compare(currentEntry.val, expected) {
 		c.queue.touch(currentNode)
-		c.mu.Unlock()
 		return false, true
 	}
 	currentNode.setEntry(entry[K, V]{key: key, val: new, exp: exp.new(now, c.defaultExp, c.sliding)})
 	c.queue.touch(currentNode)
-	c.mu.Unlock()
 	if c.onEviction != nil {
 		go c.onEviction(key, currentEntry.val, EvictionReasonReplaced)
 	}
@@ -654,19 +628,17 @@ func (c *Cache[K, V]) CompareAndSwap(key K, expected, new V, compare func(V, V)
 func (c *Cache[K, V]) Remove(key K) (present bool) {
 	now := time.Now()
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return false
 	}
 	currentNode, ok := c.m[key]
 	if !ok {
-		c.mu.Unlock()
 		return false
 	}
 	c.queue.remove(currentNode)
 	delete(c.m, key)
 	currentEntry := currentNode.entry()
-	c.mu.Unlock()
 	if c.onEviction == nil {
 		return !currentEntry.expired(now)
 	}
@@ -691,14 +663,13 @@ func (c *Cache[K, V]) RemoveAll() {
 
 func (c *Cache[K, V]) removeAll(now time.Time) {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return
 	}
 	removed := c.m
 	c.m = make(map[K]node[K, V], c.limit)
 	c.queue = newEvictionQueue[K, V](c.limit, c.policy)
-	c.mu.Unlock()
 	if c.onEviction != nil && len(removed) != 0 {
 		go func() {
 			for key, node := range removed {
@@ -722,8 +693,8 @@ func (c *Cache[K, V]) RemoveExpired() {
 
 func (c *Cache[K, V]) removeExpired(now time.Time) {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return
 	}
 	// To avoid copying the expired entries if there's no eviction callback.
@@ -735,7 +706,6 @@ func (c *Cache[K, V]) removeExpired(now time.Time) {
 				delete(c.m, key)
 			}
 		}
-		c.mu.Unlock()
 		return
 	}
 	var removed []entry[K, V]
@@ -747,7 +717,6 @@ func (c *Cache[K, V]) removeExpired(now time.Time) {
 			c.queue.remove(node)
 		}
 	}
-	c.mu.Unlock()
 	if len(removed) != 0 {
 		go func() {
 			for _, entry := range removed {
@@ -760,12 +729,11 @@ func (c *Cache[K, V]) removeExpired(now time.Time) {
 // Len returns the number of entries in the cache.
 func (c *Cache[K, V]) Len() int {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return 0
 	}
 	n := c.len()
-	c.mu.Unlock()
 	return n
 }
 
@@ -785,13 +753,12 @@ func (c *Cache[K, V]) len() int {
 // and there is no cleaner running. Garbage collector will collect the cache.
 func (c *Cache[K, V]) Close() {
 	c.mu.Lock()
+	defer c.mu.Unlock()
 	if c.closed {
-		c.mu.Unlock()
 		return
 	}
 	c.m = nil
 	c.closed = true
-	c.mu.Unlock()
 	// If the cleaner is running, stop it.
 	// It's safe to access c.cleaner without a lock because
 	// it's only set during initialization and never modified.
@@ -1090,22 +1057,20 @@ func (s *Sharded[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 		return oldShard.ReplaceKey(old, new, exp)
 	}
 	oldShard.mu.Lock()
+	defer oldShard.mu.Unlock()
 	// Check if the old shard is closed.
 	// If so it means that the Sharded is closed as well.
 	if oldShard.closed {
-		oldShard.mu.Unlock()
 		return false
 	}
 	oldShardNode, ok := oldShard.m[old]
 	if !ok {
-		oldShard.mu.Unlock()
 		return false
 	}
 	oldShard.queue.remove(oldShardNode)
 	delete(oldShard.m, old)
 	oldShardEntry := oldShardNode.entry()
 	if oldShardEntry.expired(now) {
-		oldShard.mu.Unlock()
 		if oldShard.onEviction != nil {
 			go oldShard.onEviction(old, oldShardEntry.val, EvictionReasonExpired)
 		}
@@ -1113,6 +1078,7 @@ func (s *Sharded[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	}
 	newEntry := entry[K, V]{key: new, val: oldShardEntry.val, exp: exp.new(now, oldShard.defaultExp, oldShard.sliding)}
 	newShard.mu.Lock()
+	defer newShard.mu.Unlock()
 	newShardNode, ok := newShard.m[new]
 	if ok {
 		newShardEntry := newShardNode.entry()
@@ -1126,8 +1092,6 @@ func (s *Sharded[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 			newShard.queue.remove(newShardNode)
 			evictionReason = EvictionReasonExpired
 		}
-		oldShard.mu.Unlock()
-		newShard.mu.Unlock()
 		// Both eviction callbacks point to the same function.
 		if oldShard.onEviction != nil {
 			go func() {
@@ -1139,8 +1103,6 @@ func (s *Sharded[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	}
 	if newShard.limit == 0 || newShard.len() < newShard.limit {
 		newShard.m[new] = newShard.queue.add(newEntry)
-		oldShard.mu.Unlock()
-		newShard.mu.Unlock()
 		if oldShard.onEviction != nil {
 			go oldShard.onEviction(old, oldShardEntry.val, EvictionReasonKeyReplaced)
 		}
@@ -1150,8 +1112,6 @@ func (s *Sharded[K, V]) ReplaceKey(old, new K, exp Expiration) (present bool) {
 	evictedEntry := evictedNode.entry()
 	delete(newShard.m, evictedEntry.key)
 	newShard.m[new] = newShard.queue.add(newEntry)
-	oldShard.mu.Unlock()
-	newShard.mu.Unlock()
 	if oldShard.onEviction == nil {
 		return true
 	}