diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/conflict.go b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/conflict.go index 4adcc5b6ba..29a7036805 100644 --- a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/conflict.go +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/conflict.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex" "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/weight" "github.com/iotaledger/hive.go/ds/advancedset" "github.com/iotaledger/hive.go/lo" @@ -15,7 +16,7 @@ import ( type Conflict[ConflictID, ResourceID IDType] struct { // PreferredInsteadUpdated is triggered whenever preferred conflict is updated. It carries two values: // the new preferred conflict and a set of conflicts visited - PreferredInsteadUpdated *event.Event2[*Conflict[ConflictID, ResourceID], TriggerContext[ConflictID]] + PreferredInsteadUpdated *event.Event2[*Conflict[ConflictID, ResourceID], reentrantmutex.ThreadID] id ConflictID parents *advancedset.AdvancedSet[ConflictID] @@ -30,7 +31,7 @@ type Conflict[ConflictID, ResourceID IDType] struct { func New[ConflictID, ResourceID IDType](id ConflictID, parents *advancedset.AdvancedSet[ConflictID], conflictSets map[ResourceID]*Set[ConflictID, ResourceID], initialWeight *weight.Weight) *Conflict[ConflictID, ResourceID] { c := &Conflict[ConflictID, ResourceID]{ - PreferredInsteadUpdated: event.New2[*Conflict[ConflictID, ResourceID], TriggerContext[ConflictID]](), + PreferredInsteadUpdated: event.New2[*Conflict[ConflictID, ResourceID], reentrantmutex.ThreadID](), id: id, parents: parents, children: advancedset.New[*Conflict[ConflictID, ResourceID]](), @@ -39,9 +40,9 @@ func New[ConflictID, ResourceID IDType](id ConflictID, parents *advancedset.Adva } c.conflictingConflicts = NewSortedSet[ConflictID, ResourceID](c) - c.conflictingConflicts.HeaviestPreferredMemberUpdated.Hook(func(eventConflict *Conflict[ConflictID, ResourceID], visitedConflicts TriggerContext[ConflictID]) { - fmt.Println(c.ID(), "prefers", eventConflict.ID()) - c.PreferredInsteadUpdated.Trigger(eventConflict, visitedConflicts) + c.conflictingConflicts.HeaviestPreferredMemberUpdated.Hook(func(eventConflict *Conflict[ConflictID, ResourceID], threadID reentrantmutex.ThreadID) { + fmt.Println(c.ID(), "prefers", eventConflict.ID(), threadID) + c.PreferredInsteadUpdated.Trigger(eventConflict, threadID) }) // add existing conflicts first, so we can correctly determine the preferred instead flag @@ -98,15 +99,12 @@ func (c *Conflict[ConflictID, ResourceID]) Compare(other *Conflict[ConflictID, R return bytes.Compare(lo.PanicOnErr(c.id.Bytes()), lo.PanicOnErr(other.id.Bytes())) } -func (c *Conflict[ConflictID, ResourceID]) PreferredInstead() *Conflict[ConflictID, ResourceID] { - c.mutex.RLock() - defer c.mutex.RUnlock() - - return c.conflictingConflicts.HeaviestPreferredConflict() +func (c *Conflict[ConflictID, ResourceID]) PreferredInstead(optThreadID ...reentrantmutex.ThreadID) *Conflict[ConflictID, ResourceID] { + return c.conflictingConflicts.HeaviestPreferredConflict(optThreadID...) } -func (c *Conflict[ConflictID, ResourceID]) IsPreferred() bool { - return c.PreferredInstead() == c +func (c *Conflict[ConflictID, ResourceID]) IsPreferred(optThreadID ...reentrantmutex.ThreadID) bool { + return c.PreferredInstead(optThreadID...) == c } func (c *Conflict[ConflictID, ResourceID]) String() string { diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedset.go b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedset.go index ff518efa60..513e32ea5a 100644 --- a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedset.go +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedset.go @@ -2,13 +2,12 @@ package conflict import ( "fmt" - "math/rand" "sync" "sync/atomic" + "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex" "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/weight" "github.com/iotaledger/hive.go/ds/shrinkingmap" - "github.com/iotaledger/hive.go/ds/types" "github.com/iotaledger/hive.go/runtime/event" "github.com/iotaledger/hive.go/runtime/syncutils" "github.com/iotaledger/hive.go/stringify" @@ -17,7 +16,7 @@ import ( // SortedSet is a set of Conflicts that is sorted by their weight. type SortedSet[ConflictID, ResourceID IDType] struct { // HeaviestPreferredMemberUpdated is triggered when the heaviest preferred member of the SortedSet changes. - HeaviestPreferredMemberUpdated *event.Event2[*Conflict[ConflictID, ResourceID], TriggerContext[ConflictID]] + HeaviestPreferredMemberUpdated *event.Event2[*Conflict[ConflictID, ResourceID], reentrantmutex.ThreadID] // owner is the Conflict that owns this SortedSet. owner *Conflict[ConflictID, ResourceID] @@ -47,17 +46,18 @@ type SortedSet[ConflictID, ResourceID IDType] struct { isShutdown atomic.Bool // mutex is used to synchronize access to the SortedSet. - mutex sync.RWMutex + mutex *reentrantmutex.ReEntrantMutex } // NewSortedSet creates a new SortedSet that is owned by the given Conflict. func NewSortedSet[ConflictID, ResourceID IDType](owner *Conflict[ConflictID, ResourceID]) *SortedSet[ConflictID, ResourceID] { s := &SortedSet[ConflictID, ResourceID]{ - HeaviestPreferredMemberUpdated: event.New2[*Conflict[ConflictID, ResourceID], TriggerContext[ConflictID]](), + HeaviestPreferredMemberUpdated: event.New2[*Conflict[ConflictID, ResourceID], reentrantmutex.ThreadID](), owner: owner, members: shrinkingmap.New[ConflictID, *sortedSetMember[ConflictID, ResourceID]](), pendingWeightUpdates: shrinkingmap.New[ConflictID, *sortedSetMember[ConflictID, ResourceID]](), pendingWeightUpdatesCounter: syncutils.NewCounter(), + mutex: reentrantmutex.New(owner.ID().String()), } s.pendingWeightUpdatesSignal = sync.NewCond(&s.pendingWeightUpdatesMutex) @@ -70,9 +70,13 @@ func NewSortedSet[ConflictID, ResourceID IDType](owner *Conflict[ConflictID, Res } // Add adds the given Conflict to the SortedSet. -func (s *SortedSet[ConflictID, ResourceID]) Add(conflict *Conflict[ConflictID, ResourceID]) { - s.mutex.Lock() - defer s.mutex.Unlock() +func (s *SortedSet[ConflictID, ResourceID]) Add(conflict *Conflict[ConflictID, ResourceID], optThreadID ...reentrantmutex.ThreadID) { + if len(optThreadID) == 0 { + optThreadID = []reentrantmutex.ThreadID{reentrantmutex.NewThreadID()} + } + + s.mutex.Lock(optThreadID[0]) + defer s.mutex.UnLock(optThreadID[0]) newMember, isNew := s.members.GetOrCreate(conflict.id, func() *sortedSetMember[ConflictID, ResourceID] { return newSortedSetMember[ConflictID, ResourceID](s, conflict) @@ -119,17 +123,21 @@ func (s *SortedSet[ConflictID, ResourceID]) Add(conflict *Conflict[ConflictID, R } } - if conflict.IsPreferred() && newMember.Compare(s.heaviestPreferredMember) == weight.Heavier { + if conflict.IsPreferred(optThreadID[0]) && newMember.Compare(s.heaviestPreferredMember) == weight.Heavier { s.heaviestPreferredMember = newMember - s.HeaviestPreferredMemberUpdated.Trigger(conflict, NewTriggerContext(conflict.ID())) + s.HeaviestPreferredMemberUpdated.Trigger(conflict, optThreadID[0]) } } // ForEach iterates over all Conflicts of the SortedSet and calls the given callback for each of them. -func (s *SortedSet[ConflictID, ResourceID]) ForEach(callback func(*Conflict[ConflictID, ResourceID]) error) error { - s.mutex.RLock() - defer s.mutex.RUnlock() +func (s *SortedSet[ConflictID, ResourceID]) ForEach(callback func(*Conflict[ConflictID, ResourceID]) error, optThreadID ...reentrantmutex.ThreadID) error { + if len(optThreadID) == 0 { + optThreadID = []reentrantmutex.ThreadID{reentrantmutex.NewThreadID()} + } + + s.mutex.RLock(optThreadID[0]) + defer s.mutex.RUnlock(optThreadID[0]) for currentMember := s.heaviestMember; currentMember != nil; currentMember = currentMember.lighterMember { if err := callback(currentMember.Conflict); err != nil { @@ -141,9 +149,13 @@ func (s *SortedSet[ConflictID, ResourceID]) ForEach(callback func(*Conflict[Conf } // HeaviestConflict returns the heaviest Conflict of the SortedSet. -func (s *SortedSet[ConflictID, ResourceID]) HeaviestConflict() *Conflict[ConflictID, ResourceID] { - s.mutex.RLock() - defer s.mutex.RUnlock() +func (s *SortedSet[ConflictID, ResourceID]) HeaviestConflict(optThreadID ...reentrantmutex.ThreadID) *Conflict[ConflictID, ResourceID] { + if len(optThreadID) == 0 { + optThreadID = []reentrantmutex.ThreadID{reentrantmutex.NewThreadID()} + } + + s.mutex.RLock(optThreadID[0]) + defer s.mutex.RUnlock(optThreadID[0]) if s.heaviestMember == nil { return nil @@ -153,14 +165,16 @@ func (s *SortedSet[ConflictID, ResourceID]) HeaviestConflict() *Conflict[Conflic } // HeaviestPreferredConflict returns the heaviest preferred Conflict of the SortedSet. -func (s *SortedSet[ConflictID, ResourceID]) HeaviestPreferredConflict() *Conflict[ConflictID, ResourceID] { - a := rand.Float64() +func (s *SortedSet[ConflictID, ResourceID]) HeaviestPreferredConflict(optThreadID ...reentrantmutex.ThreadID) *Conflict[ConflictID, ResourceID] { + if len(optThreadID) == 0 { + optThreadID = []reentrantmutex.ThreadID{reentrantmutex.NewThreadID()} + } - fmt.Println("HeaviestPreferreConflict", s.owner.ID(), a) - defer fmt.Println("unlocked HeaviestPreferreConflict", s.owner.ID(), a) + fmt.Println("HeaviestPreferreConflict", s.owner.ID(), optThreadID[0]) + defer fmt.Println("unlocked HeaviestPreferreConflict", s.owner.ID(), optThreadID[0]) - s.mutex.RLock() - defer s.mutex.RUnlock() + s.mutex.RLock(optThreadID[0]) + defer s.mutex.RUnlock(optThreadID[0]) if s.heaviestPreferredMember == nil { return nil @@ -200,17 +214,17 @@ func (s *SortedSet[ConflictID, ResourceID]) notifyPendingWeightUpdate(member *so } // notifyPreferredInsteadUpdate notifies the SortedSet about a member that changed its preferred instead flag. -func (s *SortedSet[ConflictID, ResourceID]) notifyPreferredInsteadUpdate(member *sortedSetMember[ConflictID, ResourceID], preferred bool, visitedConflicts TriggerContext[ConflictID]) { - fmt.Println("Write-Lock", s.owner.ID(), "notifyPreferredInsteadUpdate(", member.ID(), ",", preferred, ",", visitedConflicts, ")") - defer fmt.Println("Write-Unlock", s.owner.ID(), "notifyPreferredInsteadUpdate(", member.ID(), ",", preferred, ",", visitedConflicts, ")") +func (s *SortedSet[ConflictID, ResourceID]) notifyPreferredInsteadUpdate(member *sortedSetMember[ConflictID, ResourceID], preferred bool, threadID reentrantmutex.ThreadID) { + fmt.Println("Write-Lock", s.owner.ID(), "notifyPreferredInsteadUpdate(", member.ID(), ",", preferred, ",", threadID, ")") + defer fmt.Println("Write-Unlock", s.owner.ID(), "notifyPreferredInsteadUpdate(", member.ID(), ",", preferred, ",", threadID, ")") - s.mutex.Lock() - defer s.mutex.Unlock() + s.mutex.Lock(threadID) + defer s.mutex.UnLock(threadID) if preferred { if member.Compare(s.heaviestPreferredMember) == weight.Heavier { s.heaviestPreferredMember = member - s.HeaviestPreferredMemberUpdated.Trigger(member.Conflict, visitedConflicts) + s.HeaviestPreferredMemberUpdated.Trigger(member.Conflict, threadID) } return @@ -221,12 +235,12 @@ func (s *SortedSet[ConflictID, ResourceID]) notifyPreferredInsteadUpdate(member } currentMember := member.lighterMember - for currentMember.Conflict != s.owner && !currentMember.IsPreferred() && currentMember.PreferredInstead() != member.Conflict { + for currentMember.Conflict != s.owner && !currentMember.IsPreferred(threadID) && currentMember.PreferredInstead(threadID) != member.Conflict { currentMember = currentMember.lighterMember } s.heaviestPreferredMember = currentMember - s.HeaviestPreferredMemberUpdated.Trigger(currentMember.Conflict, visitedConflicts) + s.HeaviestPreferredMemberUpdated.Trigger(currentMember.Conflict, threadID) } // nextPendingWeightUpdate returns the next member that needs to be updated (or nil if the shutdown flag is set). @@ -260,13 +274,15 @@ func (s *SortedSet[ConflictID, ResourceID]) fixMemberPositionWorker() { // fixMemberPosition fixes the position of the given member in the SortedSet. func (s *SortedSet[ConflictID, ResourceID]) fixMemberPosition(member *sortedSetMember[ConflictID, ResourceID]) { + threadID := reentrantmutex.NewThreadID() + fmt.Println("Write-Lock", s.owner.ID(), "fixMemberPosition(", member.ID(), ")") defer fmt.Println("Write-Unlock", s.owner.ID(), "fixMemberPosition(", member.ID(), ")") - s.mutex.Lock() - defer s.mutex.Unlock() + s.mutex.Lock(threadID) + defer s.mutex.UnLock(threadID) - preferredMember := s.preferredInstead(member) + preferredMember := member.PreferredInstead(threadID) // the member needs to be moved up in the list for currentMember := member.heavierMember; currentMember != nil && currentMember.Compare(member) == weight.Lighter; currentMember = member.heavierMember { @@ -274,7 +290,8 @@ func (s *SortedSet[ConflictID, ResourceID]) fixMemberPosition(member *sortedSetM if currentMember.ID() == preferredMember.ID() { s.heaviestPreferredMember = member - s.HeaviestPreferredMemberUpdated.Trigger(member.Conflict, NewTriggerContext(s.owner.ID())) + fmt.Println("TRIGGER1", threadID) + s.HeaviestPreferredMemberUpdated.Trigger(member.Conflict, threadID) } } @@ -283,27 +300,16 @@ func (s *SortedSet[ConflictID, ResourceID]) fixMemberPosition(member *sortedSetM for currentMember := member.lighterMember; currentMember != nil && currentMember.Compare(member) == weight.Heavier; currentMember = member.lighterMember { s.swapNeighbors(currentMember, member) - if memberIsHeaviestPreferred && s.isPreferred(currentMember) { + if memberIsHeaviestPreferred && currentMember.IsPreferred(threadID) { s.heaviestPreferredMember = currentMember - s.HeaviestPreferredMemberUpdated.Trigger(currentMember.Conflict, TriggerContext[ConflictID]{s.owner.ID(): types.Void}) + fmt.Println("TRIGGER2", threadID) + s.HeaviestPreferredMemberUpdated.Trigger(currentMember.Conflict, threadID) memberIsHeaviestPreferred = false } } } -func (s *SortedSet[ConflictID, ResourceID]) preferredInstead(member *sortedSetMember[ConflictID, ResourceID]) *Conflict[ConflictID, ResourceID] { - if member.Conflict == s.owner { - return s.heaviestPreferredMember.Conflict - } - - return member.PreferredInstead() -} - -func (s *SortedSet[ConflictID, ResourceID]) isPreferred(member *sortedSetMember[ConflictID, ResourceID]) bool { - return s.preferredInstead(member) == member.Conflict -} - // swapNeighbors swaps the given members in the SortedSet. func (s *SortedSet[ConflictID, ResourceID]) swapNeighbors(heavierMember, lighterMember *sortedSetMember[ConflictID, ResourceID]) { if heavierMember.lighterMember != nil { diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedsetmember.go b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedsetmember.go index ef03c6096f..b6383475ae 100644 --- a/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedsetmember.go +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/conflict/sortedsetmember.go @@ -2,11 +2,10 @@ package conflict import ( "bytes" - "fmt" "sync" + "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex" "github.com/iotaledger/goshimmer/packages/protocol/engine/ledger/mempool/newconflictdag/weight" - "github.com/iotaledger/hive.go/ds/types" "github.com/iotaledger/hive.go/lo" "github.com/iotaledger/hive.go/runtime/event" ) @@ -35,7 +34,7 @@ type sortedSetMember[ConflictID, ResourceID IDType] struct { onUpdateHook *event.Hook[func(weight.Value)] // onPreferredUpdatedHook is the hook that is triggered when the preferredInstead value of the Conflict is updated. - onPreferredUpdatedHook *event.Hook[func(*Conflict[ConflictID, ResourceID], TriggerContext[ConflictID])] + onPreferredUpdatedHook *event.Hook[func(*Conflict[ConflictID, ResourceID], reentrantmutex.ThreadID)] // Conflict is the wrapped Conflict. *Conflict[ConflictID, ResourceID] @@ -53,8 +52,8 @@ func newSortedSetMember[ConflictID, ResourceID IDType](set *SortedSet[ConflictID // do not attach to event from ourselves if set.owner != conflict { - s.onPreferredUpdatedHook = conflict.PreferredInsteadUpdated.Hook(func(newPreferredConflict *Conflict[ConflictID, ResourceID], visitedConflicts TriggerContext[ConflictID]) { - s.notifyPreferredInsteadUpdate(newPreferredConflict, visitedConflicts) + s.onPreferredUpdatedHook = conflict.PreferredInsteadUpdated.Hook(func(newPreferredConflict *Conflict[ConflictID, ResourceID], threadID reentrantmutex.ThreadID) { + s.notifyPreferredInsteadUpdate(newPreferredConflict, threadID) }) } @@ -113,13 +112,6 @@ func (s *sortedSetMember[ConflictID, ResourceID]) weightUpdateApplied() bool { } // notifyPreferredInsteadUpdate notifies the sortedSet that the preferred instead flag of the Conflict was updated. -func (s *sortedSetMember[ConflictID, ResourceID]) notifyPreferredInsteadUpdate(newPreferredConflict *Conflict[ConflictID, ResourceID], visitedConflicts TriggerContext[ConflictID]) { - if _, exists := visitedConflicts[s.sortedSet.owner.ID()]; !exists { - visitedConflicts[s.ID()] = types.Void - fmt.Println("notify", s.sortedSet.owner.ID(), "that", s.ID(), "prefers", newPreferredConflict.ID(), "with visited conflicts", visitedConflicts) - - s.sortedSet.notifyPreferredInsteadUpdate(s, newPreferredConflict == s.Conflict, visitedConflicts) - } else { - fmt.Println("do not notify", s.sortedSet.owner.ID(), "that", s.ID(), "prefers", newPreferredConflict.ID(), "with visited conflicts", visitedConflicts) - } +func (s *sortedSetMember[ConflictID, ResourceID]) notifyPreferredInsteadUpdate(newPreferredConflict *Conflict[ConflictID, ResourceID], threadID reentrantmutex.ThreadID) { + s.sortedSet.notifyPreferredInsteadUpdate(s, newPreferredConflict == s.Conflict, threadID) } diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex.go b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex.go new file mode 100644 index 0000000000..485ca89a2a --- /dev/null +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex.go @@ -0,0 +1,139 @@ +package reentrantmutex + +import ( + "fmt" + "runtime" + "strconv" + "sync" +) + +type ReEntrantMutex struct { + debugID string + wLockThreadID *ThreadID + wLockCounter int + rLockCounters map[ThreadID]int + mutexUnlocked sync.Cond + mutex sync.Mutex +} + +func New(debugID string) *ReEntrantMutex { + r := new(ReEntrantMutex) + r.debugID = debugID + r.rLockCounters = make(map[ThreadID]int) + r.mutexUnlocked.L = &r.mutex + + return r +} +func (m *ReEntrantMutex) Origin() string { + _, fileName, lineNumber, _ := runtime.Caller(2) + + return fileName + ":" + strconv.Itoa(lineNumber) +} + +func (m *ReEntrantMutex) RLock(threadID ThreadID) { + fmt.Printf("[RLOCKING]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + defer fmt.Printf("[RLOCKED]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + + m.mutex.Lock() + defer m.mutex.Unlock() + + for !m.isRLockable(threadID) { + m.mutexUnlocked.Wait() + } + + m.rLockCounters[threadID]++ +} + +func (m *ReEntrantMutex) RUnlock(threadID ThreadID) { + fmt.Printf("[RUNLOCKING]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + defer fmt.Printf("[RUNLOCKED]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + + if m.rUnlock(threadID) { + m.mutexUnlocked.Broadcast() + } +} + +func (m *ReEntrantMutex) Lock(threadID ThreadID) { + fmt.Printf("[LOCKING]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + defer fmt.Printf("[LOCKED]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + + m.mutex.Lock() + defer m.mutex.Unlock() + + for !m.isLockable(threadID) { + m.mutexUnlocked.Wait() + } + + m.wLockThreadID = &threadID + m.wLockCounter++ +} + +func (m *ReEntrantMutex) UnLock(threadID ThreadID) { + fmt.Printf("[UNLOCKING]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + defer fmt.Printf("[UNLOCKED]\t%s\t%04d\t%s\n", m.debugID, threadID, m.Origin()) + + if m.unlock(threadID) { + m.mutexUnlocked.Broadcast() + } +} + +func (m *ReEntrantMutex) isLockable(threadID ThreadID) bool { + if m.wLockThreadID != nil { + return *m.wLockThreadID == threadID + } + + if len(m.rLockCounters) == 0 { + return true + } + + if len(m.rLockCounters) == 1 { + for rLockThreadID := range m.rLockCounters { + return rLockThreadID == threadID + } + } + + return false +} + +func (m *ReEntrantMutex) unlock(threadID ThreadID) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.wLockThreadID == nil || *m.wLockThreadID != threadID { + panic("threadID does not match current threadID") + } + + m.wLockCounter-- + + if m.wLockCounter == 0 { + m.wLockThreadID = nil + return true + } + + return false +} + +func (m *ReEntrantMutex) isRLockable(threadID ThreadID) bool { + if m.wLockThreadID != nil { + return *m.wLockThreadID == threadID + } + + return true +} + +func (m *ReEntrantMutex) rUnlock(threadID ThreadID) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + if counter, exists := m.rLockCounters[threadID]; !exists || counter == 0 { + panic("trying to RUnlock a threadID that was not locked before") + } + + m.rLockCounters[threadID]-- + + if m.rLockCounters[threadID] == 0 { + delete(m.rLockCounters, threadID) + } + + return true +} diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex_test.go b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex_test.go new file mode 100644 index 0000000000..ed3410cc3d --- /dev/null +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/reentrantmutex_test.go @@ -0,0 +1,68 @@ +package reentrantmutex + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLock(t *testing.T) { + m := New() + m.Lock(1) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + m.Lock(2) + defer m.UnLock(2) + + fmt.Print("DONE 2") + + wg.Done() + }() + + m.Lock(1) + m.UnLock(1) + + time.Sleep(1 * time.Second) + + m.UnLock(1) + + fmt.Print("DONE 1") + + wg.Wait() +} + +func TestRLock(t *testing.T) { + m := New() + m.RLock(1) + m.RLock(2) + m.RLock(3) + m.RUnlock(2) + + var lockAcquired bool + var wg sync.WaitGroup + wg.Add(1) + go func() { + m.Lock(1) + m.UnLock(1) + + lockAcquired = true + + wg.Done() + }() + + time.Sleep(1 * time.Second) + + require.False(t, lockAcquired) + + m.RUnlock(3) + m.RUnlock(1) + + wg.Wait() + + require.True(t, lockAcquired) +} diff --git a/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/threadid.go b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/threadid.go new file mode 100644 index 0000000000..4a6c8cb2bf --- /dev/null +++ b/packages/protocol/engine/ledger/mempool/newconflictdag/reentrantmutex/threadid.go @@ -0,0 +1,19 @@ +package reentrantmutex + +import "sync" + +type ThreadID uint64 + +var ( + threadIDCounter ThreadID + threadIDCounterMutex sync.Mutex +) + +func NewThreadID() ThreadID { + threadIDCounterMutex.Lock() + defer threadIDCounterMutex.Unlock() + + threadIDCounter++ + + return threadIDCounter +}