Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix a data race by adding a shared mutex #2961

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions srtcore/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ int srt::CUDTUnited::newConnection(const SRTSOCKET listen,

try
{
ScopedLock col(ls->core().m_ConnectionLock);
ns = new CUDTSocket(*ls);
// No need to check the peer, this is the address from which the request has come.
ns->m_PeerAddr = peer;
Expand Down
33 changes: 21 additions & 12 deletions srtcore/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@
, m_szPayloadSize()
, m_bClosing(false)
, m_LSLock()
, m_pListener(NULL)
, m_pListener()
, m_pRendezvousQueue(NULL)
, m_vNewEntry()
, m_IDLock()
Expand Down Expand Up @@ -1405,10 +1405,11 @@
bool have_listener = false;
{
ScopedLock cg(m_LSLock);
if (m_pListener)
m_pListener.lockRead();
if (m_pListener.udt)
{
LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener->socketID());
listener_ret = m_pListener->processConnectRequest(addr, unit->m_Packet);
LOGC(cnlog.Debug, log << "PASSING request from: " << addr.str() << " to listener:" << m_pListener.udt->socketID());
listener_ret = m_pListener.udt->processConnectRequest(addr, unit->m_Packet);

// This function does return a code, but it's hard to say as to whether
// anything can be done about it. In case when it's stated possible, the
Expand All @@ -1418,6 +1419,7 @@

have_listener = true;
}
m_pListener.unlockRead();
}

// NOTE: Rendezvous sockets do bind(), but not listen(). It means that the socket is
Expand Down Expand Up @@ -1690,21 +1692,28 @@

int srt::CRcvQueue::setListener(CUDT* u)
{
ScopedLock lslock(m_LSLock);

if (NULL != m_pListener)
m_pListener.lockWrite();
if (NULL != m_pListener.udt)
{
m_pListener.unlockWrite();
return -1;
}

m_pListener.udt = u;
m_pListener.unlockWrite();

m_pListener = u;
return 0;
}

void srt::CRcvQueue::removeListener(const CUDT* u)
{
ScopedLock lslock(m_LSLock);

if (u == m_pListener)
m_pListener = NULL;
//ScopedLock lslock(m_LSLock);

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
m_pListener.lockWrite();
if (u == m_pListener.udt)
{
m_pListener.udt = NULL;
}
m_pListener.unlockWrite();
}

void srt::CRcvQueue::registerConnector(const SRTSOCKET& id,
Expand Down
32 changes: 31 additions & 1 deletion srtcore/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,36 @@ namespace srt
{
class CChannel;
class CUDT;
class CUDTWrapper;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class definition follows below. This declaration is excessive.


class CUDTWrapper {
public:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public:
private:

CUDT *udt;
sync::SharedMutex mut;

public:
CUDTWrapper()
:udt(NULL)
,mut()
{
}
void lockRead()
{
return mut.lockRead();
}
void lockWrite()
{
return mut.lockWrite();
}
void unlockRead()
{
return mut.unlockRead();

}
void unlockWrite(){
return mut.unlockWrite();
}
};
Comment on lines +72 to +99
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite like the name. It does not express what the class does, except for wrapping CUDT.
In fact, it can be a template class holding a pointer to a resource and a shared mutex, e.g.

template <class T>
class CSharedObject
{
private:
    T* m_pObj;
    sync::SharedMutex m_mtx;
public:
    // ...
};


struct CUnit
{
Expand Down Expand Up @@ -555,7 +585,7 @@ class CRcvQueue

private:
sync::Mutex m_LSLock;
CUDT* m_pListener; // pointer to the (unique, if any) listening UDT entity
CUDTWrapper m_pListener; // pointer to the (unique, if any) listening UDT entity
CRendezvousQueue* m_pRendezvousQueue; // The list of sockets in rendezvous mode

std::vector<CUDT*> m_vNewEntry; // newly added entries, to be inserted
Expand Down
156 changes: 156 additions & 0 deletions srtcore/sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define INC_SRT_SYNC_H

#include "platform_sys.h"
#include <limits.h>

#include <cstdlib>
#include <limits>
Expand Down Expand Up @@ -943,9 +944,164 @@
/// @param[in] maxVal maximum allowed value of the resulting random number.
int genRandomInt(int minVal, int maxVal);

class SharedMutex
{
private:
Condition m_pLockWriteCond;
Condition m_pLockReadCond;

Mutex m_pMutex;
Mutex m_pMutex2;

int m_pCountRead;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"p" mean pointer, while the variable is not a pointer.
Please review the whole class.

bool m_pWriterLocked;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'p' means "pointer". This variable is of a type "Boolean", so should be m_bWriterLocked.



public:
SharedMutex()
:m_pLockWriteCond()
,m_pLockReadCond()
,m_pMutex()
,m_pMutex2()
,m_pCountRead(0)
,m_pWriterLocked(false)
{
m_pCountRead = 0;
m_pWriterLocked = false;

}

void lockWrite()
{
UniqueLock l1(m_pMutex);
if(m_pWriterLocked)
m_pLockWriteCond.wait(l1);
m_pWriterLocked = true;
if(m_pCountRead)
m_pLockReadCond.wait(l1);


}

void unlockWrite()
{
UniqueLock l2(m_pMutex);
m_pWriterLocked = false;
l2.unlock();
std::cout << "NOTIFY ALL" << std::endl;
m_pLockWriteCond.notify_all();
std::cout << "WRITER NOTIFIED" << std::endl;

}

void lockRead()
{
std::cout << "TRY LOCK READ " << this->m_pCountRead << this->m_pWriterLocked << std::endl;
UniqueLock l3(m_pMutex);
if(m_pWriterLocked)
m_pLockWriteCond.wait(l3);
m_pCountRead++;
std::cout << "LOCKED READ" << std::endl;
}

void unlockRead()
{
std::cout << "UNLOCK READ" << std::endl;
ScopedLock l4(m_pMutex);
m_pCountRead--;
if(m_pWriterLocked && m_pCountRead == 0)
m_pLockReadCond.notify_one();
else if (m_pCountRead > 0)
m_pLockWriteCond.notify_one();
std::cout << "READ UNLOCKED" << std::endl;


}

};

/* REFERENCE IMPLEMENTATION
class shared_mutex
{
Mutex mut_;
Condition gate1_;
Condition gate2_;
unsigned state_;

static const unsigned write_entered_ = 1U << (sizeof(unsigned)*CHAR_BIT - 1);
static const unsigned n_readers_ = ~write_entered_;

public:

shared_mutex() : state_(0) {}


// Exclusive ownership

void
lock()
{
UniqueLock lk(mut_);
std::cout << "LOCK WRITE " << std::endl;
while (state_ & write_entered_)
gate1_.wait(lk);
state_ |= write_entered_;
while (state_ & n_readers_)
gate2_.wait(lk);
std::cout << "LOCK WRITE DONE" << std::endl;

}

void
unlock()
{
{
ScopedLock _(mut_);
state_ = 0;
}
std::cout << "UNLOCK WRITE " << std::endl;
gate1_.notify_all();
std::cout << "UNLOCK WRITE DONE" << std::endl;

}

// Shared ownership

void
lock_shared()
{
UniqueLock lk(mut_);
while ((state_ & write_entered_) || (state_ & n_readers_) == n_readers_)
gate1_.wait(lk);
unsigned num_readers = (state_ & n_readers_) + 1;
state_ &= ~n_readers_;
state_ |= num_readers;
}

void
unlock_shared()
{
ScopedLock _(mut_);
unsigned num_readers = (state_ & n_readers_) - 1;
state_ &= ~n_readers_;
state_ |= num_readers;
if (state_ & write_entered_)
{
if (num_readers == 0)
gate2_.notify_one();
}
else
{
if (num_readers == n_readers_ - 1)
gate1_.notify_one();
}
}
};*/
Comment on lines +1023 to +1099

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

} // namespace sync
} // namespace srt


#include "atomic_clock.h"

#endif // INC_SRT_SYNC_H
Loading