Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix a data race #2984

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions srtcore/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ srt::CRcvQueue::CRcvQueue()
, m_szPayloadSize()
, m_bClosing(false)
, m_LSLock()
, m_pListener(NULL)
, m_pListener()
, m_pRendezvousQueue(NULL)
, m_vNewEntry()
, m_IDLock()
Expand Down Expand Up @@ -1405,10 +1405,11 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit,
bool have_listener = false;
{
ScopedLock cg(m_LSLock);
if (m_pListener)
m_pListener.lockRead();
if (m_pListener.m_pObj)
{
LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener->socketID());
listener_ret = m_pListener->processConnectRequest(addr, unit->m_Packet);
LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener.m_pObj->socketID());
listener_ret = m_pListener.m_pObj->processConnectRequest(addr, unit->m_Packet);

// This function does return a code, but it's hard to say as to whether
// anything can be done about it. In case when it's stated possible, the
Expand All @@ -1418,6 +1419,7 @@ srt::EConnectStatus srt::CRcvQueue::worker_ProcessConnectionRequest(CUnit* unit,

have_listener = true;
}
m_pListener.unlockRead();
}

// NOTE: Rendezvous sockets do bind(), but not listen(). It means that the socket is
Expand Down Expand Up @@ -1690,21 +1692,28 @@ int srt::CRcvQueue::recvfrom(int32_t id, CPacket& w_packet)

int srt::CRcvQueue::setListener(CUDT* u)
{
ScopedLock lslock(m_LSLock);
m_pListener.lockWrite();

if (NULL != m_pListener)
if (NULL != m_pListener.m_pObj)
{
m_pListener.unlockWrite();
return -1;
}

m_pListener.m_pObj = u;
m_pListener.unlockWrite();

m_pListener = u;
return 0;
}

void srt::CRcvQueue::removeListener(const CUDT* u)
{
ScopedLock lslock(m_LSLock);
m_pListener.lockWrite();

if (u == m_pListener)
m_pListener = NULL;
if (u == m_pListener.m_pObj)
m_pListener.m_pObj = NULL;

m_pListener.unlockWrite();
}

void srt::CRcvQueue::registerConnector(const SRTSOCKET& id,
Expand Down
6 changes: 3 additions & 3 deletions srtcore/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ class CRcvQueue
void storePktClone(int32_t id, const CPacket& pkt);

private:
sync::Mutex m_LSLock;
CUDT* m_pListener; // pointer to the (unique, if any) listening UDT entity
CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode
sync::Mutex m_LSLock;
sync::CSharedObject<CUDT> m_pListener; // pointer to the (unique, if any) listening UDT entity
CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode

std::vector<CUDT*> m_vNewEntry; // newly added entries, to be inserted
sync::Mutex m_IDLock;
Expand Down
105 changes: 105 additions & 0 deletions srtcore/sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ namespace sync {

srt::sync::CEvent g_Sync;



} // namespace sync
} // namespace srt

Expand Down Expand Up @@ -357,3 +359,106 @@ int srt::sync::genRandomInt(int minVal, int maxVal)
#endif // HAVE_CXX11
}


////////////////////////////////////////////////////////////////////////////////
//
// Shared Mutex
//
////////////////////////////////////////////////////////////////////////////////

srt::sync::SharedMutex::SharedMutex()
:m_LockWriteCond()
,m_LockReadCond()
,m_Mutex()
,m_iCountRead(0)
,m_bWriterLocked(false)
{
m_iCountRead = 0;
m_bWriterLocked = false;
setupCond(m_LockReadCond, "SharedMutex::m_pLockReadCond");
setupCond(m_LockWriteCond, "SharedMutex::m_pLockWriteCond");
setupMutex(m_Mutex, "SharedMutex::m_Mutex");
}

srt::sync::SharedMutex::~SharedMutex()
{
releaseMutex(m_Mutex);
releaseCond(m_LockWriteCond);
releaseCond(m_LockReadCond);
}

void srt::sync::SharedMutex::lock()
{
UniqueLock l1(m_Mutex);
if (m_bWriterLocked)
m_LockWriteCond.wait(l1);

m_bWriterLocked = true;

if (m_iCountRead)
m_LockReadCond.wait(l1);
}

bool srt::sync::SharedMutex::try_lock()
{
UniqueLock l1(m_Mutex);
if (m_bWriterLocked || m_iCountRead > 0)
return false;
else
{
m_bWriterLocked = true;
return true;
}
}

void srt::sync::SharedMutex::unlock()
{
UniqueLock lk(m_Mutex);
m_bWriterLocked = false;

lk.unlock();
m_LockWriteCond.notify_all();
}

void srt::sync::SharedMutex::lock_shared()
{
UniqueLock lk(m_Mutex);
if (m_bWriterLocked)
m_LockWriteCond.wait(lk);

m_iCountRead++;
}

bool srt::sync::SharedMutex::try_lock_shared()
{
UniqueLock lk(m_Mutex);
if (m_bWriterLocked)
return false;
else
{
m_iCountRead++;
return true;
}
}

void srt::sync::SharedMutex::unlock_shared()
{
ScopedLock lk(m_Mutex);

m_iCountRead--;
if (m_iCountRead < 0)
m_iCountRead = 0;

if (m_bWriterLocked && m_iCountRead == 0)
m_LockReadCond.notify_one();
else if (m_iCountRead > 0)
m_LockWriteCond.notify_one();

}

int srt::sync::SharedMutex::getReaderCount()
{
return m_iCountRead;
}


72 changes: 72 additions & 0 deletions srtcore/sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,78 @@ CUDTException& GetThreadLocalError();
/// @param[in] maxVal maximum allowed value of the resulting random number.
int genRandomInt(int minVal, int maxVal);


// Implementation of a read-write mutex.
// This allows multiple readers at a time, or a single writer
class SharedMutex
{
public:
SharedMutex();
~SharedMutex();

private:
Condition m_LockWriteCond;
Condition m_LockReadCond;

Mutex m_Mutex;

int m_iCountRead;
bool m_bWriterLocked;

// Acquire the lock for writting purposes. Only one thread can acquire this lock at a time
// Once it is locked, no reader can acquire it
public:
void lock();
bool try_lock();
void unlock();

// Acquire the lock if no writter already has it. For read purpose only
// Several readers can lock this at the same time.
void lock_shared();
bool try_lock_shared();
void unlock_shared();

int getReaderCount();


};

template <class T>
class CSharedObject
{
public:
T* m_pObj;
sync::SharedMutex m_Mtx;

public:
CSharedObject<T>()
:m_pObj()
,m_Mtx()
{
}

void lockWrite()
{
m_Mtx.lock();
}

void unlockWrite()
{
m_Mtx.unlock();
}

void lockRead()
{
m_Mtx.lock_shared();
}

void unlockRead()
{
m_Mtx.unlock_shared();
}

};

} // namespace sync
} // namespace srt

Expand Down
85 changes: 85 additions & 0 deletions test/test_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,91 @@ TEST(SyncThread, Joinable)
EXPECT_FALSE(foo.joinable());
}

/*****************************************************************************/
/*
* SharedMutex
*/
/*****************************************************************************/
TEST(SharedMutex, LockWriteRead)
{
SharedMutex mut;

mut.lock();
EXPECT_FALSE(mut.try_lock_shared());

}

TEST(SharedMutex, LockReadWrite)
{
SharedMutex mut;

mut.lock_shared();
EXPECT_FALSE(mut.try_lock());

}

TEST(SharedMutex, LockReadTwice)
{
SharedMutex mut;

mut.lock_shared();
mut.lock_shared();
EXPECT_TRUE(mut.try_lock_shared());
}

TEST(SharedMutex, LockWriteTwice)
{
SharedMutex mut;

mut.lock();
EXPECT_FALSE(mut.try_lock());
}

TEST(SharedMutex, LockUnlockWrite)
{
SharedMutex mut;
mut.lock();
EXPECT_FALSE(mut.try_lock());
mut.unlock();
EXPECT_TRUE(mut.try_lock());
}

TEST(SharedMutex, LockUnlockRead)
{
SharedMutex mut;

mut.lock_shared();
EXPECT_FALSE(mut.try_lock());

mut.unlock_shared();
EXPECT_TRUE(mut.try_lock());
}

TEST(SharedMutex, LockedReadCount)
{
SharedMutex mut;
int count = 0;

mut.lock_shared();
count++;
ASSERT_EQ(mut.getReaderCount(), count);

mut.lock_shared();
count++;
ASSERT_EQ(mut.getReaderCount(), count);

mut.unlock_shared();
count--;
ASSERT_EQ(mut.getReaderCount(), count);

mut.unlock_shared();
count--;
ASSERT_EQ(mut.getReaderCount(), count);

EXPECT_TRUE(mut.try_lock());
}


/*****************************************************************************/
/*
* FormatTime
Expand Down
Loading