diff --git a/srtcore/queue.cpp b/srtcore/queue.cpp index 98999a81f..7127ae675 100644 --- a/srtcore/queue.cpp +++ b/srtcore/queue.cpp @@ -1130,7 +1130,7 @@ srt::CRcvQueue::CRcvQueue() , m_szPayloadSize() , m_bClosing(false) , m_LSLock() - , m_pListener(NULL) + , m_pListener() , m_pRendezvousQueue(NULL) , m_vNewEntry() , m_IDLock() @@ -1405,10 +1405,11 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit, bool have_listener = false; { ScopedLock cg(m_LSLock); - if (m_pListener) + m_pListener.lockRead(); + if (m_pListener.m_pObj) { - LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener->socketID()); - listener_ret = m_pListener->processConnectRequest(addr, unit->m_Packet); + LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener.m_pObj->socketID()); + listener_ret = m_pListener.m_pObj->processConnectRequest(addr, unit->m_Packet); // This function does return a code, but it's hard to say as to whether // anything can be done about it. In case when it's stated possible, the @@ -1418,6 +1419,7 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit, have_listener = true; } + m_pListener.unlockRead(); } // NOTE: Rendezvous sockets do bind(), but not listen(). It means that the socket is @@ -1690,21 +1692,28 @@ int srt::CRcvQueue::recvfrom(int32_t id, CPacket& w_packet) int srt::CRcvQueue::setListener(CUDT* u) { - ScopedLock lslock(m_LSLock); + m_pListener.lockWrite(); - if (NULL != m_pListener) + if (NULL != m_pListener.m_pObj) + { + m_pListener.unlockWrite(); return -1; + } + + m_pListener.m_pObj = u; + m_pListener.unlockWrite(); - m_pListener = u; return 0; } void srt::CRcvQueue::removeListener(const CUDT* u) { - ScopedLock lslock(m_LSLock); + m_pListener.lockWrite(); - if (u == m_pListener) - m_pListener = NULL; + if (u == m_pListener.m_pObj) + m_pListener.m_pObj = NULL; + + m_pListener.unlockWrite(); } void srt::CRcvQueue::registerConnector(const SRTSOCKET& id, diff --git a/srtcore/queue.h b/srtcore/queue.h index dd68a7721..70becd28d 100644 --- a/srtcore/queue.h +++ b/srtcore/queue.h @@ -554,9 +554,9 @@ class CRcvQueue void storePktClone(int32_t id, const CPacket& pkt); private: - sync::Mutex m_LSLock; - CUDT* m_pListener; // pointer to the (unique, if any) listening UDT entity - CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode + sync::Mutex m_LSLock; + sync::CSharedObject m_pListener; // pointer to the (unique, if any) listening UDT entity + CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode std::vector m_vNewEntry; // newly added entries, to be inserted sync::Mutex m_IDLock; diff --git a/srtcore/sync.cpp b/srtcore/sync.cpp index a7cebb909..c25fc2c03 100644 --- a/srtcore/sync.cpp +++ b/srtcore/sync.cpp @@ -172,6 +172,8 @@ namespace sync { srt::sync::CEvent g_Sync; + + } // namespace sync } // namespace srt @@ -357,3 +359,106 @@ int srt::sync::genRandomInt(int minVal, int maxVal) #endif // HAVE_CXX11 } + +//////////////////////////////////////////////////////////////////////////////// +// +// Shared Mutex +// +//////////////////////////////////////////////////////////////////////////////// + +srt::sync::SharedMutex::SharedMutex() + :m_LockWriteCond() + ,m_LockReadCond() + ,m_Mutex() + ,m_iCountRead(0) + ,m_bWriterLocked(false) +{ + m_iCountRead = 0; + m_bWriterLocked = false; + setupCond(m_LockReadCond, "SharedMutex::m_pLockReadCond"); + setupCond(m_LockWriteCond, "SharedMutex::m_pLockWriteCond"); + setupMutex(m_Mutex, "SharedMutex::m_Mutex"); +} + +srt::sync::SharedMutex::~SharedMutex() +{ + releaseMutex(m_Mutex); + releaseCond(m_LockWriteCond); + releaseCond(m_LockReadCond); +} + +void srt::sync::SharedMutex::lock() +{ + UniqueLock l1(m_Mutex); + if (m_bWriterLocked) + m_LockWriteCond.wait(l1); + + m_bWriterLocked = true; + + if (m_iCountRead) + m_LockReadCond.wait(l1); +} + +bool srt::sync::SharedMutex::try_lock() +{ + UniqueLock l1(m_Mutex); + if (m_bWriterLocked || m_iCountRead > 0) + return false; + else + { + m_bWriterLocked = true; + return true; + } +} + +void srt::sync::SharedMutex::unlock() +{ + UniqueLock lk(m_Mutex); + m_bWriterLocked = false; + + lk.unlock(); + m_LockWriteCond.notify_all(); +} + +void srt::sync::SharedMutex::lock_shared() +{ + UniqueLock lk(m_Mutex); + if (m_bWriterLocked) + m_LockWriteCond.wait(lk); + + m_iCountRead++; +} + +bool srt::sync::SharedMutex::try_lock_shared() +{ + UniqueLock lk(m_Mutex); + if (m_bWriterLocked) + return false; + else + { + m_iCountRead++; + return true; + } +} + +void srt::sync::SharedMutex::unlock_shared() +{ + ScopedLock lk(m_Mutex); + + m_iCountRead--; + if (m_iCountRead < 0) + m_iCountRead = 0; + + if (m_bWriterLocked && m_iCountRead == 0) + m_LockReadCond.notify_one(); + else if (m_iCountRead > 0) + m_LockWriteCond.notify_one(); + +} + +int srt::sync::SharedMutex::getReaderCount() +{ + return m_iCountRead; +} + + diff --git a/srtcore/sync.h b/srtcore/sync.h index fb6d56432..5fb2b7714 100644 --- a/srtcore/sync.h +++ b/srtcore/sync.h @@ -943,6 +943,78 @@ CUDTException& GetThreadLocalError(); /// @param[in] maxVal maximum allowed value of the resulting random number. int genRandomInt(int minVal, int maxVal); + +// Implementation of a read-write mutex. +// This allows multiple readers at a time, or a single writer +class SharedMutex +{ + public: + SharedMutex(); + ~SharedMutex(); + + private: + Condition m_LockWriteCond; + Condition m_LockReadCond; + + Mutex m_Mutex; + + int m_iCountRead; + bool m_bWriterLocked; + + // Acquire the lock for writting purposes. Only one thread can acquire this lock at a time + // Once it is locked, no reader can acquire it + public: + void lock(); + bool try_lock(); + void unlock(); + + // Acquire the lock if no writter already has it. For read purpose only + // Several readers can lock this at the same time. + void lock_shared(); + bool try_lock_shared(); + void unlock_shared(); + + int getReaderCount(); + + +}; + +template +class CSharedObject +{ + public: + T* m_pObj; + sync::SharedMutex m_Mtx; + + public: + CSharedObject() + :m_pObj() + ,m_Mtx() + { + } + + void lockWrite() + { + m_Mtx.lock(); + } + + void unlockWrite() + { + m_Mtx.unlock(); + } + + void lockRead() + { + m_Mtx.lock_shared(); + } + + void unlockRead() + { + m_Mtx.unlock_shared(); + } + +}; + } // namespace sync } // namespace srt diff --git a/test/test_sync.cpp b/test/test_sync.cpp index e0454a581..9e2064cc1 100644 --- a/test/test_sync.cpp +++ b/test/test_sync.cpp @@ -609,6 +609,91 @@ TEST(SyncThread, Joinable) EXPECT_FALSE(foo.joinable()); } +/*****************************************************************************/ +/* + * SharedMutex + */ + /*****************************************************************************/ +TEST(SharedMutex, LockWriteRead) +{ + SharedMutex mut; + + mut.lock(); + EXPECT_FALSE(mut.try_lock_shared()); + +} + +TEST(SharedMutex, LockReadWrite) +{ + SharedMutex mut; + + mut.lock_shared(); + EXPECT_FALSE(mut.try_lock()); + +} + +TEST(SharedMutex, LockReadTwice) +{ + SharedMutex mut; + + mut.lock_shared(); + mut.lock_shared(); + EXPECT_TRUE(mut.try_lock_shared()); +} + +TEST(SharedMutex, LockWriteTwice) +{ + SharedMutex mut; + + mut.lock(); + EXPECT_FALSE(mut.try_lock()); +} + +TEST(SharedMutex, LockUnlockWrite) +{ + SharedMutex mut; + mut.lock(); + EXPECT_FALSE(mut.try_lock()); + mut.unlock(); + EXPECT_TRUE(mut.try_lock()); +} + +TEST(SharedMutex, LockUnlockRead) +{ + SharedMutex mut; + + mut.lock_shared(); + EXPECT_FALSE(mut.try_lock()); + + mut.unlock_shared(); + EXPECT_TRUE(mut.try_lock()); +} + +TEST(SharedMutex, LockedReadCount) +{ + SharedMutex mut; + int count = 0; + + mut.lock_shared(); + count++; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.lock_shared(); + count++; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.unlock_shared(); + count--; + ASSERT_EQ(mut.getReaderCount(), count); + + mut.unlock_shared(); + count--; + ASSERT_EQ(mut.getReaderCount(), count); + + EXPECT_TRUE(mut.try_lock()); +} + + /*****************************************************************************/ /* * FormatTime