From 8eb27e0985b9d87834d822676a104b74faafb9a1 Mon Sep 17 00:00:00 2001 From: Carlos Segarra Date: Wed, 28 Feb 2024 17:20:05 +0000 Subject: [PATCH] transport(ptp): move from protobuf to fixed-size c-struct This commit moves the abstraction of a PTP message from a protobuf object to a fixed-size C-struct with a heap pointer. The rationale is that these PTP messages move through the system, and even when careful it is challenging to keep track of the copies allocations that protobuf is doing under the hood. In exchange, we are very explicity about the copies and allocations we do. --- .../faabric/transport/PointToPointBroker.h | 21 +- .../faabric/transport/PointToPointClient.h | 7 +- .../faabric/transport/PointToPointMessage.h | 45 +++ src/mpi/MpiWorld.cpp | 39 ++- src/proto/faabric.proto | 8 - src/scheduler/Scheduler.cpp | 26 +- src/transport/CMakeLists.txt | 1 + src/transport/MessageEndpointClient.cpp | 1 + src/transport/PointToPointBroker.cpp | 272 +++++++++++------- src/transport/PointToPointClient.cpp | 36 ++- src/transport/PointToPointMessage.cpp | 62 ++++ src/transport/PointToPointServer.cpp | 51 ++-- tests/dist/transport/functions.cpp | 51 +++- tests/dist/transport/test_point_to_point.cpp | 1 - tests/test/transport/test_point_to_point.cpp | 94 ++++-- .../transport/test_point_to_point_groups.cpp | 29 +- .../transport/test_point_to_point_message.cpp | 95 ++++++ 17 files changed, 599 insertions(+), 240 deletions(-) create mode 100644 include/faabric/transport/PointToPointMessage.h create mode 100644 src/transport/PointToPointMessage.cpp create mode 100644 tests/test/transport/test_point_to_point_message.cpp diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h index 95f6cba17..87a47ca3b 100644 --- a/include/faabric/transport/PointToPointBroker.h +++ b/include/faabric/transport/PointToPointBroker.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -120,27 +121,16 @@ class PointToPointBroker void updateHostForIdx(int groupId, int groupIdx, std::string newHost); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg = false); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, bool mustOrderMsg = false, int sequenceNum = NO_SEQUENCE_NUM, std::string hostHint = ""); - std::vector recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg = false); + void recvMessage(PointToPointMessage& msg, bool mustOrderMsg = false); void clearGroup(int groupId); @@ -163,7 +153,8 @@ class PointToPointBroker std::shared_ptr getGroupFlag(int groupId); - Message doRecvMessage(int groupId, int sendIdx, int recvIdx); + // Returns the message response code and the sequence number + std::pair doRecvMessage(PointToPointMessage& msg); void initSequenceCounters(int groupId); diff --git a/include/faabric/transport/PointToPointClient.h b/include/faabric/transport/PointToPointClient.h index 634b41579..5e5add933 100644 --- a/include/faabric/transport/PointToPointClient.h +++ b/include/faabric/transport/PointToPointClient.h @@ -3,18 +3,19 @@ #include #include #include +#include namespace faabric::transport { std::vector> getSentMappings(); -std::vector> +std::vector> getSentPointToPointMessages(); std::vector> + PointToPointMessage>> getSentLockMessages(); void clearSentMessages(); @@ -26,7 +27,7 @@ class PointToPointClient : public faabric::transport::MessageEndpointClient void sendMappings(faabric::PointToPointMappings& mappings); - void sendMessage(faabric::PointToPointMessage& msg, + void sendMessage(const PointToPointMessage& msg, int sequenceNum = NO_SEQUENCE_NUM); void groupLock(int appId, diff --git a/include/faabric/transport/PointToPointMessage.h b/include/faabric/transport/PointToPointMessage.h new file mode 100644 index 000000000..e61e2c509 --- /dev/null +++ b/include/faabric/transport/PointToPointMessage.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +namespace faabric::transport { + +/* Simple fixed-size C-struct to capture the state of a PTP message moving + * through Faabric. + * + * We require fixed-size, and no unique pointers to be able to use + * high-throughput ring-buffers to send the messages around. This also means + * that we manually malloc/free the data pointer. The message size is: + * 4 * int32_t = 4 * 4 bytes = 16 bytes + * 1 * size_t = 1 * 8 bytes = 8 bytes + * 1 * void* = 1 * 8 bytes = 8 bytes + * total = 32 bytes = 4 * 8 so the struct is naturally 8 byte-aligned + */ +struct PointToPointMessage +{ + int32_t appId; + int32_t groupId; + int32_t sendIdx; + int32_t recvIdx; + size_t dataSize; + void* dataPtr; +}; +static_assert((sizeof(PointToPointMessage) % 8) == 0, + "PTP message mus be 8-aligned!"); + +// The wire format for a PTP message is very simple: the fixed-size struct, +// followed by dataSize bytes containing the payload. +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg); + +// This parsing function mallocs space for the message payload. This is to +// keep the PTP message at fixed-size, and be able to efficiently move it +// around in-memory queues +void parsePtpMsg(std::span bytes, PointToPointMessage* msg); + +// Alternative signature for parsing PTP messages for when the caller can +// provide an already-allocated buffer to write into +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer); +} diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index d50344c40..e300ea006 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -60,14 +61,16 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, throw std::runtime_error("Error serialising message"); } try { - broker.sendMessage( - thisRankMsg->groupid(), - sendRank, - recvRank, - reinterpret_cast(serialisedBuffer.data()), - serialisedBuffer.size(), - dstHost, - true); + // It is safe to send a pointer to a stack-allocated object + // because the broker will make an additional copy (and so will NNG!) + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank, + .dataSize = serialisedBuffer.size(), + .dataPtr = (void*)serialisedBuffer.data() }); + + broker.sendMessage(msg, dstHost, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - send {} -> {}", thisRankMsg->appid(), @@ -82,10 +85,12 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { - std::vector msg; + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank }); try { - msg = - broker.recvMessage(thisRankMsg->groupid(), sendRank, recvRank, true); + broker.recvMessage(msg, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - recv (remote) {} -> {}", thisRankMsg->appid(), @@ -95,7 +100,12 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, recvRank); throw e; } - PARSE_MSG(MPIMessage, msg.data(), msg.size()); + + // Parsing into the protobuf makes a copy of the message, so we can + // free the heap pointer after + PARSE_MSG(MPIMessage, msg.dataPtr, msg.dataSize); + faabric::util::free(msg.dataPtr); + return std::make_shared(parsedMsg); } @@ -599,7 +609,10 @@ void MpiWorld::doRecv(std::shared_ptr& m, // Assert message integrity // Note - this checks won't happen in Release builds if (m->messagetype() != messageType) { - SPDLOG_ERROR("Different message types (got: {}, expected: {})", + SPDLOG_ERROR("{}:{}:{} Different message types (got: {}, expected: {})", + m->worldid(), + m->sender(), + m->destination(), m->messagetype(), messageType); } diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 5daa2b5cb..8ed729a8e 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -199,14 +199,6 @@ message StateAppendedResponse { // POINT-TO-POINT // --------------------------------------------- -message PointToPointMessage { - int32 appId = 1; - int32 groupId = 2; - int32 sendIdx = 3; - int32 recvIdx = 4; - bytes data = 5; -} - message PointToPointMappings { int32 appId = 1; int32 groupId = 2; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 63a87c49d..6ce6f4afd 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -459,12 +459,32 @@ Scheduler::checkForMigrationOpportunities(faabric::Message& msg, auto groupIdxs = broker.getIdxsRegisteredForGroup(groupId); groupIdxs.erase(0); for (const auto& recvIdx : groupIdxs) { - broker.sendMessage( - groupId, 0, recvIdx, BYTES_CONST(&newGroupId), sizeof(int)); + // It is safe to send a pointer to the stack, because the + // transport layer will perform an additional copy of the PTP + // message to put it in the message body + // TODO(no-inproc): this may not be true once we move the inproc + // sockets to in-memory queues + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, + .sendIdx = 0, + .recvIdx = recvIdx, + .dataSize = sizeof(int), + .dataPtr = &newGroupId }); + + broker.sendMessage(msg); } } else if (overwriteNewGroupId == 0) { - std::vector bytes = broker.recvMessage(groupId, 0, groupIdx); + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, .sendIdx = 0, .recvIdx = groupIdx }); + // TODO(no-order): when we remove the need to order ptp messages we + // should be able to call recv giving it a pre-allocated buffer, + // avoiding the hassle of malloc-ing and free-ing + broker.recvMessage(msg); + std::vector bytes((uint8_t*)msg.dataPtr, + (uint8_t*)msg.dataPtr + msg.dataSize); newGroupId = faabric::util::bytesToInt(bytes); + // The previous call makes a copy, so safe to free now + faabric::util::free(msg.dataPtr); } else { // In some settings, like tests, we already know the new group id, so // we can set it here (and in fact, we need to do so when faking two diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index e8b7c339d..e68fa72bd 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -9,6 +9,7 @@ faabric_lib(transport MessageEndpointServer.cpp PointToPointBroker.cpp PointToPointClient.cpp + PointToPointMessage.cpp PointToPointServer.cpp ) diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 7984c951b..4bee05763 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -36,6 +36,7 @@ void MessageEndpointClient::asyncSend(int header, sequenceNum); } +// TODO: consider making an iovec-style scatter/gather alternative signature void MessageEndpointClient::asyncSend(int header, const uint8_t* buffer, size_t bufferSize, diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 9581fc27a..d2c5a0cc3 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -53,7 +53,8 @@ thread_local std::vector sentMsgCount; thread_local std::vector recvMsgCount; -thread_local std::vector> outOfOrderMsgs; +thread_local std::vector>> + outOfOrderMsgs; static std::shared_ptr getClient(const std::string& host) { @@ -202,8 +203,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) groupId, recursive); - ptpBroker.recvMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } else { // Notify remote locker that they've acquired the lock SPDLOG_TRACE( @@ -217,10 +222,6 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) } } else { auto cli = getClient(mainHost); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote lock {}:{}:{} to {}", groupId, @@ -232,7 +233,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) // acquired cli->groupLock(appId, groupId, groupIdx, recursive); - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } } @@ -285,10 +291,6 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote unlock {}:{}:{} to {}", groupId, @@ -308,9 +310,13 @@ void PointToPointGroup::localUnlock() void PointToPointGroup::notifyLocked(int groupIdx) { std::vector data(1, 0); - - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg = { .appId = 0, + .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }; + ptpBroker.sendMessage(msg); } void PointToPointGroup::barrier(int groupIdx) @@ -324,23 +330,40 @@ void PointToPointGroup::barrier(int groupIdx) if (groupIdx == POINT_TO_POINT_MAIN_IDX) { // Receive from all for (int i = 1; i < groupSize; i++) { - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } // Reply to all std::vector data(1, 0); for (int i = 1; i < groupSize; i++) { - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, i, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = i, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } else { // Do the send - std::vector data(1, 0); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); // Await the response - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage response({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(response); } } @@ -351,15 +374,23 @@ void PointToPointGroup::notify(int groupIdx) SPDLOG_TRACE( "Master group {} waiting for notify from index {}", groupId, i); - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); SPDLOG_TRACE("Master group {} notified by index {}", groupId, i); } } else { - std::vector data(1, 0); SPDLOG_TRACE("Notifying group {} from index {}", groupId, groupIdx); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } @@ -581,22 +612,11 @@ void PointToPointBroker::updateHostForIdx(int groupId, mappings[key] = newHost; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg) { - sendMessage(groupId, - sendIdx, - recvIdx, - buffer, - bufferSize, - mustOrderMsg, - NO_SEQUENCE_NUM, - hostHint); + sendMessage(msg, mustOrderMsg, NO_SEQUENCE_NUM, hostHint); } // Gets or creates a pair of inproc endpoints (recv&send) in the endpoints map. @@ -634,11 +654,7 @@ auto getEndpointPtrs(const std::string& label) return endpointPtrs; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, bool mustOrderMsg, int sequenceNum, std::string hostHint) @@ -647,19 +663,21 @@ void PointToPointBroker::sendMessage(int groupId, // sender thread, and another time from the point-to-point server to route // it to the receiver thread - waitForMappingsOnThisHost(groupId); + waitForMappingsOnThisHost(msg.groupId); // If the application code knows which host does the receiver live in // (cached for performance) we allow it to provide a hint to avoid // acquiring a shared lock here - std::string host = - hostHint.empty() ? getHostForReceiver(groupId, recvIdx) : hostHint; + std::string host = hostHint.empty() + ? getHostForReceiver(msg.groupId, msg.recvIdx) + : hostHint; // Set the sequence number if we need ordering and one is not provided bool mustSetSequenceNum = mustOrderMsg && sequenceNum == NO_SEQUENCE_NUM; if (host == conf.endpointHost) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = @@ -671,46 +689,49 @@ void PointToPointBroker::sendMessage(int groupId, // the sender thread we add a sequence number (if needed) int localSendSeqNum = sequenceNum; if (mustSetSequenceNum) { - localSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + localSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Local point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); try { - endpoint.send(NO_HEADER, buffer, bufferSize, localSendSeqNum); + // TODO(no-inproc): once we convert the inproc endpoints to a queue + // we should be able to just push the whole message to the queue + std::vector buffer(sizeof(PointToPointMessage) + + msg.dataSize); + serializePtpMsg(buffer, msg); + endpoint.send( + NO_HEADER, buffer.data(), buffer.size(), localSendSeqNum); } catch (std::runtime_error& e) { SPDLOG_ERROR("Timed-out with local point-to-point message {}:{}:{} " "(seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); throw e; } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(sendIdx); - msg.set_recvidx(recvIdx); - msg.set_data(buffer, bufferSize); // When sending a remote message, we set a sequence number if required int remoteSendSeqNum = NO_SEQUENCE_NUM; if (mustSetSequenceNum) { - remoteSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + remoteSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Remote point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); @@ -719,59 +740,81 @@ void PointToPointBroker::sendMessage(int groupId, } catch (std::runtime_error& e) { SPDLOG_TRACE("Timed-out with remote point-to-point message " "{}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); } } } -Message PointToPointBroker::doRecvMessage(int groupId, int sendIdx, int recvIdx) +std::pair PointToPointBroker::doRecvMessage( + PointToPointMessage& msg) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = *std::get>( *endpointPtrs); - return endpoint.recv(); + // TODO(no-inproc): this will become a pop from a queue, not a read from + // an in-proc socket + Message bytes = endpoint.recv(); + + // WARNING: this call mallocs + parsePtpMsg(bytes.udata(), &msg); + + /* TODO(no-order): for the moment always parse and malloc memory, as it is + * not easy to track when did we malloc or not. This is gonna become + * simpler once we remove the need to order messages in the PTP layer + * + if (hasPreAllocBuffer) { + std::span msgDataSpan((uint8_t*) msg.dataPtr, msg.dataSize); + parsePtpMsg(bytes.udata(), &msg, msgDataSpan); + } else { + parsePtpMsg(bytes.udata(), &msg); + } + */ + + assert(getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx) == label); + + return std::make_pair(bytes.getResponseCode(), + bytes.getSequenceNum()); } -std::vector PointToPointBroker::recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg) +void PointToPointBroker::recvMessage(PointToPointMessage& msg, + bool mustOrderMsg) { // If we don't need to receive messages in order, return here if (!mustOrderMsg) { - // TODO - can we avoid this copy? - return doRecvMessage(groupId, sendIdx, recvIdx).dataCopy(); + doRecvMessage(msg); + return; } // Get the sequence number we expect to receive - int expectedSeqNum = getExpectedSeqNum(groupId, sendIdx); + int expectedSeqNum = getExpectedSeqNum(msg.groupId, msg.sendIdx); // We first check if we have already received the message. We only need to // check this once. - auto foundIterator = - std::find_if(outOfOrderMsgs.at(sendIdx).begin(), - outOfOrderMsgs.at(sendIdx).end(), - [expectedSeqNum](const Message& msg) { - return msg.getSequenceNum() == expectedSeqNum; - }); - if (foundIterator != outOfOrderMsgs.at(sendIdx).end()) { + auto foundIterator = std::find_if( + outOfOrderMsgs.at(msg.sendIdx).begin(), + outOfOrderMsgs.at(msg.sendIdx).end(), + [expectedSeqNum](const std::pair& pair) { + return pair.first == expectedSeqNum; + }); + if (foundIterator != outOfOrderMsgs.at(msg.sendIdx).end()) { SPDLOG_TRACE("Retrieved the expected message ({}:{} seq: {}) from the " "out-of-order buffer", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - Message returnMsg = std::move(*foundIterator); - outOfOrderMsgs.at(sendIdx).erase(foundIterator); - return returnMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + msg = foundIterator->second; + outOfOrderMsgs.at(msg.sendIdx).erase(foundIterator); + return; } // Given that we don't have the message, we query the transport layer until @@ -779,47 +822,52 @@ std::vector PointToPointBroker::recvMessage(int groupId, while (true) { SPDLOG_TRACE( "Entering loop to query transport layer for msg ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - // Receive from the transport layer - Message recvMsg = doRecvMessage(groupId, sendIdx, recvIdx); + + // Receive from the transport layer with the same group id and + // send/recv indexes + PointToPointMessage tmpMsg({ .groupId = msg.groupId, + .sendIdx = msg.sendIdx, + .recvIdx = msg.recvIdx }); + auto [responseCode, seqNum] = doRecvMessage(tmpMsg); // If the receive was not successful, exit the loop - if (recvMsg.getResponseCode() != - faabric::transport::MessageResponseCode::SUCCESS) { + if (responseCode != faabric::transport::MessageResponseCode::SUCCESS) { SPDLOG_WARN( "Error {} ({}) when awaiting a message ({}:{} seq: {} label: {})", - static_cast(recvMsg.getResponseCode()), - MessageResponseCodeText.at(recvMsg.getResponseCode()), - sendIdx, - recvIdx, + static_cast(responseCode), + MessageResponseCodeText.at(responseCode), + msg.sendIdx, + msg.recvIdx, expectedSeqNum, - getPointToPointKey(groupId, sendIdx, recvIdx)); + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx)); throw std::runtime_error("Error when awaiting a PTP message"); } // If the sequence numbers match, exit the loop - int seqNum = recvMsg.getSequenceNum(); if (seqNum == expectedSeqNum) { SPDLOG_TRACE("Received the expected message ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - return recvMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + + msg = tmpMsg; + return; } // If not, we must insert the received message in the out of order // received messages SPDLOG_TRACE("Received out-of-order message ({}:{} seq: {}) (expected: " "{} - got: {})", - sendIdx, - recvIdx, + tmpMsg.sendIdx, + tmpMsg.recvIdx, seqNum, expectedSeqNum, seqNum); - outOfOrderMsgs.at(sendIdx).emplace_back(std::move(recvMsg)); + outOfOrderMsgs.at(tmpMsg.sendIdx).emplace_back(seqNum, tmpMsg); } } @@ -874,10 +922,10 @@ void PointToPointBroker::resetThreadLocalCache() void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) { + /* int postMigrationOkCode = 1337; int recvCode = 0; - // TODO: implement this as a broadcast in the PTP broker int mainIdx = 0; if (groupIdx == mainIdx) { auto groupIdxs = getIdxsRegisteredForGroup(groupId); @@ -902,6 +950,8 @@ void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) recvCode); throw std::runtime_error("Error in post-migration hook"); } + */ + PointToPointGroup::getGroup(groupId)->barrier(groupIdx); SPDLOG_DEBUG("{}:{} exiting post-migration hook", groupId, groupIdx); } diff --git a/src/transport/PointToPointClient.cpp b/src/transport/PointToPointClient.cpp index d0b7188f8..506fc9874 100644 --- a/src/transport/PointToPointClient.cpp +++ b/src/transport/PointToPointClient.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -13,12 +14,11 @@ static std::mutex mockMutex; static std::vector> sentMappings; -static std::vector> - sentMessages; +static std::vector> sentMessages; static std::vector> + PointToPointMessage>> sentLockMessages; std::vector> @@ -27,7 +27,7 @@ getSentMappings() return sentMappings; } -std::vector> +std::vector> getSentPointToPointMessages() { return sentMessages; @@ -35,7 +35,7 @@ getSentPointToPointMessages() std::vector> + PointToPointMessage>> getSentLockMessages() { return sentLockMessages; @@ -64,13 +64,18 @@ void PointToPointClient::sendMappings(faabric::PointToPointMappings& mappings) } } -void PointToPointClient::sendMessage(faabric::PointToPointMessage& msg, +void PointToPointClient::sendMessage(const PointToPointMessage& msg, int sequenceNum) { if (faabric::util::isMockMode()) { sentMessages.emplace_back(host, msg); } else { - asyncSend(PointToPointCall::MESSAGE, &msg, sequenceNum); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(msg) + msg.dataSize); + serializePtpMsg(buffer, msg); + asyncSend( + PointToPointCall::MESSAGE, buffer.data(), buffer.size(), sequenceNum); } } @@ -80,11 +85,12 @@ void PointToPointClient::makeCoordinationRequest( int groupIdx, faabric::transport::PointToPointCall call) { - faabric::PointToPointMessage req; - req.set_appid(appId); - req.set_groupid(groupId); - req.set_sendidx(groupIdx); - req.set_recvidx(POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req({ .appId = appId, + .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); switch (call) { case (faabric::transport::PointToPointCall::LOCK_GROUP): { @@ -115,7 +121,11 @@ void PointToPointClient::makeCoordinationRequest( faabric::util::UniqueLock lock(mockMutex); sentLockMessages.emplace_back(host, call, req); } else { - asyncSend(call, &req); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(PointToPointMessage) + req.dataSize); + serializePtpMsg(buffer, req); + asyncSend(call, buffer.data(), buffer.size()); } } diff --git a/src/transport/PointToPointMessage.cpp b/src/transport/PointToPointMessage.cpp new file mode 100644 index 000000000..e9415db11 --- /dev/null +++ b/src/transport/PointToPointMessage.cpp @@ -0,0 +1,62 @@ +#include +#include + +#include +#include +#include + +namespace faabric::transport { + +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg) +{ + assert(buffer.size() == sizeof(PointToPointMessage) + msg.dataSize); + std::memcpy(buffer.data(), &msg, sizeof(PointToPointMessage)); + + if (msg.dataSize > 0 && msg.dataPtr != nullptr) { + std::memcpy(buffer.data() + sizeof(PointToPointMessage), + msg.dataPtr, + msg.dataSize); + } +} + +// Parse all the fixed-size parts of the struct +static void parsePtpMsgCommon(std::span bytes, + PointToPointMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(PointToPointMessage)); + std::memcpy(msg, bytes.data(), sizeof(PointToPointMessage)); + size_t thisDataSize = bytes.size() - sizeof(PointToPointMessage); + assert(thisDataSize == msg->dataSize); + + if (thisDataSize == 0) { + msg->dataPtr = nullptr; + } +} + +void parsePtpMsg(std::span bytes, PointToPointMessage* msg) +{ + parsePtpMsgCommon(bytes, msg); + + if (msg->dataSize == 0) { + return; + } + + // malloc memory for the PTP message payload + msg->dataPtr = faabric::util::malloc(msg->dataSize); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} + +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer) +{ + parsePtpMsgCommon(bytes, msg); + + assert(msg->dataSize == preAllocBuffer.size()); + msg->dataPtr = preAllocBuffer.data(); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} +} diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp index 173fec0bf..6224eed84 100644 --- a/src/transport/PointToPointServer.cpp +++ b/src/transport/PointToPointServer.cpp @@ -1,12 +1,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace faabric::transport { @@ -25,9 +27,11 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) int sequenceNum = message.getSequenceNum(); switch (header) { case (faabric::transport::PointToPointCall::MESSAGE): { - PARSE_MSG(faabric::PointToPointMessage, - message.udata().data(), - message.udata().size()) + // Here we are copying the message from the transport layer (NNG) + // into our PTP message structure + // NOTE: this mallocs + PointToPointMessage parsedMsg; + parsePtpMsg(message.udata(), &parsedMsg); // If the sequence number is set, we must also set the ordering // flag @@ -35,13 +39,15 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) // Send the message locally to the downstream socket, add the // sequence number for in-order reception - broker.sendMessage(parsedMsg.groupid(), - parsedMsg.sendidx(), - parsedMsg.recvidx(), - BYTES_CONST(parsedMsg.data().c_str()), - parsedMsg.data().size(), - mustOrderMsg, - sequenceNum); + broker.sendMessage(parsedMsg, mustOrderMsg, sequenceNum); + + // TODO(no-inproc): for the moment, the downstream (inproc) + // socket makes a copy of this message, so we can free it now + // after sending. This will not be the case once we move to + // in-memory queues + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } break; } case faabric::transport::PointToPointCall::LOCK_GROUP: { @@ -101,28 +107,33 @@ std::unique_ptr PointToPointServer::doRecvMappings( void PointToPointServer::recvGroupLock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); + SPDLOG_TRACE("Receiving lock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->lock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->lock(parsedMsg.sendIdx, recursive); } void PointToPointServer::recvGroupUnlock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); SPDLOG_TRACE("Receiving unlock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->unlock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->unlock(parsedMsg.sendIdx, recursive); } void PointToPointServer::onWorkerStop() diff --git a/tests/dist/transport/functions.cpp b/tests/dist/transport/functions.cpp index 8c99b05b2..1485f5f47 100644 --- a/tests/dist/transport/functions.cpp +++ b/tests/dist/transport/functions.cpp @@ -4,9 +4,9 @@ #include "faabric_utils.h" #include "init.h" -#include #include #include +#include #include #include #include @@ -43,12 +43,25 @@ int handlePointToPointFunction( std::vector expectedRecvData(10, recvFromIdx); // Do the sending - broker.sendMessage( - groupId, groupIdx, sendToIdx, sendData.data(), sendData.size()); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = sendToIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg); // Do the receiving - std::vector actualRecvData = - broker.recvMessage(groupId, recvFromIdx, groupIdx); + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = recvFromIdx, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg); + std::vector actualRecvData(recvMsg.dataSize); + std::memcpy(actualRecvData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of recvMessage + // to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); // Check data is as expected if (actualRecvData != expectedRecvData) { @@ -82,19 +95,31 @@ int handleManyPointToPointMsgFunction( // Send loop for (int i = 0; i < numMsg; i++) { std::vector sendData(5, i); - broker.sendMessage(groupId, - sendIdx, - recvIdx, - sendData.data(), - sendData.size(), - true); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg, true); } } else if (groupIdx == recvIdx) { // Recv loop for (int i = 0; i < numMsg; i++) { std::vector expectedData(5, i); - auto actualData = - broker.recvMessage(groupId, sendIdx, recvIdx, true); + + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg, true); + + std::vector actualData(recvMsg.dataSize); + std::memcpy(actualData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of + // recvMessage to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); + if (actualData != expectedData) { SPDLOG_ERROR( "Out-of-order message reception (got: {}, expected: {})", diff --git a/tests/dist/transport/test_point_to_point.cpp b/tests/dist/transport/test_point_to_point.cpp index 28ab5b6c7..8cd0f6e49 100644 --- a/tests/dist/transport/test_point_to_point.cpp +++ b/tests/dist/transport/test_point_to_point.cpp @@ -5,7 +5,6 @@ #include "init.h" #include -#include #include #include #include diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp index 98b16b9f7..6a23b90c5 100644 --- a/tests/test/transport/test_point_to_point.cpp +++ b/tests/test/transport/test_point_to_point.cpp @@ -120,9 +120,7 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector sentDataA = { 0, 1, 2, 3 }; std::vector receivedDataA; std::vector sentDataB = { 3, 4, 5 }; - std::vector receivedDataB; std::vector sentDataC = { 6, 7, 8 }; - std::vector receivedDataC; std::shared_ptr msgLatch = std::make_shared(2, 1000); @@ -131,34 +129,60 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, PointToPointBroker& broker = getPointToPointBroker(); // Receive the first message - receivedDataA = broker.recvMessage(groupId, idxA, idxB); + PointToPointMessage msgAB( + { .groupId = groupId, .sendIdx = idxA, .recvIdx = idxB }); + broker.recvMessage(msgAB); + receivedDataA.resize(msgAB.dataSize); + std::memcpy(receivedDataA.data(), msgAB.dataPtr, msgAB.dataSize); + faabric::util::free(msgAB.dataPtr); msgLatch->wait(); // Send a message back - broker.sendMessage( - groupId, idxB, idxA, sentDataB.data(), sentDataB.size()); + PointToPointMessage msgBA({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataB.size(), + .dataPtr = sentDataB.data() }); + broker.sendMessage(msgBA); // Lastly, send another message specifying the recepient host to avoid // an extra check in the broker - broker.sendMessage(groupId, - idxB, - idxA, - sentDataC.data(), - sentDataC.size(), - std::string(LOCALHOST)); + PointToPointMessage msgBA2({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataC.size(), + .dataPtr = sentDataC.data() }); + broker.sendMessage(msgBA2, std::string(LOCALHOST)); broker.resetThreadLocalCache(); }); // Only send the message after the thread creates a receiving socket to // avoid deadlock - broker.sendMessage(groupId, idxA, idxB, sentDataA.data(), sentDataA.size()); + PointToPointMessage msgAB({ .groupId = groupId, + .sendIdx = idxA, + .recvIdx = idxB, + .dataSize = sentDataA.size(), + .dataPtr = sentDataA.data() }); + broker.sendMessage(msgAB); // Wait for the thread to handle the message msgLatch->wait(); // Receive the two messages sent back - receivedDataB = broker.recvMessage(groupId, idxB, idxA); - receivedDataC = broker.recvMessage(groupId, idxB, idxA); + + PointToPointMessage msgBA1( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA1); + std::vector receivedDataB( + (uint8_t*)msgBA1.dataPtr, (uint8_t*)msgBA1.dataPtr + msgBA1.dataSize); + faabric::util::free(msgBA1.dataPtr); + + PointToPointMessage msgBA2( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA2); + std::vector receivedDataC( + (uint8_t*)msgBA2.dataPtr, (uint8_t*)msgBA2.dataPtr + msgBA2.dataSize); + faabric::util::free(msgBA2.dataPtr); if (t.joinable()) { t.join(); @@ -236,22 +260,28 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector recvData; for (int i = 0; i < numMsg; i++) { - recvData = - broker.recvMessage(groupId, idxA, idxB, isMessageOrderingOn); sendData = std::vector(3, i); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // TODO(no-order): when we remove the need to order PTP messages + // we will be able to provide a buffer to receive the message into + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(recvData == sendData); + faabric::util::free(msg.dataPtr); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxB, - idxA, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } broker.resetThreadLocalCache(); @@ -262,20 +292,26 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxA, - idxB, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - recvData = broker.recvMessage(groupId, idxB, idxA, isMessageOrderingOn); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // REQUIRE(msg.dataSize == recvData.size()); + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(sendData == recvData); + faabric::util::free(msg.dataPtr); } if (t.joinable()) { diff --git a/tests/test/transport/test_point_to_point_groups.cpp b/tests/test/transport/test_point_to_point_groups.cpp index 8d9761335..fa583e70b 100644 --- a/tests/test/transport/test_point_to_point_groups.cpp +++ b/tests/test/transport/test_point_to_point_groups.cpp @@ -135,8 +135,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, op = PointToPointCall::LOCK_GROUP; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, false); } @@ -147,8 +151,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, recursive = true; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, recursive); } @@ -166,8 +174,7 @@ TEST_CASE_METHOD(PointToPointGroupFixture, group->unlock(groupIdx, recursive); } - std::vector< - std::tuple> + std::vector> actualRequests = getSentLockMessages(); REQUIRE(actualRequests.size() == 1); @@ -176,11 +183,11 @@ TEST_CASE_METHOD(PointToPointGroupFixture, PointToPointCall actualOp = std::get<1>(actualRequests.at(0)); REQUIRE(actualOp == op); - faabric::PointToPointMessage req = std::get<2>(actualRequests.at(0)); - REQUIRE(req.appid() == appId); - REQUIRE(req.groupid() == groupId); - REQUIRE(req.sendidx() == groupIdx); - REQUIRE(req.recvidx() == POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req = std::get<2>(actualRequests.at(0)); + REQUIRE(req.appId == appId); + REQUIRE(req.groupId == groupId); + REQUIRE(req.sendIdx == groupIdx); + REQUIRE(req.recvIdx == POINT_TO_POINT_MAIN_IDX); } TEST_CASE_METHOD(PointToPointGroupFixture, diff --git a/tests/test/transport/test_point_to_point_message.cpp b/tests/test/transport/test_point_to_point_message.cpp new file mode 100644 index 000000000..51b1cb87c --- /dev/null +++ b/tests/test/transport/test_point_to_point_message.cpp @@ -0,0 +1,95 @@ +#include + +#include +#include + +#include + +using namespace faabric::transport; + +namespace tests { + +bool arePtpMsgEqual(const PointToPointMessage& msgA, const PointToPointMessage& msgB) +{ + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(PointToPointMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.dataPtr == nullptr || msgB.dataPtr == nullptr) { + return msgA.dataPtr == msgB.dataPtr; + } + + return std::memcmp(msgA.dataPtr, msgB.dataPtr, msgA.dataSize) == 0; +} + +TEST_CASE("Test (de)serialising a PTP message", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + SECTION("Empty message") + { + msg.dataSize = 0; + msg.dataPtr = nullptr; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + } + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + + if (msg.dataPtr != nullptr) { + faabric::util::free(msg.dataPtr); + } + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } +} + +TEST_CASE("Test (de)serialising a PTP message into prealloc buffer", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + std::vector preAllocBuffer(msg.dataSize); + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg, preAllocBuffer); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + REQUIRE(parsedMsg.dataPtr == preAllocBuffer.data()); + + faabric::util::free(msg.dataPtr); +} +}