diff --git a/srtcore/sync.cpp b/srtcore/sync.cpp index a7cebb909..bfe153657 100644 --- a/srtcore/sync.cpp +++ b/srtcore/sync.cpp @@ -357,3 +357,98 @@ int srt::sync::genRandomInt(int minVal, int maxVal) #endif // HAVE_CXX11 } + +//////////////////////////////////////////////////////////////////////////////// +// +// Shared Mutex +// +//////////////////////////////////////////////////////////////////////////////// + +srt::sync::SharedMutex::SharedMutex() + : m_LockWriteCond() + , m_LockReadCond() + , m_Mutex() + , m_iCountRead(0) + , m_bWriterLocked(false) +{ + setupCond(m_LockReadCond, "SharedMutex::m_pLockReadCond"); + setupCond(m_LockWriteCond, "SharedMutex::m_pLockWriteCond"); + setupMutex(m_Mutex, "SharedMutex::m_pMutex"); +} + +srt::sync::SharedMutex::~SharedMutex() +{ + releaseMutex(m_Mutex); + releaseCond(m_LockWriteCond); + releaseCond(m_LockReadCond); +} + +void srt::sync::SharedMutex::lock() +{ + UniqueLock l1(m_Mutex); + while (m_bWriterLocked) + m_LockWriteCond.wait(l1); + + m_bWriterLocked = true; + + while (m_iCountRead) + m_LockReadCond.wait(l1); +} + +bool srt::sync::SharedMutex::try_lock() +{ + UniqueLock l1(m_Mutex); + if (m_bWriterLocked || m_iCountRead > 0) + return false; + + m_bWriterLocked = true; + return true; +} + +void srt::sync::SharedMutex::unlock() +{ + ScopedLock lk(m_Mutex); + m_bWriterLocked = false; + + m_LockWriteCond.notify_all(); +} + +void srt::sync::SharedMutex::lock_shared() +{ + UniqueLock lk(m_Mutex); + while (m_bWriterLocked) + m_LockWriteCond.wait(lk); + + m_iCountRead++; +} + +bool srt::sync::SharedMutex::try_lock_shared() +{ + UniqueLock lk(m_Mutex); + if (m_bWriterLocked) + return false; + + m_iCountRead++; + return true; +} + +void srt::sync::SharedMutex::unlock_shared() +{ + ScopedLock lk(m_Mutex); + + m_iCountRead--; + + SRT_ASSERT(m_iCountRead >= 0); + if (m_iCountRead < 0) + m_iCountRead = 0; + + if (m_bWriterLocked && m_iCountRead == 0) + m_LockReadCond.notify_one(); + +} + +int srt::sync::SharedMutex::getReaderCount() const +{ + ScopedLock lk(m_Mutex); + return m_iCountRead; +} \ No newline at end of file diff --git a/srtcore/sync.h b/srtcore/sync.h index fb6d56432..8fee25831 100644 --- a/srtcore/sync.h +++ b/srtcore/sync.h @@ -943,6 +943,43 @@ CUDTException& GetThreadLocalError(); /// @param[in] maxVal maximum allowed value of the resulting random number. int genRandomInt(int minVal, int maxVal); + +/// Implementation of a read-write mutex. +/// This allows multiple readers at a time, or a single writer. +/// TODO: The class can be improved if needed to give writer a preference +/// by adding additional m_iWritersWaiting member variable (counter). +/// TODO: The m_iCountRead could be made atomic to make unlok_shared() faster and lock-free. +class SharedMutex +{ +public: + SharedMutex(); + ~SharedMutex(); + +private: + Condition m_LockWriteCond; + Condition m_LockReadCond; + + mutable Mutex m_Mutex; + + int m_iCountRead; + bool m_bWriterLocked; + + /// Acquire the lock for writting purposes. Only one thread can acquire this lock at a time + /// Once it is locked, no reader can acquire it +public: + void lock(); + bool try_lock(); + void unlock(); + + /// Acquire the lock if no writter already has it. For read purpose only + /// Several readers can lock this at the same time. + void lock_shared(); + bool try_lock_shared(); + void unlock_shared(); + + int getReaderCount() const; +}; + } // namespace sync } // namespace srt diff --git a/test/test_sync.cpp b/test/test_sync.cpp index e0454a581..9e2064cc1 100644 --- a/test/test_sync.cpp +++ b/test/test_sync.cpp @@ -609,6 +609,91 @@ TEST(SyncThread, Joinable) EXPECT_FALSE(foo.joinable()); } +/*****************************************************************************/ +/* + * SharedMutex + */ + /*****************************************************************************/ +TEST(SharedMutex, LockWriteRead) +{ + SharedMutex mut; + + mut.lock(); + EXPECT_FALSE(mut.try_lock_shared()); + +} + +TEST(SharedMutex, LockReadWrite) +{ + SharedMutex mut; + + mut.lock_shared(); + EXPECT_FALSE(mut.try_lock()); + +} + +TEST(SharedMutex, LockReadTwice) +{ + SharedMutex mut; + + mut.lock_shared(); + mut.lock_shared(); + EXPECT_TRUE(mut.try_lock_shared()); +} + +TEST(SharedMutex, LockWriteTwice) +{ + SharedMutex mut; + + mut.lock(); + EXPECT_FALSE(mut.try_lock()); +} + +TEST(SharedMutex, LockUnlockWrite) +{ + SharedMutex mut; + mut.lock(); + EXPECT_FALSE(mut.try_lock()); + mut.unlock(); + EXPECT_TRUE(mut.try_lock()); +} + +TEST(SharedMutex, LockUnlockRead) +{ + SharedMutex mut; + + mut.lock_shared(); + EXPECT_FALSE(mut.try_lock()); + + mut.unlock_shared(); + EXPECT_TRUE(mut.try_lock()); +} + +TEST(SharedMutex, LockedReadCount) +{ + SharedMutex mut; + int count = 0; + + mut.lock_shared(); + count++; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.lock_shared(); + count++; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.unlock_shared(); + count--; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.unlock_shared(); + count--; + ASSERT_EQ(mut.getReaderCount(), count); + + EXPECT_TRUE(mut.try_lock()); +} + + /*****************************************************************************/ /* * FormatTime