From 94a8f37844da65d9dfa896884eaa8231d954575c Mon Sep 17 00:00:00 2001 From: Alex Dvoretsky Date: Sun, 7 Apr 2024 15:35:27 +0300 Subject: [PATCH 1/3] avoid parallel running a load function for the same key #242 --- lib/cache/loadable.go | 56 ++++++++++++++++++++++++++++++-------- lib/cache/loadable_test.go | 45 ++++++++++++++++++++++++------ lib/go.mod | 1 + lib/go.sum | 2 ++ 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/lib/cache/loadable.go b/lib/cache/loadable.go index 2612376a..a339b177 100644 --- a/lib/cache/loadable.go +++ b/lib/cache/loadable.go @@ -2,9 +2,12 @@ package cache import ( "context" + "errors" + "fmt" "sync" "github.com/eko/gocache/lib/v4/store" + "golang.org/x/sync/singleflight" ) const ( @@ -21,19 +24,21 @@ type LoadFunction[T any] func(ctx context.Context, key any) (T, error) // LoadableCache represents a cache that uses a function to load data type LoadableCache[T any] struct { - loadFunc LoadFunction[T] - cache CacheInterface[T] - setChannel chan *loadableKeyValue[T] - setterWg *sync.WaitGroup + singleFlight singleflight.Group + loadFunc LoadFunction[T] + cache CacheInterface[T] + setChannel chan *loadableKeyValue[T] + setterWg *sync.WaitGroup } -// NewLoadable instanciates a new cache that uses a function to load data +// NewLoadable instantiates a new cache that uses a function to load data func NewLoadable[T any](loadFunc LoadFunction[T], cache CacheInterface[T]) *LoadableCache[T] { loadable := &LoadableCache[T]{ - loadFunc: loadFunc, - cache: cache, - setChannel: make(chan *loadableKeyValue[T], 10000), - setterWg: &sync.WaitGroup{}, + singleFlight: singleflight.Group{}, + loadFunc: loadFunc, + cache: cache, + setChannel: make(chan *loadableKeyValue[T], 10000), + setterWg: &sync.WaitGroup{}, } loadable.setterWg.Add(1) @@ -47,6 +52,7 @@ func (c *LoadableCache[T]) setter() { for item := range c.setChannel { c.Set(context.Background(), item.key, item.value) + c.singleFlight.Forget(c.getCacheKey(item.key)) } } @@ -60,9 +66,21 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) { } // Unable to find in cache, try to load it from load function - object, err = c.loadFunc(ctx, key) - if err != nil { - return object, err + var r any + if r, err, _ = c.singleFlight.Do( + c.getCacheKey(key), + func() (any, error) { + return c.loadFunc(ctx, key) + }, + ); err != nil { + return *new(T), err + } + var ok bool + if object, ok = r.(T); !ok { + zero := *new(T) + return zero, errors.New( + fmt.Sprintf("returned value can't be cast to %T", zero), + ) } // Then, put it back in cache @@ -102,3 +120,17 @@ func (c *LoadableCache[T]) Close() error { return nil } + +// getCacheKey returns the cache key for the given key object by returning +// the key if type is string or by computing a checksum of key structure +// if its type is other than string +func (c *LoadableCache[T]) getCacheKey(key any) string { + switch v := key.(type) { + case string: + return v + case CacheKeyGenerator: + return v.GetCacheKey() + default: + return checksum(key) + } +} diff --git a/lib/cache/loadable_test.go b/lib/cache/loadable_test.go index 9e0e4870..b7c535cc 100644 --- a/lib/cache/loadable_test.go +++ b/lib/cache/loadable_test.go @@ -3,6 +3,8 @@ package cache import ( "context" "errors" + "sync" + "sync/atomic" "testing" "time" @@ -98,25 +100,50 @@ func TestLoadableGetWhenAvailableInLoadFunc(t *testing.T) { // Cache 1 cache1 := NewMockSetterCacheInterface[any](ctrl) cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) + cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) + cache1.EXPECT().Get(ctx, "my-key").Return(nil, errors.New("unable to find in cache 1")) cache1.EXPECT().Set(ctx, "my-key", cacheValue).AnyTimes().Return(nil) + var loadCallCount int32 + pauseLoadFn := make(chan struct{}) + loadFunc := func(_ context.Context, key any) (any, error) { + atomic.AddInt32(&loadCallCount, 1) + <-pauseLoadFn + time.Sleep(time.Millisecond * 10) return cacheValue, nil } cache := NewLoadable[any](loadFunc, cache1) - // When - value, err := cache.Get(ctx, "my-key") - - // Wait for data to be processed - for len(cache.setChannel) > 0 { - time.Sleep(1 * time.Millisecond) + const numRequests = 3 + var started sync.WaitGroup + started.Add(numRequests) + var finished sync.WaitGroup + finished.Add(numRequests) + for i := 0; i < numRequests; i++ { + go func() { + defer finished.Done() + started.Done() + // When + value, err := cache.Get(ctx, "my-key") + + // Wait for data to be processed + for len(cache.setChannel) > 0 { + time.Sleep(1 * time.Millisecond) + } + + // Then + assert.Nil(t, err) + assert.Equal(t, cacheValue, value) + }() } - // Then - assert.Nil(t, err) - assert.Equal(t, cacheValue, value) + started.Wait() + close(pauseLoadFn) + finished.Wait() + + assert.Equal(t, int32(1), loadCallCount) } func TestLoadableDelete(t *testing.T) { diff --git a/lib/go.mod b/lib/go.mod index ea5a4667..bf24f67d 100644 --- a/lib/go.mod +++ b/lib/go.mod @@ -8,6 +8,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 + golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f ) require ( diff --git a/lib/go.sum b/lib/go.sum index 8dd57fa9..846614ad 100644 --- a/lib/go.sum +++ b/lib/go.sum @@ -304,6 +304,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f h1:Ax0t5p6N38Ga0dThY21weqDEyz2oklo4IvDkpigvkD8= +golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= From 96952063fcab58bc2aaeb17859e2da63677b6e91 Mon Sep 17 00:00:00 2001 From: Alex Dvoretsky Date: Tue, 16 Apr 2024 11:45:55 +0300 Subject: [PATCH 2/3] better var name and cleanup Co-authored-by: Vincent Composieux --- lib/cache/loadable.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/cache/loadable.go b/lib/cache/loadable.go index a339b177..13404d83 100644 --- a/lib/cache/loadable.go +++ b/lib/cache/loadable.go @@ -66,18 +66,21 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) { } // Unable to find in cache, try to load it from load function - var r any - if r, err, _ = c.singleFlight.Do( - c.getCacheKey(key), + cacheKey := c.getCacheKey(key) + zero := *new(T) + + loadedResult, err, _ := c.singleFlight.Do( + cacheKey, func() (any, error) { return c.loadFunc(ctx, key) }, - ); err != nil { - return *new(T), err + ) + if err != nil { + return zero, err } + var ok bool - if object, ok = r.(T); !ok { - zero := *new(T) + if object, ok = loadedResult.(T); !ok { return zero, errors.New( fmt.Sprintf("returned value can't be cast to %T", zero), ) From dcc921ea60b3ef5164c629c519a374b2563c9f5c Mon Sep 17 00:00:00 2001 From: Alex Dvoretsky Date: Tue, 16 Apr 2024 11:47:36 +0300 Subject: [PATCH 3/3] cleanup Co-authored-by: Vincent Composieux --- lib/cache/loadable.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/cache/loadable.go b/lib/cache/loadable.go index 13404d83..e2e642eb 100644 --- a/lib/cache/loadable.go +++ b/lib/cache/loadable.go @@ -52,7 +52,9 @@ func (c *LoadableCache[T]) setter() { for item := range c.setChannel { c.Set(context.Background(), item.key, item.value) - c.singleFlight.Forget(c.getCacheKey(item.key)) + + cacheKey := c.getCacheKey(item.key) + c.singleFlight.Forget(cacheKey) } }