diff --git a/getfindmutex.go b/getfindmutex.go index 4efc659..af2963c 100644 --- a/getfindmutex.go +++ b/getfindmutex.go @@ -1,6 +1,9 @@ package discovery -import "sync" +import ( + "fmt" + "sync" +) // GetFindMutex A modified version of a RWMutex. Many get locks can be held but // only one Find lock. A waiting Find lock (even if it hasn't been locked, just @@ -11,28 +14,67 @@ import "sync" // requests, since it's likely that once FIND has been run, subsequent GET // requests will be able to be served from cache type GetFindMutex struct { - mutex sync.RWMutex + mutexMap map[string]*sync.RWMutex + mapLock sync.Mutex } // GetLock Gets a lock that can be held by an unlimited number of goroutines, -// these locks are only blocked by FindLocks -func (g *GetFindMutex) GetLock() { - g.mutex.RLock() +// these locks are only blocked by FindLocks. A type and context must be +// provided since a Get in one type (or context) should not be blocked by a Find +// in another +func (g *GetFindMutex) GetLock(itemContext string, typ string) { + g.mutexFor(itemContext, typ).RLock() } // GetUnlock Unlocks the GetLock. This must be called once for each GetLock // otherwise it will be impossible to ever obtain a FindLock -func (g *GetFindMutex) GetUnlock() { - g.mutex.RUnlock() +func (g *GetFindMutex) GetUnlock(itemContext string, typ string) { + g.mutexFor(itemContext, typ).RUnlock() } -// FindLock An exclusive lock. Ensure that all GetLocks have been unlocked -// and stops any more from being obtained -func (g *GetFindMutex) FindLock() { - g.mutex.Lock() +// FindLock An exclusive lock. Ensure that all GetLocks have been unlocked and +// stops any more from being obtained. Provide a type and context to ensure that +// the lock is only help for that type and context combination rather than +// locking the whole engine +func (g *GetFindMutex) FindLock(itemContext string, typ string) { + g.mutexFor(itemContext, typ).Lock() } // FindUnlock Unlocks a FindLock -func (g *GetFindMutex) FindUnlock() { - g.mutex.Unlock() +func (g *GetFindMutex) FindUnlock(itemContext string, typ string) { + g.mutexFor(itemContext, typ).Unlock() +} + +// mutexFor Returns the relevant RWMutex for a given context and type, creating +// and storing a new one if needed +func (g *GetFindMutex) mutexFor(itemContext string, typ string) *sync.RWMutex { + var mutex *sync.RWMutex + var ok bool + + keyName := g.keyName(itemContext, typ) + + g.mapLock.Lock() + defer g.mapLock.Unlock() + + // Create the map if needed + if g.mutexMap == nil { + g.mutexMap = make(map[string]*sync.RWMutex) + } + + // Get the mutex from storage + mutex, ok = g.mutexMap[keyName] + + // If the mutex wasn't found for this key, create a new one + if !ok { + mutex = &sync.RWMutex{} + g.mutexMap[keyName] = mutex + } + + return mutex +} + +// keyName Returns the name of the key for a given context and type combo for +// use with the mutexMap +func (g *GetFindMutex) keyName(itemContext string, typ string) string { + return fmt.Sprintf("%v.%v", itemContext, typ) } diff --git a/getfindmutex_test.go b/getfindmutex_test.go index cd3c139..73caab6 100644 --- a/getfindmutex_test.go +++ b/getfindmutex_test.go @@ -14,12 +14,38 @@ func TestGetLock(t *testing.T) { doneChan := make(chan bool) go func() { - gfm.GetLock() - gfm.GetLock() - gfm.GetLock() - gfm.GetUnlock() - gfm.GetUnlock() - gfm.GetUnlock() + gfm.GetLock("testContext", "testType") + gfm.GetLock("testContext", "testType") + gfm.GetLock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") + doneChan <- true + }() + + select { + case <-ctx.Done(): + t.Error("Timeout") + case <-doneChan: + } + + cancel() + }) + + t.Run("many find locks from different types and contexts can be held at once", func(t *testing.T) { + var gfm GetFindMutex + ctx, cancel := context.WithTimeout(context.Background(), (1 * time.Second)) + doneChan := make(chan bool) + + go func() { + gfm.FindLock("testContext1", "testType1") + gfm.FindLock("testContext1", "testType2") + gfm.FindLock("testContext2", "testType") + gfm.FindLock("testContext3", "testType") + gfm.FindUnlock("testContext1", "testType1") + gfm.FindUnlock("testContext1", "testType2") + gfm.FindUnlock("testContext2", "testType") + gfm.FindUnlock("testContext3", "testType") doneChan <- true }() @@ -38,15 +64,15 @@ func TestGetLock(t *testing.T) { getChan := make(chan bool) findChan := make(chan bool) - gfm.FindLock() + gfm.FindLock("testContext", "testType") go func() { - gfm.GetLock() - gfm.GetLock() - gfm.GetLock() - gfm.GetUnlock() - gfm.GetUnlock() - gfm.GetUnlock() + gfm.GetLock("testContext", "testType") + gfm.GetLock("testContext", "testType") + gfm.GetLock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") + gfm.GetUnlock("testContext", "testType") getChan <- true }() @@ -83,13 +109,13 @@ func TestGetLock(t *testing.T) { go func() { defer wg.Done() - gfm.GetLock() + gfm.GetLock("testContext", "testType") actionChan <- "getLock1" // do some work time.Sleep(50 * time.Millisecond) - gfm.GetUnlock() + gfm.GetUnlock("testContext", "testType") }() @@ -97,14 +123,14 @@ func TestGetLock(t *testing.T) { defer wg.Done() time.Sleep(10 * time.Millisecond) - gfm.FindLock() + gfm.FindLock("testContext", "testType") actionChan <- "findLock1" // do some work time.Sleep(50 * time.Millisecond) - gfm.FindUnlock() + gfm.FindUnlock("testContext", "testType") }() @@ -112,14 +138,14 @@ func TestGetLock(t *testing.T) { defer wg.Done() time.Sleep(20 * time.Millisecond) - gfm.GetLock() + gfm.GetLock("testContext", "testType") actionChan <- "getLock2" // do some work time.Sleep(50 * time.Millisecond) - gfm.GetUnlock() + gfm.GetUnlock("testContext", "testType") }() diff --git a/source.go b/source.go index ffe063d..695e952 100644 --- a/source.go +++ b/source.go @@ -107,7 +107,7 @@ func (e *Engine) Get(typ string, context string, query string) (*sdp.Item, error } } - e.gfm.GetLock() + e.gfm.GetLock(context, typ) for _, src := range relevantSources { tags := sdpcache.Tags{ @@ -143,7 +143,7 @@ func (e *Engine) Get(typ string, context string, query string) (*sdp.Item, error // If the cache found something then just return that log.WithFields(logFields).Debug("Found item from cache") - e.gfm.GetUnlock() + e.gfm.GetUnlock(context, typ) return cached[0], nil } @@ -208,12 +208,12 @@ func (e *Engine) Get(typ string, context string, query string) (*sdp.Item, error // Store the new item in the cache e.cache.StoreItem(item, GetCacheDuration(src), tags) - e.gfm.GetUnlock() + e.gfm.GetUnlock(context, typ) return item, nil } - e.gfm.GetUnlock() + e.gfm.GetUnlock(context, typ) } // If we don't find anything then we should raise an error @@ -240,8 +240,8 @@ func (e *Engine) Find(typ string, context string) ([]*sdp.Item, error) { } } - e.gfm.FindLock() - defer e.gfm.FindUnlock() + e.gfm.FindLock(context, typ) + defer e.gfm.FindUnlock(context, typ) items := make([]*sdp.Item, 0) errors := make([]error, 0) @@ -397,8 +397,8 @@ func (e *Engine) Search(typ string, context string, query string) ([]*sdp.Item, } } - e.gfm.GetLock() - defer e.gfm.GetUnlock() + e.gfm.GetLock(context, typ) + defer e.gfm.GetUnlock(context, typ) items := make([]*sdp.Item, 0) errors := make([]error, 0)