Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free-form text rejection message sent back to the caller #1170

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
22 changes: 18 additions & 4 deletions srtcore/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ SRTSOCKET CUDTUnited::newSocket(CUDTSocket** pps)
}

int CUDTUnited::newConnection(const SRTSOCKET listen, const sockaddr_any& peer, const CPacket& hspkt,
CHandShake& w_hs, SRT_REJECT_REASON& w_error)
CHandShake& w_hs, int& w_error, string& w_streaminfo)
{
CUDTSocket* ns = NULL;

Expand Down Expand Up @@ -549,6 +549,8 @@ int CUDTUnited::newConnection(const SRTSOCKET listen, const sockaddr_any& peer,
return -1;
}

ns->m_pUDT->m_RejectReason = SRT_REJ_UNKNOWN; // pre-set a universal value

try
{
ns->m_SocketID = generateSocketID();
Expand Down Expand Up @@ -586,6 +588,7 @@ int CUDTUnited::newConnection(const SRTSOCKET listen, const sockaddr_any& peer,
// CUDT::open() may only throw original std::bad_alloc from new.
// This is only to make the library extra safe (when your machine lacks
// memory, it will continue to work, but fail to accept connection).

try
{
// This assignment must happen b4 the call to CUDT::connect() because
Expand All @@ -607,6 +610,11 @@ int CUDTUnited::newConnection(const SRTSOCKET listen, const sockaddr_any& peer,
{
if (!ls->m_pUDT->runAcceptHook(ns->m_pUDT, peer.get(), w_hs, hspkt))
{
w_error = ns->m_pUDT->m_RejectReason;

// Save the STREAMID contents in case when a user changed it
// in the listener callback
w_streaminfo = ns->m_pUDT->m_sStreamName;
error = 1;
goto ERR_ROLLBACK;
}
Expand Down Expand Up @@ -777,11 +785,12 @@ int CUDTUnited::newConnection(const SRTSOCKET listen, const sockaddr_any& peer,
#if ENABLE_LOGGING
static const char* why [] = {
"UNKNOWN ERROR",
"CONNECTION REJECTED",
"EXPLICIT REJECTION",
"IPE when mapping a socket",
"IPE when inserting a socket"
};
LOGC(mglog.Error, log << CONID(ns->m_SocketID) << "newConnection: connection rejected due to: " << why[error]);
LOGC(mglog.Error, log << CONID(ns->m_SocketID) << "newConnection: connection rejected due to: "
<< why[error] << " - " << RequestTypeStr(URQFailure(w_error)));
#endif
SRTSOCKET id = ns->m_SocketID;
ns->makeClosed();
Expand Down Expand Up @@ -4089,9 +4098,14 @@ SRT_API std::string getstreamid(SRTSOCKET u)
return CUDT::getstreamid(u);
}

SRT_REJECT_REASON getrejectreason(SRTSOCKET u)
int getrejectreason(SRTSOCKET u)
{
return CUDT::rejectReason(u);
}

int setrejectreason(SRTSOCKET u, int value)
{
return CUDT::rejectReason(u, value);
}

} // namespace UDT
2 changes: 1 addition & 1 deletion srtcore/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ friend class CRendezvousQueue;
/// @return If the new connection is successfully created: 1 success, 0 already exist, -1 error.

int newConnection(const SRTSOCKET listen, const sockaddr_any& peer, const CPacket& hspkt,
CHandShake& w_hs, SRT_REJECT_REASON& w_error);
CHandShake& w_hs, int& w_error, std::string& w_streaminfo);

int installAcceptHook(const SRTSOCKET lsn, srt_listen_callback_fn* hook, void* opaq);

Expand Down
11 changes: 9 additions & 2 deletions srtcore/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,16 @@ extern const char* const srt_rejectreason_msg [] = {
"Group settings collision"
};

const char* srt_rejectreason_str(SRT_REJECT_REASON rid)
const char* srt_rejectreason_str(int id)
{
int id = rid;
if (id > SRT_REJC_SERVER)
{
if (id > SRT_REJC_USER)
return "USER ERROR";

return "SERVER ERROR";
}

static const size_t ra_size = Size(srt_rejectreason_msg);
if (size_t(id) >= ra_size)
return srt_rejectreason_msg[0];
Expand Down
187 changes: 162 additions & 25 deletions srtcore/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2668,6 +2668,79 @@ int CUDT::processSrtMsg_HSRSP(const uint32_t *srtdata, size_t len, uint32_t ts,
return SRT_CMD_NONE;
}

void CUDT::interpretRejectionMessage(const CHandShake& hs, const CPacket& hspkt)
{
// This is simply part of the same procedure as done in interpretSrtHandshake(),
// it just extracts the contents of a prospective SRT_CMD_SID extension.
// If not present, do nothing.

int ext_flags = SrtHSRequest::SRT_HSTYPE_HSFLAGS::unwrap(hs.m_iType);
if (!IsSet(ext_flags, CHandShake::HS_EXT_CONFIG) || hspkt.getLength() <= CHandShake::m_iContentSize)
return;

uint32_t* p = reinterpret_cast<uint32_t*>(hspkt.m_pcData + CHandShake::m_iContentSize);
size_t size = hspkt.getLength() - CHandShake::m_iContentSize; // Due to previous cond check we grant it's >0

// XXX Probably common parts with interpretSrtHandshake() can be done for this loop

uint32_t *begin = p;
uint32_t *next = 0;
size_t length = size / sizeof(uint32_t);
size_t blocklen = 0;

for (;;) // This is one shot loop, unless REPEATED by 'continue'.
{
int cmd = FindExtensionBlock(begin, length, (blocklen), (next));

HLOGC(mglog.Debug,
log << "interpretRejectionMessage: found extension: (" << cmd << ") " << MessageTypeStr(UMSG_EXT, cmd));

const size_t bytelen = blocklen * sizeof(uint32_t);
if (cmd == SRT_CMD_SID)
{
if (!bytelen || bytelen > MAX_SID_LENGTH)
{
LOGC(mglog.Error,
log << "interpretRejectionMessage: STREAMID length " << bytelen << " is 0 or > " << +MAX_SID_LENGTH
<< " - PROTOCOL ERROR, REJECTING");
return;
}
// Copied through a cleared array. This is because the length is aligned to 4
// where the padding is filled by zero bytes. For the case when the string is
// exactly of a 4-divisible length, we make a big array with maximum allowed size
// filled with zeros. Copying to this array should then copy either only the valid
// characters of the string (if the lenght is divisible by 4), or the string with
// padding zeros. In all these cases in the resulting array we should have all
// subsequent characters of the string plus at least one '\0' at the end. This will
// make it a perfect NUL-terminated string, to be used to initialize a string.
char target[MAX_SID_LENGTH + 1];
memset((target), 0, MAX_SID_LENGTH + 1);
memcpy((target), begin + 1, bytelen);

// Un-swap on big endian machines
ItoHLA((uint32_t *)target, (uint32_t *)target, blocklen);

m_sStreamName = target;
HLOGC(mglog.Debug,
log << "REJECTION MESSAGE SID [" << m_sStreamName << "] (bytelen=" << bytelen
<< " blocklen=" << blocklen << ")");
}
else if (cmd == SRT_CMD_NONE)
{
break;
}
else
{
// Found some block that is not interesting here. Skip this and get the next one.
HLOGC(mglog.Debug, log << "interpretRejectionMessage: ... skipping " << MessageTypeStr(UMSG_EXT, cmd));
}

if (!NextExtensionBlock((begin), next, (length)))
break;
}

}

// This function is called only when the URQ_CONCLUSION handshake has been received from the peer.
bool CUDT::interpretSrtHandshake(const CHandShake& hs,
const CPacket& hspkt,
Expand All @@ -2690,22 +2763,26 @@ bool CUDT::interpretSrtHandshake(const CHandShake& hs,
if (hs.m_iVersion < HS_VERSION_SRT1)
return true; // do nothing

// Anyway, check if the handshake contains any extra data.
if (hspkt.getLength() <= CHandShake::m_iContentSize)
// It's not necessary to check size to fit in the basic handshake,
// it's already done during handshake serialization.

// Now check the obligatory HS flags. HSX is required, KMX and CONFIG optional.
int ext_flags = SrtHSRequest::SRT_HSTYPE_HSFLAGS::unwrap(hs.m_iType);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just inverted order, while the size check is for not just the plain handshake, but + the HSREQ extension. Basic size doesn't have to be checked - it was done already when deserializing the handshake. But then the flag should be checked first because with the flag it should grant that it also contains appropriate extensions.

if (ext_flags == 0)
{
m_RejectReason = SRT_REJ_ROGUE;
// This would mean that the handshake was at least HSv5, but somehow no extras were added.
// Dismiss it then, however this has to be logged.
LOGC(mglog.Error, log << "HS VERSION=" << hs.m_iVersion << " but no handshake extension found!");
LOGC(mglog.Error, log << "HS VERSION=" << hs.m_iVersion << " but no handshake extension flags are set!");
return false;
}

// We still believe it should work, let's check the flags.
int ext_flags = SrtHSRequest::SRT_HSTYPE_HSFLAGS::unwrap(hs.m_iType);
if (ext_flags == 0)
// Anyway, check if the handshake contains any extra data.
// The size must enclose at least the obligatory HSREQ extension, plus the header
if (hspkt.getLength() < CHandShake::m_iContentSize + (SRT_HS__SIZE + 1) * sizeof(int32_t))
{
m_RejectReason = SRT_REJ_ROGUE;
LOGC(mglog.Error, log << "HS VERSION=" << hs.m_iVersion << " but no handshake extension flags are set!");
// This would mean that the handshake was at least HSv5, but somehow no extras were added.
// Dismiss it then, however this has to be logged.
LOGC(mglog.Error, log << "HS VERSION=" << hs.m_iVersion << " but no handshake extension found (size=" << hspkt.getLength() << ")!");
return false;
}

Expand Down Expand Up @@ -4664,6 +4741,8 @@ EConnectStatus CUDT::processConnectResponse(const CPacket& response, CUDTExcepti
if (m_ConnRes.m_iReqType > URQ_FAILURE_TYPES)
{
m_RejectReason = RejectReasonForURQ(m_ConnRes.m_iReqType);
// Extract STREAMID extension, if present, and set it back on a socket.
interpretRejectionMessage(m_ConnRes, response);
return CONN_REJECT;
}

Expand Down Expand Up @@ -10033,25 +10112,56 @@ int32_t CUDT::bake(const sockaddr_any& addr, int32_t current_cookie, int correct
}
}

// XXX This is quite a mystery, why this function has a return value
// and what the purpose for it was. There's just one call of this
// function in the whole code and in that call the return value is
// ignored. Actually this call happens in the CRcvQueue::worker thread,
// where it makes a response for incoming UDP packet that might be
// a connection request. Should any error occur in this process, there
// is no way to "report error" that happened here. Basing on that
// these values in original UDT code were quite like the values
// for m_iReqType, they have been changed to URQ_* symbols, which
// may mean that the intent for the return value was to send this
// value back as a control packet back to the connector.
//
// This function does the same as the fragment that prepares the handshake
// in CUDT::createSrtHandshake. This should only copy one type of extension
// with contents possible to be specified as a string. Many extensions have
// specific ways of extracting the data into the extension, hence this is
// not refactored into single functions.
size_t CUDT::addHandshakeExtension(char* data, int cmd, size_t hs_size, std::string contents)
{
size_t ra_size = hs_size / sizeof(int32_t);

// Now attach the SRT handshake for HSREQ
size_t offset = ra_size;
uint32_t *p = reinterpret_cast<uint32_t *>(data);
// NOTE: since this point, ra_size has a size in int32_t elements, NOT BYTES.

// The first 4-byte item is the CMD/LENGTH spec.
uint32_t *pcmdspec = p + offset; // Remember the location to be filled later, when we know the length
++offset;

// Now prepare the string with 4-byte alignment. The string size is limited
// to half the payload size. Just a sanity check to not pack too much into
// the conclusion packet.
size_t size_limit = m_iMaxSRTPayloadSize / 2;

if (contents.size() >= size_limit)
{
contents = "#!::="; // size error
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just came to my mind: probably this isn't possible here, as the size limit should be checked when setting streamid.

}

size_t wordsize = (contents.size() + 3) / 4;
size_t aligned_bytesize = wordsize * 4;

memset((p + offset), 0, aligned_bytesize);
memcpy((p + offset), contents.data(), contents.size());
// Preswap to little endian (in place due to possible padding zeros)
HtoILA((uint32_t *)(p + offset), (uint32_t *)(p + offset), wordsize);

ra_size = wordsize;
*pcmdspec = HS_CMDSPEC_CMD::wrap(cmd) | HS_CMDSPEC_SIZE::wrap(ra_size);

return hs_size + sizeof (*pcmdspec) + aligned_bytesize;
}


// This function is run when the CRcvQueue object is reading packets
// from the multiplexer (@c CRcvQueue::worker_RetrieveUnit) and the
// target socket ID is 0.
//
// XXX Make this function return EConnectStatus enum type (extend if needed),
// and this will be directly passed to the caller.
SRT_REJECT_REASON CUDT::processConnectRequest(const sockaddr_any& addr, CPacket& packet)
int CUDT::processConnectRequest(const sockaddr_any& addr, CPacket& packet)
{
// XXX ASSUMPTIONS:
// [[using assert(packet.m_iID == 0)]]
Expand Down Expand Up @@ -10248,8 +10358,9 @@ SRT_REJECT_REASON CUDT::processConnectRequest(const sockaddr_any& addr, CPacket&
}
else
{
SRT_REJECT_REASON error = SRT_REJ_UNKNOWN;
int result = s_UDTUnited.newConnection(m_SocketID, addr, packet, (hs), (error));
int error = SRT_REJ_UNKNOWN;
string streaminfo_msg;
int result = s_UDTUnited.newConnection(m_SocketID, addr, packet, (hs), (error), (streaminfo_msg));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you get the contents of STREAM_ID. To put it into the response handshake?
How was it done before, if it was?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't. That's a new thing here. Previously the failure response handshake contained no extensions. This is copying it back in order to pass it in case when the user has set some contents here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is something that we decided not to do at the moment.
There is no precedure for this defined yet.


// This is listener - m_RejectReason need not be set
// because listener has no functionality of giving the app
Expand Down Expand Up @@ -10296,8 +10407,21 @@ SRT_REJECT_REASON CUDT::processConnectRequest(const sockaddr_any& addr, CPacket&
HLOGC(mglog.Debug,
log << CONID() << "processConnectRequest: sending ABNORMAL handshake info req="
<< RequestTypeStr(hs.m_iReqType));
bool has_message =! streaminfo_msg.empty();

if (has_message)
{
hs.m_iType |= CHandShake::HS_EXT_CONFIG;
}

size_t size = CHandShake::m_iContentSize;
hs.store_to((packet.m_pcData), (size));

if (has_message)
{
size = addHandshakeExtension((packet.m_pcData), SRT_CMD_SID, size, streaminfo_msg);
}
packet.setLength(size);
packet.m_iID = id;
setPacketTS(packet, steady_clock::now());
HLOGC(mglog.Debug, log << "processConnectRequest: SENDING HS (a): " << hs.show());
Expand Down Expand Up @@ -10767,7 +10891,7 @@ int CUDT::getsndbuffer(SRTSOCKET u, size_t *blocks, size_t *bytes)
return std::abs(timespan);
}

SRT_REJECT_REASON CUDT::rejectReason(SRTSOCKET u)
int CUDT::rejectReason(SRTSOCKET u)
{
CUDTSocket* s = s_UDTUnited.locateSocket(u);
if (!s || !s->m_pUDT)
Expand All @@ -10776,6 +10900,19 @@ SRT_REJECT_REASON CUDT::rejectReason(SRTSOCKET u)
return s->m_pUDT->m_RejectReason;
}

int CUDT::rejectReason(SRTSOCKET u, int value)
{
CUDTSocket* s = s_UDTUnited.locateSocket(u);
if (!s || !s->m_pUDT)
return APIError(MJ_NOTSUP, MN_SIDINVAL);

if (value < SRT_REJC_SERVER)
return APIError(MJ_NOTSUP, MN_INVAL);

s->m_pUDT->m_RejectReason = value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SRT_REJC_SERVER = 1000 (see srt.h)

srt_setrejectreason(..) should return error if the value >=2000 is passed.
The values below are SRT internal.

if (value < SRT_REJC_SERVER) does not protect from this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I am probably confused by SRT API and SRT protocol error ranges. 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a confusion, so let me explain:

The rejection code is passed through the URQ field in the handshake. For this number the ranges are valid codes, below 1000. Since 1000 on there start values of rejection codes, system ones this time, and those are these collected in SRT_REJECT_REASON. The code is then 1000 + the reject reason. Therefore for this field the server and user codes will be shifted by 2000 and 3000 respectively.

Translation from URQ into the rejection code is URQ - 1000, however:

  • Codes from the system region (0 to 999) result in the system error code, unless the value is outside the system rejection values, in which case it renders to UNKNOWN
  • Otherwise pass the code as is, just doing URQ - 1000.

return 0;
}

bool CUDT::runAcceptHook(CUDT *acore, const sockaddr* peer, const CHandShake& hs, const CPacket& hspkt)
{
// Prepare the information for the hook.
Expand Down
11 changes: 8 additions & 3 deletions srtcore/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ class CUDT
static bool setstreamid(SRTSOCKET u, const std::string& sid);
static std::string getstreamid(SRTSOCKET u);
static int getsndbuffer(SRTSOCKET u, size_t* blocks, size_t* bytes);
static SRT_REJECT_REASON rejectReason(SRTSOCKET s);
static int rejectReason(SRTSOCKET s);
static int rejectReason(SRTSOCKET s, int value);

public: // internal API
// This is public so that it can be used directly in API implementation functions.
Expand Down Expand Up @@ -831,6 +832,7 @@ class CUDT
SRT_ATR_NODISCARD int processSrtMsg_HSREQ(const uint32_t* srtdata, size_t len, uint32_t ts, int hsv);
SRT_ATR_NODISCARD int processSrtMsg_HSRSP(const uint32_t* srtdata, size_t len, uint32_t ts, int hsv);
SRT_ATR_NODISCARD bool interpretSrtHandshake(const CHandShake& hs, const CPacket& hspkt, uint32_t* out_data, size_t* out_len);
void interpretRejectionMessage(const CHandShake& hs, const CPacket& pkt);
SRT_ATR_NODISCARD bool checkApplyFilterConfig(const std::string& cs);

static CUDTGroup& newGroup(const int); // defined EXCEPTIONALLY in api.cpp for convenience reasons
Expand Down Expand Up @@ -1087,7 +1089,7 @@ class CUDT
volatile bool m_bShutdown; // If the peer side has shutdown the connection
volatile bool m_bBroken; // If the connection has been broken
volatile bool m_bPeerHealth; // If the peer status is normal
volatile SRT_REJECT_REASON m_RejectReason;
volatile int m_RejectReason;
bool m_bOpened; // If the UDT entity has been opened
int m_iBrokenCounter; // a counter (number of GC checks) to let the GC tag this socket as disconnected

Expand Down Expand Up @@ -1313,7 +1315,10 @@ class CUDT

int processData(CUnit* unit);
void processClose();
SRT_REJECT_REASON processConnectRequest(const sockaddr_any& addr, CPacket& packet);

/// Returns: URQ code, possibly containing reject reason
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doxygen style requires:

/// @returns URQ code, possibly containing reject reason

int processConnectRequest(const sockaddr_any& addr, CPacket& packet);
size_t addHandshakeExtension(char *data, int cmd, size_t hs_size, std::string contents);
static void addLossRecord(std::vector<int32_t>& lossrecord, int32_t lo, int32_t hi);
int32_t bake(const sockaddr_any& addr, int32_t previous_cookie = 0, int correction = 0);
int32_t ackDataUpTo(int32_t seq);
Expand Down
Loading