From bc667e0cd4d244f125bad475a5c81e792b0b75dc Mon Sep 17 00:00:00 2001 From: Nikunj Yadav Date: Wed, 21 Oct 2015 14:10:06 -0700 Subject: [PATCH] Wdt State machine refactor Summary: Get rid of ThreadData and move different functionalities into the appropriate modules Reviewed By: @uddipta Differential Revision: D2445473 --- CMakeLists.txt | 7 +- ErrorCodes.h | 2 +- FileCreator.cpp | 6 +- Protocol.cpp | 3 +- README.md | 26 +- Receiver.cpp | 1290 ++++----------------------------- Receiver.h | 426 +---------- ReceiverThread.cpp | 905 +++++++++++++++++++++++ ReceiverThread.h | 341 +++++++++ Reporting.cpp | 29 + Reporting.h | 53 +- Sender.cpp | 1013 +++----------------------- Sender.h | 354 ++------- SenderThread.cpp | 633 ++++++++++++++++ SenderThread.h | 278 +++++++ ThreadTransferHistory.cpp | 282 +++++++ ThreadTransferHistory.h | 168 +++++ ThreadsController.cpp | 241 ++++++ ThreadsController.h | 371 ++++++++++ ThreadsControllerTest.cpp | 122 ++++ TransferLogManager.cpp | 6 +- WdtBase.cpp | 9 + WdtBase.h | 39 +- WdtConfig.h | 6 +- WdtThread.cpp | 48 ++ WdtThread.h | 79 ++ wdt_global_checkpoint_test.sh | 38 + 27 files changed, 3976 insertions(+), 2799 deletions(-) create mode 100644 ReceiverThread.cpp create mode 100644 ReceiverThread.h create mode 100644 SenderThread.cpp create mode 100644 SenderThread.h create mode 100644 ThreadTransferHistory.cpp create mode 100644 ThreadTransferHistory.h create mode 100644 ThreadsController.cpp create mode 100644 ThreadsController.h create mode 100644 ThreadsControllerTest.cpp create mode 100644 WdtThread.cpp create mode 100644 WdtThread.h create mode 100644 wdt_global_checkpoint_test.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index e6467528..a4c1eecd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ cmake_minimum_required(VERSION 3.2) # There is no C per se in WDT but if you use CXX only here many checks fail # Version is Major.Minor.YYMMDDX for up to 10 releases per day # Minor currently is also the protocol version - has to match with Protocol.cpp -project("WDT" LANGUAGES C CXX VERSION 1.21.1510120) +project("WDT" LANGUAGES C CXX VERSION 1.22.1510210) # On MacOS this requires the latest (master) CMake (and/or CMake 3.1.1/3.2) set(CMAKE_CXX_STANDARD 11) @@ -79,8 +79,13 @@ ErrorCodes.cpp FileByteSource.cpp FileCreator.cpp Protocol.cpp +WdtThread.cpp +ThreadsController.cpp +ReceiverThread.cpp Receiver.cpp Reporting.cpp +ThreadTransferHistory.cpp +SenderThread.cpp Sender.cpp ServerSocket.cpp SocketUtils.cpp diff --git a/ErrorCodes.h b/ErrorCodes.h index 1eef6654..7d7fd90a 100644 --- a/ErrorCodes.h +++ b/ErrorCodes.h @@ -9,7 +9,7 @@ #pragma once #include - +#include #include namespace facebook { diff --git a/FileCreator.cpp b/FileCreator.cpp index 2495335e..725d3a66 100644 --- a/FileCreator.cpp +++ b/FileCreator.cpp @@ -164,8 +164,7 @@ int FileCreator::openExistingFile(const string &relPathStr) { WDT_CHECK(relPathStr[0] != '/'); WDT_CHECK(relPathStr.back() != '/'); - string path(rootDir_); - path.append(relPathStr); + const string path = rootDir_ + relPathStr; int openFlags = O_WRONLY; START_PERF_TIMER @@ -184,8 +183,7 @@ int FileCreator::createFile(const string &relPathStr) { CHECK(relPathStr[0] != '/'); CHECK(relPathStr.back() != '/'); - std::string path(rootDir_); - path.append(relPathStr); + const string path = rootDir_ + relPathStr; int p = relPathStr.size(); while (p && relPathStr[p - 1] != '/') { diff --git a/Protocol.cpp b/Protocol.cpp index 018e60b9..f25a0a0d 100644 --- a/Protocol.cpp +++ b/Protocol.cpp @@ -49,7 +49,8 @@ int Protocol::negotiateProtocol(int requestedProtocolVersion, } std::ostream &operator<<(std::ostream &os, const Checkpoint &checkpoint) { - os << "num-blocks: " << checkpoint.numBlocks + os << "checkpoint-port: " << checkpoint.port + << "num-blocks: " << checkpoint.numBlocks << " seq-id: " << checkpoint.lastBlockSeqId << " block-offset: " << checkpoint.lastBlockOffset << " received-bytes: " << checkpoint.lastBlockReceivedBytes; diff --git a/README.md b/README.md index 0352759f..5de55502 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,10 @@ caller get the mutable object of options and set different options accordingly. When wdt is run in a standalone mode, behavior is changed through gflags in wdtCmdLine.cpp +* WdtThread.{h|cpp} +Common functionality and settings between SenderThread and ReceiverThread. +Both of these kind of threads inherit from this base class. + * WdtBase.{h|cpp} Common functionality and settings between Sender and Receiver @@ -156,11 +160,20 @@ directory, sorted by decreasing size (as they are discovered, you can start pulling from the queue even before all the files are found, it will return the current largest file) +* ThreadTransferHistory.{h|cpp} -* Sender.{h|cpp} +Every thread maintains a transfer history so that when a connection breaks +it can talk to the receiver to find out up to where in the history has been +sent. This class encapsulates all the logic for that bookkeeping + +* SenderThread.{h|cpp} + +Implements the functionality of one sender thread, which binds to a certain port +and sends files over. -Formerly wdtlib.cpp - main code sending files +* Sender.{h|cpp} +Spawns multiple SenderThread threads and sends the data across to receiver ### Consuming / Receiving @@ -168,10 +181,15 @@ Formerly wdtlib.cpp - main code sending files Creates file and directories necessary for said file (mkdir -p like) -* Receiver.{h|cpp} +* ReceiverThread.{h|cpp} -Formerly wdtlib.cpp - main code receiving files +Implements the funcionality of the receiver threads, responsible for listening on +a port and receiving files over the network. + +* Receiver.{h|cpp} +Parent receiver class that spawns multiple ReceiverThread threads and receives +data from a remote host ### Low level building blocks diff --git a/Receiver.cpp b/Receiver.cpp index 0934e186..48cf5196 100644 --- a/Receiver.cpp +++ b/Receiver.cpp @@ -10,7 +10,6 @@ #include "ServerSocket.h" #include "FileWriter.h" #include "SocketUtils.h" -#include "DirectorySourceQueue.h" #include #include @@ -29,82 +28,10 @@ using std::vector; namespace facebook { namespace wdt { - -const static int kTimeoutBufferMillis = 1000; -const static int kWaitTimeoutFactor = 5; - -std::ostream &operator<<(std::ostream &os, const Receiver::ThreadData &data) { - os << "Thread[" << data.threadIndex_ << ", port: " << data.socket_.getPort() - << "] "; - return os; -} - -int64_t readAtLeast(ServerSocket &s, char *buf, int64_t max, int64_t atLeast, - int64_t len) { - VLOG(4) << "readAtLeast len " << len << " max " << max << " atLeast " - << atLeast << " from " << s.getFd(); - CHECK_GE(len, 0); - CHECK_GT(atLeast, 0); - CHECK_LE(atLeast, max); - int count = 0; - while (len < atLeast) { - // because we want to process data as soon as it arrives, tryFull option for - // read is false - int64_t n = s.read(buf + len, max - len, false); - if (n < 0) { - PLOG(ERROR) << "Read error on " << s.getPort() << " after " << count; - if (len) { - return len; - } else { - return n; - } - } - if (n == 0) { - VLOG(2) << "Eof on " << s.getPort() << " after " << count << " reads " - << "got " << len; - return len; - } - len += n; - count++; - } - VLOG(3) << "Took " << count << " reads to get " << len - << " from fd : " << s.getFd(); - return len; -} - -int64_t readAtMost(ServerSocket &s, char *buf, int64_t max, int64_t atMost) { - const int64_t target = atMost < max ? atMost : max; - VLOG(3) << "readAtMost target " << target; - // because we want to process data as soon as it arrives, tryFull option for - // read is false - int64_t n = s.read(buf, target, false); - if (n < 0) { - PLOG(ERROR) << "Read error on " << s.getPort() << " with target " << target; - return n; - } - if (n == 0) { - LOG(WARNING) << "Eof on " << s.getFd(); - return n; - } - VLOG(3) << "readAtMost " << n << " / " << atMost << " from " << s.getFd(); - return n; -} - -const Receiver::StateFunction Receiver::stateMap_[] = { - &Receiver::listen, &Receiver::acceptFirstConnection, - &Receiver::acceptWithTimeout, &Receiver::sendLocalCheckpoint, - &Receiver::readNextCmd, &Receiver::processFileCmd, - &Receiver::processSettingsCmd, &Receiver::processDoneCmd, - &Receiver::processSizeCmd, &Receiver::sendFileChunks, - &Receiver::sendGlobalCheckpoint, &Receiver::sendDoneCmd, - &Receiver::sendAbortCmd, &Receiver::waitForFinishOrNewCheckpoint, - &Receiver::waitForFinishWithThreadError}; - void Receiver::addCheckpoint(Checkpoint checkpoint) { LOG(INFO) << "Adding global checkpoint " << checkpoint.port << " " << checkpoint.numBlocks << " " << checkpoint.lastBlockReceivedBytes; checkpoints_.emplace_back(checkpoint); - conditionAllFinished_.notify_all(); } std::vector Receiver::getNewCheckpoints(int startIndex) { @@ -126,11 +53,7 @@ Receiver::Receiver(const WdtTransferRequest &transferRequest) { } setProtocolVersion(transferRequest.protocolVersion); setDir(transferRequest.directory); - const auto &options = WdtOptions::get(); - for (int32_t portNum : transferRequest.ports) { - threadServerSockets_.emplace_back(portNum, options.backlog, - &abortCheckerCallback_); - } + ports_ = transferRequest.ports; } Receiver::Receiver(int port, int numSockets, const std::string &destDir) @@ -151,32 +74,78 @@ void Receiver::traverseDestinationDir( return; } -WdtTransferRequest Receiver::init() { - vector successfulSockets; - for (size_t i = 0; i < threadServerSockets_.size(); i++) { - ServerSocket socket = std::move(threadServerSockets_[i]); - int max_retries = WdtOptions::get().max_retries; - for (int retries = 0; retries < max_retries; retries++) { - if (socket.listen() == OK) { - break; - } +void Receiver::startNewGlobalSession(const std::string &peerIp) { + if (throttler_) { + // If throttler is configured/set then register this session + // in the throttler. This is guranteed to work in either of the + // modes long running or not. We will de register from the throttler + // when the current session ends + throttler_->registerTransfer(); + } + startTime_ = Clock::now(); + const auto &options = WdtOptions::get(); + if (options.enable_download_resumption) { + bool verifySuccessful = transferLogManager_.verifySenderIp(peerIp); + if (!verifySuccessful) { + fileChunksInfo_.clear(); } - if (socket.listen() == OK) { - successfulSockets.push_back(std::move(socket)); - } else { - LOG(ERROR) << "Couldn't listen on port " << socket.getPort(); + } + hasNewTransferStarted_.store(true); + LOG(INFO) << "Starting new transfer, peerIp " << peerIp << " , transfer id " + << transferId_; +} + +bool Receiver::hasNewTransferStarted() const { + return hasNewTransferStarted_.load(); +} + +void Receiver::endCurGlobalSession() { + LOG(INFO) << "Ending the transfer " << transferId_; + if (throttler_) { + throttler_->deRegisterTransfer(); + } + checkpoints_.clear(); + fileCreator_->clearAllocationMap(); + // TODO might consider moving closing the transfer log here + hasNewTransferStarted_.store(false); +} + +WdtTransferRequest Receiver::init() { + const auto &options = WdtOptions::get(); + backlog_ = options.backlog; + bufferSize_ = options.buffer_size; + if (bufferSize_ < Protocol::kMaxHeader) { + // round up to even k + bufferSize_ = 2 * 1024 * ((Protocol::kMaxHeader - 1) / (2 * 1024) + 1); + LOG(INFO) << "Specified -buffer_size " << options.buffer_size + << " smaller than " << Protocol::kMaxHeader << " using " + << bufferSize_ << " instead"; + } + auto numThreads = ports_.size(); + fileCreator_.reset( + new FileCreator(destDir_, numThreads, transferLogManager_)); + threadsController_ = new ThreadsController(numThreads); + threadsController_->setNumFunnels(ReceiverThread::NUM_FUNNELS); + threadsController_->setNumBarriers(ReceiverThread::NUM_BARRIERS); + threadsController_->setNumConditions(ReceiverThread::NUM_CONDITIONS); + receiverThreads_ = threadsController_->makeThreads( + this, ports_.size(), ports_); + size_t numSuccessfulInitThreads = 0; + for (auto &receiverThread : receiverThreads_) { + ErrorCode code = receiverThread->init(); + if (code == OK) { + ++numSuccessfulInitThreads; } } - LOG(INFO) << "Registered " << successfulSockets.size() << " sockets"; + LOG(INFO) << "Registered " << numSuccessfulInitThreads + << " successful sockets"; ErrorCode code = OK; - if (threadServerSockets_.size() != successfulSockets.size()) { + if (numSuccessfulInitThreads != ports_.size()) { code = FEWER_PORTS; - if (successfulSockets.size() == 0) { + if (numSuccessfulInitThreads == 0) { code = ERROR; } } - threadServerSockets_ = std::move(successfulSockets); - // TODO: mutate input request instead, post validation WdtTransferRequest transferRequest(getPorts()); transferRequest.protocolVersion = protocolVersion_; transferRequest.transferId = transferId_; @@ -191,7 +160,7 @@ WdtTransferRequest Receiver::init() { code = ERROR; } } - transferRequest.directory = destDir_; + transferRequest.directory = getDir(); transferRequest.errorCode = code; return transferRequest; } @@ -201,6 +170,14 @@ void Receiver::setDir(const std::string &destDir) { transferLogManager_.setRootDir(destDir_); } +TransferLogManager &Receiver::getTransferLogManager() { + return transferLogManager_; +} + +std::unique_ptr &Receiver::getFileCreator() { + return fileCreator_; +} + const std::string &Receiver::getDir() { return destDir_; } @@ -221,12 +198,16 @@ Receiver::~Receiver() { vector Receiver::getPorts() const { vector ports; - for (const auto &socket : threadServerSockets_) { - ports.push_back(socket.getPort()); + for (const auto &receiverThread : receiverThreads_) { + ports.push_back(receiverThread->getPort()); } return ports; } +const std::vector &Receiver::getFileChunksInfo() const { + return fileChunksInfo_; +} + int64_t Receiver::getTransferConfig() const { auto &options = WdtOptions::get(); int64_t config = 0; @@ -265,10 +246,9 @@ std::unique_ptr Receiver::finish() { LOG(WARNING) << "The receiver is not joinable. The threads will never" << " finish and this method will never return"; } - for (size_t i = 0; i < receiverThreads_.size(); i++) { - receiverThreads_[i].join(); + for (auto &receiverThread : receiverThreads_) { + receiverThread->finish(); } - // A very important step to mark the transfer finished // No other transferAsync, or runForever can be called on this // instance unless the current transfer has finished @@ -279,50 +259,37 @@ std::unique_ptr Receiver::finish() { progressTrackerThread_.join(); } std::unique_ptr report = getTransferReport(); - + auto &summary = report->getSummary(); bool transferSuccess = (report->getSummary().getCombinedErrorCode() == OK); fixAndCloseTransferLog(transferSuccess); - - if (progressReporter_ && totalSenderBytes_ >= 0) { - report->setTotalFileSize(totalSenderBytes_); + auto totalSenderBytes = summary.getTotalSenderBytes(); + if (progressReporter_ && totalSenderBytes >= 0) { + report->setTotalFileSize(totalSenderBytes); report->setTotalTime(durationSeconds(Clock::now() - startTime_)); progressReporter_->end(report); } if (options.enable_perf_stat_collection) { PerfStatReport globalPerfReport; - for (auto &perfReport : perfReports_) { - globalPerfReport += perfReport; + for (auto &receiverThread : receiverThreads_) { + globalPerfReport += receiverThread->getPerfReport(); } LOG(INFO) << globalPerfReport; } LOG(WARNING) << "WDT receiver's transfer has been finished"; LOG(INFO) << *report; - receiverThreads_.clear(); - threadServerSockets_.clear(); - threadStats_.clear(); areThreadsJoined_ = true; return report; } std::unique_ptr Receiver::getTransferReport() { - std::unique_ptr report = - folly::make_unique(threadStats_); - const TransferStats &summary = report->getSummary(); - - if (numBlocksSend_ == -1 || numBlocksSend_ != summary.getNumBlocks()) { - // either none of the threads finished properly or not all of the blocks - // were transferred - report->setErrorCode(ERROR); - } else if (totalSenderBytes_ != -1 && - totalSenderBytes_ != summary.getEffectiveDataBytes()) { - // did not receive all the bytes - LOG(ERROR) << "Number of bytes sent and received do not match " - << totalSenderBytes_ << " " << summary.getEffectiveDataBytes(); - report->setErrorCode(ERROR); - } else { - report->setErrorCode(OK); + TransferStats globalStats; + for (const auto &receiverThread : receiverThreads_) { + globalStats += receiverThread->getTransferStats(); } + globalStats.validate(); + std::unique_ptr report = + folly::make_unique(std::move(globalStats)); return report; } @@ -385,11 +352,10 @@ void Receiver::progressTracker() { std::chrono::time_point lastUpdateTime = Clock::now(); int intervalsSinceLastUpdate = 0; double currentThroughput = 0; - LOG(INFO) << "Progress reporter updating every " << progressReportIntervalMillis << " ms"; auto waitingTime = std::chrono::milliseconds(progressReportIntervalMillis); - int64_t totalSenderBytes; + int64_t totalSenderBytes = -1; while (true) { { std::unique_lock lock(mutex_); @@ -397,14 +363,18 @@ void Receiver::progressTracker() { if (transferFinished_ || getCurAbortCode() != OK) { break; } - if (totalSenderBytes_ == -1) { - continue; - } - totalSenderBytes = totalSenderBytes_; } double totalTime = durationSeconds(Clock::now() - startTime_); + TransferStats globalStats; + for (const auto &receiverThread : receiverThreads_) { + globalStats += receiverThread->getTransferStats(); + } + totalSenderBytes = globalStats.getTotalSenderBytes(); + if (totalSenderBytes == -1) { + continue; + } auto transferReport = folly::make_unique( - threadStats_, totalTime, totalSenderBytes); + std::move(globalStats), totalTime, totalSenderBytes); intervalsSinceLastUpdate++; if (intervalsSinceLastUpdate >= throughputUpdateInterval) { auto curTime = Clock::now(); @@ -417,7 +387,6 @@ void Receiver::progressTracker() { intervalsSinceLastUpdate = 0; } transferReport->setCurrentThroughput(currentThroughput); - progressReporter_->progress(transferReport); } } @@ -428,21 +397,11 @@ void Receiver::start() { LOG(WARNING) << "There is an existing transfer in progress on this object"; } areThreadsJoined_ = false; - numActiveThreads_ = threadServerSockets_.size(); LOG(INFO) << "Starting (receiving) server on ports [ " << getPorts() << "] Target dir : " << destDir_; markTransferFinished(false); const auto &options = WdtOptions::get(); - int64_t bufferSize = options.buffer_size; - if (bufferSize < Protocol::kMaxHeader) { - // round up to even k - bufferSize = 2 * 1024 * ((Protocol::kMaxHeader - 1) / (2 * 1024) + 1); - LOG(INFO) << "Specified -buffer_size " << options.buffer_size - << " smaller than " << Protocol::kMaxHeader << " using " - << bufferSize << " instead"; - } - fileCreator_.reset(new FileCreator(destDir_, threadServerSockets_.size(), - transferLogManager_)); + // TODO do the init stuff here if (options.enable_download_resumption) { WDT_CHECK(!options.skip_writes) << "Can not skip transfers with download resumption turned on"; @@ -458,21 +417,27 @@ void Receiver::start() { traverseDestinationDir(fileChunksInfo_); } } - perfReports_.resize(threadServerSockets_.size()); - const int64_t numSockets = threadServerSockets_.size(); - for (int64_t i = 0; i < numSockets; i++) { - threadStats_.emplace_back(true); - } if (!throttler_) { configureThrottler(); } else { LOG(INFO) << "Throttler set externally. Throttler : " << *throttler_; } - - for (int64_t i = 0; i < numSockets; i++) { - receiverThreads_.emplace_back(&Receiver::receiveOne, this, i, - std::ref(threadServerSockets_[i]), bufferSize, - std::ref(threadStats_[i])); + while (true) { + for (auto &receiverThread : receiverThreads_) { + receiverThread->startThread(); + } + if (!isJoinable_) { + // If it is long running mode, finish the threads + // processing the current transfer and re spawn them again + // with the same sockets + for (auto &receiverThread : receiverThreads_) { + receiverThread->finish(); + receiverThread->reset(); + } + threadsController_->reset(); + continue; + } + break; } if (isJoinable_) { if (progressReporter_) { @@ -483,780 +448,6 @@ void Receiver::start() { } } -bool Receiver::areAllThreadsFinished(bool checkpointAdded) { - const int64_t numSockets = threadServerSockets_.size(); - bool finished = (failedThreadCount_ + waitingThreadCount_ + - waitingWithErrorThreadCount_) == numSockets; - if (checkpointAdded) { - // The thread has added a global checkpoint. So, - // even if all the threads are waiting, the session does no end. However, - // if all the threads are waiting with an error, then we must end the - // session. because none of the waiting threads can send the global - // checkpoint back to the sender - finished &= (waitingThreadCount_ == 0); - } - return finished; -} - -void Receiver::endCurGlobalSession() { - WDT_CHECK(transferFinishedCount_ + 1 == transferStartedCount_); - LOG(INFO) << "Received done for all threads. Transfer session " - << transferStartedCount_ << " finished"; - if (throttler_) { - throttler_->deRegisterTransfer(); - } - transferFinishedCount_++; - waitingThreadCount_ = 0; - waitingWithErrorThreadCount_ = 0; - checkpoints_.clear(); - fileCreator_->clearAllocationMap(); - conditionAllFinished_.notify_all(); -} - -void Receiver::incrFailedThreadCountAndCheckForSessionEnd(ThreadData &data) { - std::unique_lock lock(mutex_); - failedThreadCount_++; - if (areAllThreadsFinished(false) && - transferStartedCount_ > transferFinishedCount_) { - endCurGlobalSession(); - } -} - -bool Receiver::hasNewSessionStarted(ThreadData &data) { - bool started = transferStartedCount_ > data.transferStartedCount_; - if (started) { - WDT_CHECK(transferStartedCount_ == data.transferStartedCount_ + 1); - } - return started; -} - -void Receiver::startNewGlobalSession(ThreadData &data) { - WDT_CHECK(transferStartedCount_ == transferFinishedCount_); - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - if (throttler_) { - // If throttler is configured/set then register this session - // in the throttler. This is guaranteed to work in either of the - // modes long running or not. We will de register from the throttler - // when the current session ends - throttler_->registerTransfer(); - } - transferStartedCount_++; - startTime_ = Clock::now(); - - if (options.enable_download_resumption) { - bool verifySuccessful = - transferLogManager_.verifySenderIp(socket.getPeerIp()); - if (!verifySuccessful) { - fileChunksInfo_.clear(); - } - } - - LOG(INFO) << "New transfer started " << transferStartedCount_; -} - -bool Receiver::hasCurSessionFinished(ThreadData &data) { - return transferFinishedCount_ > data.transferFinishedCount_; -} - -void Receiver::startNewThreadSession(ThreadData &data) { - WDT_CHECK(data.transferStartedCount_ == data.transferFinishedCount_); - data.transferStartedCount_++; -} - -void Receiver::endCurThreadSession(ThreadData &data) { - WDT_CHECK(data.transferStartedCount_ == data.transferFinishedCount_ + 1); - data.transferFinishedCount_++; -} - -/***LISTEN STATE***/ -Receiver::ReceiverState Receiver::listen(ThreadData &data) { - VLOG(1) << data << " entered LISTEN state "; - const auto &options = WdtOptions::get(); - const bool doActualWrites = !options.skip_writes; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - - int32_t port = socket.getPort(); - VLOG(1) << "Server Thread for port " << port << " with backlog " - << socket.getBackLog() << " on " << destDir_ - << " writes= " << doActualWrites; - for (int i = 1; i < options.max_retries; ++i) { - ErrorCode code = socket.listen(); - if (code == OK) { - break; - } else if (code == CONN_ERROR) { - threadStats.setErrorCode(code); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - LOG(INFO) << "Sleeping after failed attempt " << i; - /* sleep override */ - usleep(options.sleep_millis * 1000); - } - // one more/last try (stays true if it worked above) - if (socket.listen() != OK) { - LOG(ERROR) << "Unable to listen/bind despite retries"; - threadStats.setErrorCode(CONN_ERROR); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - return ACCEPT_FIRST_CONNECTION; -} - -/***ACCEPT_FIRST_CONNECTION***/ -Receiver::ReceiverState Receiver::acceptFirstConnection(ThreadData &data) { - VLOG(1) << data << " entered ACCEPT_FIRST_CONNECTION state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &curConnectionVerified = data.curConnectionVerified_; - - data.reset(); - socket.closeCurrentConnection(); - auto timeout = options.accept_timeout_millis; - int acceptAttempts = 0; - while (true) { - { - std::lock_guard lock(mutex_); - if (hasNewSessionStarted(data)) { - startNewThreadSession(data); - return ACCEPT_WITH_TIMEOUT; - } - } - if (isJoinable_ && acceptAttempts == options.max_accept_retries) { - LOG(ERROR) << "unable to accept after " << acceptAttempts << " attempts"; - threadStats.setErrorCode(CONN_ERROR); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - - if (getCurAbortCode() != OK) { - LOG(ERROR) << "Thread marked to abort while trying to accept first" - << " connection. Num attempts " << acceptAttempts; - // Even though there is a transition FAILED here - // getCurAbortCode() is going to be checked again in the receiveOne. - // So this is pretty much irrelevant - return FAILED; - } - - ErrorCode code = - socket.acceptNextConnection(timeout, curConnectionVerified); - if (code == OK) { - break; - } - acceptAttempts++; - } - - std::lock_guard lock(mutex_); - if (!hasNewSessionStarted(data)) { - // this thread has the first connection - startNewGlobalSession(data); - } - startNewThreadSession(data); - return READ_NEXT_CMD; -} - -/***ACCEPT_WITH_TIMEOUT STATE***/ -Receiver::ReceiverState Receiver::acceptWithTimeout(ThreadData &data) { - LOG(INFO) << data << " entered ACCEPT_WITH_TIMEOUT state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &senderWriteTimeout = data.senderWriteTimeout_; - auto &doneSendFailure = data.doneSendFailure_; - auto &curConnectionVerified = data.curConnectionVerified_; - socket.closeCurrentConnection(); - - auto timeout = options.accept_window_millis; - if (senderReadTimeout > 0) { - // transfer is in progress and we have already got sender settings - timeout = - std::max(senderReadTimeout, senderWriteTimeout) + kTimeoutBufferMillis; - } - - ErrorCode code = socket.acceptNextConnection(timeout, curConnectionVerified); - curConnectionVerified = false; - if (code != OK) { - LOG(ERROR) << "accept() failed with timeout " << timeout; - threadStats.setErrorCode(code); - if (doneSendFailure) { - // if SEND_DONE_CMD state had already been reached, we do not need to - // wait for other threads to end - return END; - } - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - if (doneSendFailure) { - // no need to reset any session variables in this case - return SEND_LOCAL_CHECKPOINT; - } - - data.numRead_ = data.off_ = 0; - data.pendingCheckpointIndex_ = data.checkpointIndex_; - ReceiverState nextState = READ_NEXT_CMD; - if (threadStats.getErrorCode() != OK) { - nextState = SEND_LOCAL_CHECKPOINT; - } - // reset thread status - threadStats.setErrorCode(OK); - return nextState; -} - -/***SEND_LOCAL_CHECKPOINT STATE***/ -Receiver::ReceiverState Receiver::sendLocalCheckpoint(ThreadData &data) { - LOG(INFO) << data << " entered SEND_LOCAL_CHECKPOINT state "; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &doneSendFailure = data.doneSendFailure_; - auto &checkpoint = data.checkpoint_; - int32_t protocolVersion = data.threadProtocolVersion_; - char *buf = data.getBuf(); - - std::vector checkpoints; - if (doneSendFailure) { - // in case SEND_DONE failed, a special checkpoint(-1) is sent to signal this - // condition - Checkpoint localCheckpoint(socket.getPort()); - localCheckpoint.numBlocks = -1; - checkpoints.emplace_back(localCheckpoint); - } else { - checkpoints.emplace_back(checkpoint); - } - - int64_t off = 0; - const int checkpointLen = - Protocol::getMaxLocalCheckpointLength(protocolVersion); - Protocol::encodeCheckpoints(protocolVersion, buf, off, checkpointLen, - checkpoints); - int written = socket.write(buf, checkpointLen); - if (written != checkpointLen) { - LOG(ERROR) << "unable to write local checkpoint. write mismatch " - << checkpointLen << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(checkpointLen); - if (doneSendFailure) { - return SEND_DONE_CMD; - } - return READ_NEXT_CMD; -} - -/***READ_NEXT_CMD***/ -Receiver::ReceiverState Receiver::readNextCmd(ThreadData &data) { - VLOG(1) << data << " entered READ_NEXT_CMD state "; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto bufferSize = data.bufferSize_; - - oldOffset = off; - numRead = readAtLeast(socket, buf + off, bufferSize - off, - Protocol::kMinBufLength, numRead); - if (numRead < Protocol::kMinBufLength) { - LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " - << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[off++]; - if (cmd == Protocol::DONE_CMD) { - return PROCESS_DONE_CMD; - } - if (cmd == Protocol::FILE_CMD) { - return PROCESS_FILE_CMD; - } - if (cmd == Protocol::SETTINGS_CMD) { - return PROCESS_SETTINGS_CMD; - } - if (cmd == Protocol::SIZE_CMD) { - return PROCESS_SIZE_CMD; - } - LOG(ERROR) << "received an unknown cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; -} - -/***PROCESS_SETTINGS_CMD***/ -Receiver::ReceiverState Receiver::processSettingsCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_SETTINGS_CMD state "; - char *buf = data.getBuf(); - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto &numRead = data.numRead_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &senderWriteTimeout = data.senderWriteTimeout_; - auto &threadStats = data.threadStats_; - auto &enableChecksum = data.enableChecksum_; - auto &threadProtocolVersion = data.threadProtocolVersion_; - auto &curConnectionVerified = data.curConnectionVerified_; - auto &isBlockMode = data.isBlockMode_; - Settings settings; - int senderProtocolVersion; - - bool success = Protocol::decodeVersion( - buf, off, oldOffset + Protocol::kMaxVersion, senderProtocolVersion); - if (!success) { - LOG(ERROR) << "Unable to decode version " << data.threadIndex_; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - if (senderProtocolVersion != threadProtocolVersion) { - LOG(ERROR) << "Receiver and sender protocol version mismatch " - << senderProtocolVersion << " " << threadProtocolVersion; - int negotiatedProtocol = Protocol::negotiateProtocol(senderProtocolVersion, - threadProtocolVersion); - if (negotiatedProtocol == 0) { - LOG(WARNING) << "Can not support sender with version " - << senderProtocolVersion << ", aborting!"; - threadStats.setErrorCode(VERSION_INCOMPATIBLE); - return SEND_ABORT_CMD; - } else { - LOG_IF(INFO, threadProtocolVersion != negotiatedProtocol) - << "Changing receiver protocol version to " << negotiatedProtocol; - threadProtocolVersion = negotiatedProtocol; - if (negotiatedProtocol != senderProtocolVersion) { - threadStats.setErrorCode(VERSION_MISMATCH); - return SEND_ABORT_CMD; - } - } - } - - success = Protocol::decodeSettings( - threadProtocolVersion, buf, off, - oldOffset + Protocol::kMaxVersion + Protocol::kMaxSettings, settings); - if (!success) { - LOG(ERROR) << "Unable to decode settings cmd " << data.threadIndex_; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - auto senderId = settings.transferId; - if (transferId_ != senderId) { - LOG(ERROR) << "Receiver and sender id mismatch " << senderId << " " - << transferId_; - threadStats.setErrorCode(ID_MISMATCH); - return SEND_ABORT_CMD; - } - senderReadTimeout = settings.readTimeoutMillis; - senderWriteTimeout = settings.writeTimeoutMillis; - enableChecksum = settings.enableChecksum; - isBlockMode = !settings.blockModeDisabled; - curConnectionVerified = true; - - if (settings.sendFileChunks) { - // We only move to SEND_FILE_CHUNKS state, if download resumption is enabled - // in the sender side - numRead = off = 0; - return SEND_FILE_CHUNKS; - } - auto msgLen = off - oldOffset; - numRead -= msgLen; - return READ_NEXT_CMD; -} - -/***PROCESS_FILE_CMD***/ -Receiver::ReceiverState Receiver::processFileCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_FILE_CMD state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadIndex = data.threadIndex_; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto bufferSize = data.bufferSize_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &enableChecksum = data.enableChecksum_; - auto &protocolVersion = data.threadProtocolVersion_; - auto &checkpoint = data.checkpoint_; - auto &isBlockMode = data.isBlockMode_; - - // following block needs to be executed for the first file cmd. There is no - // harm in executing it more than once. number of blocks equal to 0 is a good - // approximation for first file cmd. Did not want to introduce another boolean - if (options.enable_download_resumption && threadStats.getNumBlocks() == 0) { - std::lock_guard lock(mutex_); - if (sendChunksStatus_ != SENT) { - // sender is not in resumption mode - addTransferLogHeader(isBlockMode, /* sender not resuming */ false); - sendChunksStatus_ = SENT; - } - } - - checkpoint.resetLastBlockDetails(); - BlockDetails blockDetails; - - auto guard = folly::makeGuard([&socket, &threadStats] { - if (threadStats.getErrorCode() != OK) { - threadStats.incrFailedAttempts(); - } - }); - - ErrorCode transferStatus = (ErrorCode)buf[off++]; - if (transferStatus != OK) { - // TODO: use this status information to implement fail fast mode - VLOG(1) << "sender entered into error state " - << errorCodeToStr(transferStatus); - } - int16_t headerLen = folly::loadUnaligned(buf + off); - headerLen = folly::Endian::little(headerLen); - VLOG(2) << "Processing FILE_CMD, header len " << headerLen; - - if (headerLen > numRead) { - int64_t end = oldOffset + numRead; - numRead = - readAtLeast(socket, buf + end, bufferSize - end, headerLen, numRead); - } - if (numRead < headerLen) { - LOG(ERROR) << "Unable to read full header " << headerLen << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - off += sizeof(int16_t); - bool success = Protocol::decodeHeader(protocolVersion, buf, off, - numRead + oldOffset, blockDetails); - int64_t headerBytes = off - oldOffset; - // transferred header length must match decoded header length - WDT_CHECK_EQ(headerLen, headerBytes) << " " << blockDetails.fileName << " " - << blockDetails.seqId << " " - << protocolVersion; - threadStats.addHeaderBytes(headerBytes); - if (!success) { - LOG(ERROR) << "Error decoding at" - << " ooff:" << oldOffset << " off: " << off - << " numRead: " << numRead; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - // received a well formed file cmd, apply the pending checkpoint update - checkpointIndex = pendingCheckpointIndex; - VLOG(1) << "Read id:" << blockDetails.fileName - << " size:" << blockDetails.dataSize << " ooff:" << oldOffset - << " off: " << off << " numRead: " << numRead; - - FileWriter writer(threadIndex, &blockDetails, fileCreator_.get()); - auto writtenGuard = folly::makeGuard([&] { - if (protocolVersion >= Protocol::CHECKPOINT_OFFSET_VERSION) { - // considering partially written block contents as valid, this bypasses - // checksum verification - // TODO: Make sure checksum verification work with checkpoint offsets - checkpoint.setLastBlockDetails(blockDetails.seqId, blockDetails.offset, - writer.getTotalWritten()); - threadStats.addEffectiveBytes(headerBytes, writer.getTotalWritten()); - } - }); - - if (writer.open() != OK) { - threadStats.setErrorCode(FILE_WRITE_ERROR); - return SEND_ABORT_CMD; - } - int32_t checksum = 0; - int64_t remainingData = numRead + oldOffset - off; - int64_t toWrite = remainingData; - WDT_CHECK(remainingData >= 0); - if (remainingData >= blockDetails.dataSize) { - toWrite = blockDetails.dataSize; - } - threadStats.addDataBytes(toWrite); - if (enableChecksum) { - checksum = folly::crc32c((const uint8_t *)(buf + off), toWrite, checksum); - } - if (throttler_) { - // We might be reading more than we require for this file but - // throttling should make sense for any additional bytes received - // on the network - throttler_->limit(toWrite + headerBytes); - } - ErrorCode code = writer.write(buf + off, toWrite); - if (code != OK) { - threadStats.setErrorCode(code); - return SEND_ABORT_CMD; - } - off += toWrite; - remainingData -= toWrite; - // also means no leftOver so it's ok we use buf from start - while (writer.getTotalWritten() < blockDetails.dataSize) { - if (getCurAbortCode() != OK) { - LOG(ERROR) << "Thread marked for abort while processing a file." - << " port : " << socket.getPort(); - return FAILED; - } - int64_t nres = readAtMost(socket, buf, bufferSize, - blockDetails.dataSize - writer.getTotalWritten()); - if (nres <= 0) { - break; - } - if (throttler_) { - // We only know how much we have read after we are done calling - // readAtMost. Call throttler with the bytes read off the wire. - throttler_->limit(nres); - } - threadStats.addDataBytes(nres); - if (enableChecksum) { - checksum = folly::crc32c((const uint8_t *)buf, nres, checksum); - } - code = writer.write(buf, nres); - if (code != OK) { - threadStats.setErrorCode(code); - return SEND_ABORT_CMD; - } - } - if (writer.getTotalWritten() != blockDetails.dataSize) { - // This can only happen if there are transmission errors - // Write errors to disk are already taken care of above - LOG(ERROR) << "could not read entire content for " << blockDetails.fileName - << " port " << socket.getPort(); - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - writtenGuard.dismiss(); - VLOG(2) << "completed " << blockDetails.fileName << " off: " << off - << " numRead: " << numRead; - // Transfer of the file is complete here, mark the bytes effective - WDT_CHECK(remainingData >= 0) << "Negative remainingData " << remainingData; - if (remainingData > 0) { - // if we need to read more anyway, let's move the data - numRead = remainingData; - if ((remainingData < Protocol::kMaxHeader) && (off > (bufferSize / 2))) { - // rare so inefficient is ok - VLOG(3) << "copying extra " << remainingData << " leftover bytes @ " - << off; - memmove(/* dst */ buf, - /* from */ buf + off, - /* how much */ remainingData); - off = 0; - } else { - // otherwise just continue from the offset - VLOG(3) << "Using remaining extra " << remainingData - << " leftover bytes starting @ " << off; - } - } else { - numRead = off = 0; - } - if (enableChecksum) { - // have to read footer cmd - oldOffset = off; - numRead = readAtLeast(socket, buf + off, bufferSize - off, - Protocol::kMinBufLength, numRead); - if (numRead < Protocol::kMinBufLength) { - LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " - << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[off++]; - if (cmd != Protocol::FOOTER_CMD) { - LOG(ERROR) << "Expecting footer cmd, but received " << cmd; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - int32_t receivedChecksum; - bool success = Protocol::decodeFooter( - buf, off, oldOffset + Protocol::kMaxFooter, receivedChecksum); - if (!success) { - LOG(ERROR) << "Unable to decode footer cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - if (checksum != receivedChecksum) { - LOG(ERROR) << "Checksum mismatch " << checksum << " " << receivedChecksum - << " port " << socket.getPort() << " file " - << blockDetails.fileName; - threadStats.setErrorCode(CHECKSUM_MISMATCH); - return ACCEPT_WITH_TIMEOUT; - } - int64_t msgLen = off - oldOffset; - numRead -= msgLen; - } - if (options.isLogBasedResumption()) { - transferLogManager_.addBlockWriteEntry( - blockDetails.seqId, blockDetails.offset, blockDetails.dataSize); - } - threadStats.addEffectiveBytes(headerBytes, blockDetails.dataSize); - threadStats.incrNumBlocks(); - checkpoint.incrNumBlocks(); - return READ_NEXT_CMD; -} - -Receiver::ReceiverState Receiver::processDoneCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_DONE_CMD state "; - auto &numRead = data.numRead_; - auto &threadStats = data.threadStats_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - int protocolVersion = data.threadProtocolVersion_; - char *buf = data.getBuf(); - - if (numRead != Protocol::kMinBufLength) { - LOG(ERROR) << "Unexpected state for done command" - << " off: " << off << " numRead: " << numRead; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - ErrorCode senderStatus = (ErrorCode)buf[off++]; - bool success; - { - std::lock_guard lock(mutex_); - success = Protocol::decodeDone(protocolVersion, buf, off, - oldOffset + Protocol::kMaxDone, - numBlocksSend_, totalSenderBytes_); - } - if (!success) { - LOG(ERROR) << "Unable to decode done cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - threadStats.setRemoteErrorCode(senderStatus); - - // received a valid command, applying pending checkpoint write update - checkpointIndex = pendingCheckpointIndex; - return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; -} - -Receiver::ReceiverState Receiver::processSizeCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_SIZE_CMD state "; - auto &threadStats = data.threadStats_; - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - char *buf = data.getBuf(); - std::lock_guard lock(mutex_); - bool success = Protocol::decodeSize(buf, off, oldOffset + Protocol::kMaxSize, - totalSenderBytes_); - if (!success) { - LOG(ERROR) << "Unable to decode size cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - VLOG(1) << "Number of bytes to receive " << totalSenderBytes_; - auto msgLen = off - oldOffset; - numRead -= msgLen; - return READ_NEXT_CMD; -} - -Receiver::ReceiverState Receiver::sendFileChunks(ThreadData &data) { - LOG(INFO) << data << " entered SEND_FILE_CHUNKS state "; - char *buf = data.getBuf(); - auto bufferSize = data.bufferSize_; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &isBlockMode = data.isBlockMode_; - int64_t toWrite; - int64_t written; - std::unique_lock lock(mutex_); - while (true) { - switch (sendChunksStatus_) { - case SENT: { - lock.unlock(); - buf[0] = Protocol::ACK_CMD; - toWrite = 1; - written = socket.write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(toWrite); - return READ_NEXT_CMD; - } - case IN_PROGRESS: { - lock.unlock(); - buf[0] = Protocol::WAIT_CMD; - toWrite = 1; - written = socket.write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(toWrite); - WDT_CHECK(senderReadTimeout > 0); // must have received settings - int timeoutMillis = senderReadTimeout / kWaitTimeoutFactor; - auto waitingTime = std::chrono::milliseconds(timeoutMillis); - lock.lock(); - conditionFileChunksSent_.wait_for(lock, waitingTime); - continue; - } - case NOT_STARTED: { - // This thread has to send file chunks - sendChunksStatus_ = IN_PROGRESS; - lock.unlock(); - auto guard = folly::makeGuard([&] { - lock.lock(); - sendChunksStatus_ = NOT_STARTED; - conditionFileChunksSent_.notify_one(); - }); - int64_t off = 0; - buf[off++] = Protocol::CHUNKS_CMD; - const int64_t numParsedChunksInfo = fileChunksInfo_.size(); - Protocol::encodeChunksCmd(buf, off, bufferSize, numParsedChunksInfo); - written = socket.write(buf, off); - if (written > 0) { - threadStats.addHeaderBytes(written); - } - if (written != off) { - LOG(ERROR) << "Socket write error " << off << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - int64_t numEntriesWritten = 0; - // we try to encode as many chunks as possible in the buffer. If a - // single - // chunk can not fit in the buffer, it is ignored. Format of encoding : - // ... - while (numEntriesWritten < numParsedChunksInfo) { - off = sizeof(int32_t); - int64_t numEntriesEncoded = Protocol::encodeFileChunksInfoList( - buf, off, bufferSize, numEntriesWritten, fileChunksInfo_); - int32_t dataSize = folly::Endian::little(off - sizeof(int32_t)); - folly::storeUnaligned(buf, dataSize); - written = socket.write(buf, off); - if (written > 0) { - threadStats.addHeaderBytes(written); - } - if (written != off) { - break; - } - numEntriesWritten += numEntriesEncoded; - } - if (numEntriesWritten != numParsedChunksInfo) { - LOG(ERROR) << "Could not write all the file chunks " - << numParsedChunksInfo << " " << numEntriesWritten; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - // try to read ack - int64_t toRead = 1; - int64_t numRead = socket.read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - guard.dismiss(); - lock.lock(); - sendChunksStatus_ = SENT; - addTransferLogHeader(isBlockMode, /* sender resuming */ true); - conditionFileChunksSent_.notify_all(); - return READ_NEXT_CMD; - } - } - } -} - void Receiver::addTransferLogHeader(bool isBlockMode, bool isSenderResuming) { const auto &options = WdtOptions::get(); if (!options.enable_download_resumption) { @@ -1309,244 +500,5 @@ void Receiver::fixAndCloseTransferLog(bool transferSuccess) { transferLogManager_.unlink(); } } - -Receiver::ReceiverState Receiver::sendGlobalCheckpoint(ThreadData &data) { - LOG(INFO) << data << " entered SEND_GLOBAL_CHECKPOINTS state "; - char *buf = data.getBuf(); - auto &off = data.off_; - auto &newCheckpoints = data.newCheckpoints_; - auto &socket = data.socket_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &threadStats = data.threadStats_; - auto &numRead = data.numRead_; - auto bufferSize = data.bufferSize_; - int32_t protocolVersion = data.threadProtocolVersion_; - - buf[0] = Protocol::ERR_CMD; - off = 1; - // leave space for length - off += sizeof(int16_t); - auto oldOffset = off; - Protocol::encodeCheckpoints(protocolVersion, buf, off, bufferSize, - newCheckpoints); - int16_t length = off - oldOffset; - folly::storeUnaligned(buf + 1, folly::Endian::little(length)); - - auto written = socket.write(buf, off); - if (written != off) { - LOG(ERROR) << "unable to write error checkpoints"; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } else { - threadStats.addHeaderBytes(off); - pendingCheckpointIndex = checkpointIndex + newCheckpoints.size(); - numRead = off = 0; - return READ_NEXT_CMD; - } -} - -Receiver::ReceiverState Receiver::sendAbortCmd(ThreadData &data) { - LOG(INFO) << data << " entered SEND_ABORT_CMD state "; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &socket = data.socket_; - int32_t protocolVersion = data.threadProtocolVersion_; - int64_t offset = 0; - buf[offset++] = Protocol::ABORT_CMD; - Protocol::encodeAbort(buf, offset, protocolVersion, - threadStats.getErrorCode(), threadStats.getNumFiles()); - socket.write(buf, offset); - // No need to check if we were successful in sending ABORT - // This thread will simply disconnect and sender thread on the - // other side will timeout - socket.closeCurrentConnection(); - threadStats.addHeaderBytes(offset); - if (threadStats.getErrorCode() == VERSION_MISMATCH) { - // Receiver should try again expecting sender to have changed its version - return ACCEPT_WITH_TIMEOUT; - } - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; -} - -Receiver::ReceiverState Receiver::sendDoneCmd(ThreadData &data) { - VLOG(1) << data << " entered SEND_DONE_CMD state "; - char *buf = data.getBuf(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &doneSendFailure = data.doneSendFailure_; - - buf[0] = Protocol::DONE_CMD; - if (socket.write(buf, 1) != 1) { - PLOG(ERROR) << "unable to send DONE " << data.threadIndex_; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - - threadStats.addHeaderBytes(1); - - auto read = socket.read(buf, 1); - if (read != 1 || buf[0] != Protocol::DONE_CMD) { - LOG(ERROR) << data << " did not receive ack for DONE"; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - - read = socket.read(buf, Protocol::kMinBufLength); - if (read != 0) { - LOG(ERROR) << data << " EOF not found where expected"; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - socket.closeCurrentConnection(); - LOG(INFO) << data << " got ack for DONE. Transfer finished"; - return END; -} - -Receiver::ReceiverState Receiver::waitForFinishWithThreadError( - ThreadData &data) { - LOG(INFO) << data << " entered WAIT_FOR_FINISH_WITH_THREAD_ERROR state "; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - auto &checkpoint = data.checkpoint_; - // should only be in this state if there is some error - WDT_CHECK(threadStats.getErrorCode() != OK); - - // close the socket, so that sender receives an error during connect - socket.closeAll(); - - std::unique_lock lock(mutex_); - addCheckpoint(checkpoint); - waitingWithErrorThreadCount_++; - - if (areAllThreadsFinished(true)) { - endCurGlobalSession(); - } else { - // wait for session end - while (!hasCurSessionFinished(data)) { - conditionAllFinished_.wait(lock); - } - } - endCurThreadSession(data); - return END; -} - -Receiver::ReceiverState Receiver::waitForFinishOrNewCheckpoint( - ThreadData &data) { - VLOG(1) << data << " entered WAIT_FOR_FINISH_OR_NEW_CHECKPOINT state "; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &checkpointIndex = data.checkpointIndex_; - auto &newCheckpoints = data.newCheckpoints_; - char *buf = data.getBuf(); - auto &socket = data.socket_; - // should only be called if there are no errors - WDT_CHECK(threadStats.getErrorCode() == OK); - - std::unique_lock lock(mutex_); - // we have to check for checkpoints before checking to see if session ended or - // not. because if some checkpoints have not been sent back to the sender, - // session should not end - newCheckpoints = getNewCheckpoints(checkpointIndex); - if (!newCheckpoints.empty()) { - return SEND_GLOBAL_CHECKPOINTS; - } - - waitingThreadCount_++; - if (areAllThreadsFinished(false)) { - endCurGlobalSession(); - endCurThreadSession(data); - return SEND_DONE_CMD; - } - - // we must send periodic wait cmd to keep the sender thread alive - while (true) { - WDT_CHECK(senderReadTimeout > 0); // must have received settings - int timeoutMillis = senderReadTimeout / kWaitTimeoutFactor; - auto waitingTime = std::chrono::milliseconds(timeoutMillis); - START_PERF_TIMER - conditionAllFinished_.wait_for(lock, waitingTime); - RECORD_PERF_RESULT(PerfStatReport::RECEIVER_WAIT_SLEEP) - - // check if transfer finished or not - if (hasCurSessionFinished(data)) { - endCurThreadSession(data); - return SEND_DONE_CMD; - } - - // check to see if any new checkpoints were added - newCheckpoints = getNewCheckpoints(checkpointIndex); - if (!newCheckpoints.empty()) { - waitingThreadCount_--; - return SEND_GLOBAL_CHECKPOINTS; - } - - // must unlock because socket write could block for long time, as long as - // the write timeout, which is 5sec by default - lock.unlock(); - - // send WAIT cmd to keep sender thread alive - buf[0] = Protocol::WAIT_CMD; - if (socket.write(buf, 1) != 1) { - PLOG(ERROR) << data << " unable to write WAIT "; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - lock.lock(); - // we again have to check if the session has finished or not. while - // writing WAIT cmd, some other thread could have ended the session, so - // going back to ACCEPT_WITH_TIMEOUT state would be wrong - if (!hasCurSessionFinished(data)) { - waitingThreadCount_--; - return ACCEPT_WITH_TIMEOUT; - } - endCurThreadSession(data); - return END; - } - threadStats.addHeaderBytes(1); - lock.lock(); - } -} - -void Receiver::receiveOne(int threadIndex, ServerSocket &socket, - int64_t bufferSize, TransferStats &threadStats) { - INIT_PERF_STAT_REPORT - auto guard = folly::makeGuard([&] { - perfReports_[threadIndex] = *perfStatReport; // copy when done - std::unique_lock lock(mutex_); - numActiveThreads_--; - if (numActiveThreads_ == 0) { - LOG(WARNING) << "Last thread finished. Duration of the transfer " - << durationSeconds(Clock::now() - startTime_); - transferFinished_ = true; - } - }); - ThreadData data(threadIndex, socket, threadStats, protocolVersion_, - bufferSize); - if (!data.getBuf()) { - LOG(ERROR) << "error allocating " << bufferSize; - threadStats.setErrorCode(MEMORY_ALLOCATION_ERROR); - return; - } - ReceiverState state = LISTEN; - while (true) { - ErrorCode abortCode = getCurAbortCode(); - if (abortCode != OK) { - LOG(ERROR) << "Transfer aborted " << socket.getPort() << " " - << errorCodeToStr(abortCode); - threadStats.setErrorCode(ABORT); - incrFailedThreadCountAndCheckForSessionEnd(data); - break; - } - if (state == FAILED) { - return; - } - if (state == END) { - if (isJoinable_) { - return; - } - state = ACCEPT_FIRST_CONNECTION; - } - state = (this->*stateMap_[state])(data); - } -} } } // namespace facebook::wdt diff --git a/Receiver.h b/Receiver.h index a24cc779..a9aa6f7b 100644 --- a/Receiver.h +++ b/Receiver.h @@ -17,7 +17,9 @@ #include "Protocol.h" #include "Writer.h" #include "Throttler.h" +#include "ReceiverThread.h" #include "TransferLogManager.h" +#include "ThreadsController.h" #include #include #include @@ -98,7 +100,9 @@ class Receiver : public WdtBase { */ std::vector getPorts() const; - private: + protected: + friend class ReceiverThread; + /** * @param isFinished Mark transfer active/inactive */ @@ -111,321 +115,18 @@ class Receiver : public WdtBase { */ void traverseDestinationDir(std::vector &fileChunksInfo); - /** - * Wdt receiver has logic to maintain the consistency of the - * transfers through connection errors. All threads are run by the logic - * defined as a state machine. These are the all the states in that - * state machine - */ - enum ReceiverState { - LISTEN, - ACCEPT_FIRST_CONNECTION, - ACCEPT_WITH_TIMEOUT, - SEND_LOCAL_CHECKPOINT, - READ_NEXT_CMD, - PROCESS_FILE_CMD, - PROCESS_SETTINGS_CMD, - PROCESS_DONE_CMD, - PROCESS_SIZE_CMD, - SEND_FILE_CHUNKS, - SEND_GLOBAL_CHECKPOINTS, - SEND_DONE_CMD, - SEND_ABORT_CMD, - WAIT_FOR_FINISH_OR_NEW_CHECKPOINT, - WAIT_FOR_FINISH_WITH_THREAD_ERROR, - FAILED, - END - }; - - /** - * Structure to pass data to the state machine and also to share data between - * different state - */ - struct ThreadData { - /// Index of the thread that this data belongs to - const int threadIndex_; - - /** - * Server socket object that provides functionality such as listen() - * accept, read, write on the socket - */ - ServerSocket &socket_; - - /// Statistics of the transfer for this thread - TransferStats &threadStats_; - - /// protocol version for this thread. This per thread protocol version is - /// kept separately from the global one to avoid locking - int threadProtocolVersion_; - - /// Buffer that receivers reads data into from the network - std::unique_ptr buf_; - /// Maximum size of the buffer - const int64_t bufferSize_; - - /// Marks the number of bytes already read in the buffer - int64_t numRead_{0}; - - /// Following two are markers to mark how much data has been read/parsed - int64_t off_{0}; - int64_t oldOffset_{0}; - - /// number of checkpoints already transferred - int checkpointIndex_{0}; - - /** - * Pending value of checkpoint count. Since write call success does not - * guarantee actual transfer, we do not apply checkpoint count update after - * the write. Only after receiving next cmd from sender, we apply the - * update - */ - int pendingCheckpointIndex_{0}; - - /// a counter incremented each time a new session starts for this thread - int64_t transferStartedCount_{0}; - - /// a counter incremented each time a new session ends for this thread - int64_t transferFinishedCount_{0}; - - /// read timeout for sender - int64_t senderReadTimeout_{-1}; - - /// write timeout for sender - int64_t senderWriteTimeout_{-1}; - - /// whether checksum verification is enabled or not - bool enableChecksum_{false}; - - /** - * Whether SEND_DONE_CMD state has already failed for this session or not. - * This has to be separately handled, because session barrier is - * implemented before sending done cmd - */ - bool doneSendFailure_{false}; - - /// Checkpoints that have not been sent back to the sender - std::vector newCheckpoints_; - - /// whether settings has been received and verified for the current - /// connection. This is used to determine round robin order for polling in - /// the server socket - bool curConnectionVerified_{false}; - - /// whether the transfer is in block mode or not - bool isBlockMode_{true}; - - Checkpoint checkpoint_; - - /// Constructor for thread data - ThreadData(int threadIndex, ServerSocket &socket, - TransferStats &threadStats, int protocolVersion, - int64_t bufferSize) - : threadIndex_(threadIndex), - socket_(socket), - threadStats_(threadStats), - threadProtocolVersion_(protocolVersion), - bufferSize_(bufferSize), - checkpoint_(socket.getPort()) { - buf_.reset(new char[bufferSize_]); - } - - /** - * In long running mode, we need to reset thread variables after each - * session. Before starting each session, reset() has to called to do that. - */ - void reset() { - numRead_ = off_ = 0; - checkpointIndex_ = pendingCheckpointIndex_ = 0; - doneSendFailure_ = false; - senderReadTimeout_ = senderWriteTimeout_ = -1; - curConnectionVerified_ = false; - threadStats_.reset(); - } - - /// Get the raw pointer to the buffer - char *getBuf() { - return buf_.get(); - } - }; - - /// Overloaded operator for printing thread info - friend std::ostream &operator<<(std::ostream &os, const ThreadData &data); - - typedef ReceiverState (Receiver::*StateFunction)(ThreadData &data); + /// Get the transferred file chunks info + const std::vector &getFileChunksInfo() const; - /** - * Tries to listen/bind to port. If this fails, thread is considered failed. - * Previous states : n/a (start state) - * Next states : ACCEPT_FIRST_CONNECTION(success), - * FAILED(failure) - */ - ReceiverState listen(ThreadData &data); - /** - * Tries to accept first connection of a new session. Periodically checks - * whether a new session has started or not. If a new session has started then - * goes to ACCEPT_WITH_TIMEOUT state. Also does session initialization. In - * joinable mode, tries to accept for a limited number of user specified - * retries. - * Previous states : LISTEN, - * END(if in long running mode) - * Next states : ACCEPT_WITH_TIMEOUT(if a new transfer has started and this - * thread has not received a connection), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(if did not receive a - * connection in specified number of retries), - * READ_NEXT_CMD(if a connection was received) - */ - ReceiverState acceptFirstConnection(ThreadData &data); - /** - * Tries to accept a connection with timeout. There are 2 kinds of timeout. At - * the beginning of the session, it uses accept window as the timeout. Later - * when sender settings are known it uses max(readTimeOut, writeTimeout)) + - * buffer(500) as the timeout. - * Previous states : Almost all states(for any network errors during transfer, - * we transition to this state), - * Next states : READ_NEXT_CMD(if there are no previous errors and accept - * was successful), - * SEND_LOCAL_CHECKPOINT(if there were previous errors and - * accept was successful), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(if accept failed and - * transfer previously failed during SEND_DONE_CMD state. Thus - * case needs special handling to ensure that we do not mess up - * session variables), - * END(if accept fails otherwise) - */ - ReceiverState acceptWithTimeout(ThreadData &data); - /** - * Sends local checkpoint to the sender. In case of previous error during - * SEND_LOCAL_CHECKPOINT state, we send -1 as the checkpoint. - * Previous states : ACCEPT_WITH_TIMEOUT - * Next states : ACCEPT_WITH_TIMEOUT(if sending fails), - * SEND_DONE_CMD(if send is successful and we have previous - * SEND_DONE_CMD error), - * READ_NEXT_CMD(if send is successful otherwise) - */ - ReceiverState sendLocalCheckpoint(ThreadData &data); - /** - * Reads next cmd and transitions to the state accordingly. - * Previous states : SEND_LOCAL_CHECKPOINT, - * ACCEPT_FIRST_CONNECTION, - * ACCEPT_WITH_TIMEOUT, - * PROCESS_SETTINGS_CMD, - * PROCESS_FILE_CMD, - * SEND_GLOBAL_CHECKPOINTS, - * Next states : PROCESS_FILE_CMD, - * PROCESS_DONE_CMD, - * PROCESS_SETTINGS_CMD, - * PROCESS_SIZE_CMD, - * ACCEPT_WITH_TIMEOUT(in case of read failure), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(in case of protocol errors) - */ - ReceiverState readNextCmd(ThreadData &data); - /** - * Processes file cmd. Logic of how we write the file to the destination - * directory is defined here. - * Previous states : READ_NEXT_CMD - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error), - * ACCEPT_WITH_TIMEOUT(socket read failure) - */ - ReceiverState processFileCmd(ThreadData &data); - /** - * Processes settings cmd. Settings has a connection settings, - * protocol version, transfer id, etc. For more info check Protocol.h - * Previous states : READ_NEXT_CMD, - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error), - * ACCEPT_WITH_TIMEOUT(socket read failure), - * SEND_FILE_CHUNKS(If the sender wants to resume transfer) - */ - ReceiverState processSettingsCmd(ThreadData &data); - /** - * Processes done cmd. Also checks to see if there are any new global - * checkpoints or not - * Previous states : READ_NEXT_CMD, - * Next states : WAIT_FOR_FINISH_OR_ERROR(protocol error), - * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT(success), - * SEND_GLOBAL_CHECKPOINTS(if there are global errors) - */ - ReceiverState processDoneCmd(ThreadData &data); - /** - * Processes size cmd. Sets the value of totalSenderBytes_ - * Previous states : READ_NEXT_CMD, - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error) - */ - ReceiverState processSizeCmd(ThreadData &data); - /** - * Sends file chunks that were received successfully in any previous transfer, - * this is the first step in download resumption. - * Checks to see if they have already been transferred or not. - * If yes, send ACK. If some other thread is sending it, sends wait cmd - * and checks again later. Otherwise, breaks the entire data into bufferSIze_ - * chunks and sends it. - * Previous states: PROCESS_SETTINGS_CMD, - * Next states : ACCEPT_WITH_TIMEOUT(network error), - * READ_NEXT_CMD(success) - */ - ReceiverState sendFileChunks(ThreadData &data); - /** - * Sends global checkpoints to sender - * Previous states : PROCESS_DONE_CMD, - * WAIT_FOR_FINISH_OR_ERROR - * Next states : READ_NEXT_CMD(success), - * ACCEPT_WITH_TIMEOUT(socket write failure) - */ - ReceiverState sendGlobalCheckpoint(ThreadData &data); - /** - * Sends DONE to sender, also tries to read back ack. If anything fails during - * this state, doneSendFailure_ thread variable is set. This flag makes the - * state machine behave differently, effectively bypassing all session related - * things. - * Previous states : SEND_LOCAL_CHECKPOINT, - * WAIT_FOR_FINISH_OR_ERROR - * Next states : END(success), - * ACCEPT_WITH_TIMEOUT(failure) - */ - ReceiverState sendDoneCmd(ThreadData &data); - - /** - * Sends ABORT cmd back to the sender - * Previous states : PROCESS_FILE_CMD - * Next states : WAIT_FOR_FINISH_WITH_THREAD_ERROR - */ - ReceiverState sendAbortCmd(ThreadData &data); - - /** - * Waits for transfer to finish or new checkpoints. This state first - * increments waitingThreadCount_. Then, it - * waits till all the threads have finished. It sends periodic WAIT signal to - * prevent sender from timing out. If a new checkpoint is found, we move to - * SEND_GLOBAL_CHECKPOINTS state. - * Previous states : PROCESS_DONE_CMD - * Next states : SEND_DONE_CMD(all threads finished), - * SEND_GLOBAL_CHECKPOINTS(if new checkpoints are found), - * ACCEPT_WITH_TIMEOUT(if socket write fails) - */ - ReceiverState waitForFinishOrNewCheckpoint(ThreadData &data); - - /** - * Waits for transfer to finish. Only called when there is an error for the - * thread. It adds a checkpoint to the global list of checkpoints if a - * connection was received. It increments waitingWithErrorThreadCount_ and - * waits till the session ends. - * Previous states : Almost all states - * Next states : END - */ - ReceiverState waitForFinishWithThreadError(ThreadData &data); + /// Get file creator, used by receiver threads + std::unique_ptr &getFileCreator(); - /// Mapping from receiver states to state functions - static const StateFunction stateMap_[]; + /// Get the ref to transfer log manager + TransferLogManager &getTransferLogManager(); /// Responsible for basic setup and starting threads void start(); - /// This method is the entry point for each thread. - void receiveOne(int threadIndex, ServerSocket &s, int64_t bufferSize, - TransferStats &threadStats); - /** * Periodically calculates current transfer report and send it to progress * reporter. This only works in the single transfer mode. @@ -445,45 +146,20 @@ class Receiver : public WdtBase { */ std::vector getNewCheckpoints(int startIndex); - /// Returns true if all threads finished for this session - bool areAllThreadsFinished(bool checkpointAdded); - - /// Ends current global session - void endCurGlobalSession(); - - /** - * Returns if a new session has started and the thread is not aware of it - * A thread must hold lock on mutex_ before calling this - */ - bool hasNewSessionStarted(ThreadData &data); + /// Does the steps needed before a new transfer is started + void startNewGlobalSession(const std::string &peerIp); - /** - * Start new transfer by incrementing transferStartedCount_ - * A thread must hold lock on mutex_ before calling this - */ - void startNewGlobalSession(ThreadData &data); + /// Returns true if at least one thread has accepted connection + bool hasNewTransferStarted() const; - /** - * Returns whether the current session has finished or not. - * A thread must hold lock on mutex_ before calling this - */ - bool hasCurSessionFinished(ThreadData &data); - - /// Starts a new session for the thread - void startNewThreadSession(ThreadData &data); - - /// Ends current thread session - void endCurThreadSession(ThreadData &data); - - /// Increments failed thread count, does not wait for transfer to finish - void incrFailedThreadCountAndCheckForSessionEnd(ThreadData &data); + /// Has steps to do when the current transfer is ended + void endCurGlobalSession(); /// adds log header and also a directory invalidation entry if needed void addTransferLogHeader(bool isBlockMode, bool isSenderResuming); /// fix and close transfer log void fixAndCloseTransferLog(bool transferSuccess); - /** * Get transfer report, meant to be called after threads have been finished * This method is not thread safe @@ -495,6 +171,7 @@ class Receiver : public WdtBase { /// The thread that is responsible for calling running the progress tracker std::thread progressTrackerThread_; + /** * Flags that represents if a transfer has finished. Threads on completion * set this flag. This is always accurate even if you don't call finish() @@ -509,7 +186,7 @@ class Receiver : public WdtBase { std::string destDir_; /// Responsible for writing files on the disk - std::unique_ptr fileCreator_; + std::unique_ptr fileCreator_{nullptr}; /** * Unique-id used to verify transfer log. This value must be same for @@ -531,64 +208,14 @@ class Receiver : public WdtBase { * it has to be made sure that these threads are joined at least before * the destruction of this object. */ - std::vector receiverThreads_; - - /** - * start() gives each thread the instance of the serversocket, these - * sockets can be closed and changed completely by the progress tracker - * thread thus ending any hope of receiver threads doing any further - * successful transfer - */ - std::vector threadServerSockets_; - - /** - * Bunch of stats objects given to each thread by the root thread - * so that finish() can summarize the result at the end of joining. - */ - std::vector threadStats_; - - /// Per thread perf report - std::vector perfReports_; + std::vector> receiverThreads_; /// Transfer log manager TransferLogManager transferLogManager_; - /// Enum representing status of file chunks transfer - enum SendChunkStatus { NOT_STARTED, IN_PROGRESS, SENT }; - - /// State of the receiver when sending file chunks in sendFileChunksCmd - SendChunkStatus sendChunksStatus_{NOT_STARTED}; - - /** - * All threads coordinate with each other to send previously received file - * chunks using this condition variable. - */ - mutable std::condition_variable conditionFileChunksSent_; - - /// Number of blocks sent by the sender - int64_t numBlocksSend_{-1}; - /// Global list of checkpoints std::vector checkpoints_; - /// Number of threads which failed in the transfer - int failedThreadCount_{0}; - - /// Number of threads which are waiting for finish or new checkpoint - int waitingThreadCount_{0}; - - /// Number of threads which are waiting with an error - int waitingWithErrorThreadCount_{0}; - - /// Counter that is incremented each time a new session starts - int64_t transferStartedCount_{0}; - - /// Counter that is incremented each time a new session ends - int64_t transferFinishedCount_{0}; - - /// Total number of data bytes sender wants to transfer - int64_t totalSenderBytes_{-1}; - /// Start time of the session std::chrono::time_point startTime_; @@ -598,8 +225,8 @@ class Receiver : public WdtBase { /// Mutex to guard all the shared variables mutable std::mutex mutex_; - /// Condition variable to coordinate transfer finish - mutable std::condition_variable conditionAllFinished_; + /// Marks when a new transfer has started + std::atomic hasNewTransferStarted_{false}; /** * Returns true if threads have been joined (done in finish()) @@ -607,14 +234,17 @@ class Receiver : public WdtBase { */ bool areThreadsJoined_{false}; - /// Number of active threads, decremented every time a thread is finished - int32_t numActiveThreads_{0}; - /** * Mutex for the management of this instance, specifically to keep the * instance sane for multi threaded public API calls */ std::mutex instanceManagementMutex_; + + /// Buffer size used by this receiver + int64_t bufferSize_; + + /// Backlog used by the sockets + int backlog_; }; } } // namespace facebook::wdt diff --git a/ReceiverThread.cpp b/ReceiverThread.cpp new file mode 100644 index 00000000..902b72c6 --- /dev/null +++ b/ReceiverThread.cpp @@ -0,0 +1,905 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ReceiverThread.h" +#include "FileWriter.h" +#include +#include +#include +#include +#include +#include + +namespace facebook { +namespace wdt { + +const static int kTimeoutBufferMillis = 1000; +const static int kWaitTimeoutFactor = 5; +std::ostream &operator<<(std::ostream &os, + const ReceiverThread &receiverThread) { + os << "Thread[" << receiverThread.threadIndex_ + << ", port: " << receiverThread.socket_.getPort() << "] "; + return os; +} + +int64_t readAtLeast(ServerSocket &s, char *buf, int64_t max, int64_t atLeast, + int64_t len) { + VLOG(4) << "readAtLeast len " << len << " max " << max << " atLeast " + << atLeast << " from " << s.getFd(); + CHECK_GE(len, 0); + CHECK_GT(atLeast, 0); + CHECK_LE(atLeast, max); + int count = 0; + while (len < atLeast) { + // because we want to process data as soon as it arrives, tryFull option for + // read is false + int64_t n = s.read(buf + len, max - len, false); + if (n < 0) { + PLOG(ERROR) << "Read error on " << s.getPort() << " after " << count; + if (len) { + return len; + } else { + return n; + } + } + if (n == 0) { + VLOG(2) << "Eof on " << s.getPort() << " after " << count << " reads " + << "got " << len; + return len; + } + len += n; + count++; + } + VLOG(3) << "Took " << count << " reads to get " << len + << " from fd : " << s.getFd(); + return len; +} + +int64_t readAtMost(ServerSocket &s, char *buf, int64_t max, int64_t atMost) { + const int64_t target = atMost < max ? atMost : max; + VLOG(3) << "readAtMost target " << target; + // because we want to process data as soon as it arrives, tryFull option for + // read is false + int64_t n = s.read(buf, target, false); + if (n < 0) { + PLOG(ERROR) << "Read error on " << s.getPort() << " with target " << target; + return n; + } + if (n == 0) { + LOG(WARNING) << "Eof on " << s.getFd(); + return n; + } + VLOG(3) << "readAtMost " << n << " / " << atMost << " from " << s.getFd(); + return n; +} + +const ReceiverThread::StateFunction ReceiverThread::stateMap_[] = { + &ReceiverThread::listen, &ReceiverThread::acceptFirstConnection, + &ReceiverThread::acceptWithTimeout, &ReceiverThread::sendLocalCheckpoint, + &ReceiverThread::readNextCmd, &ReceiverThread::processFileCmd, + &ReceiverThread::processSettingsCmd, &ReceiverThread::processDoneCmd, + &ReceiverThread::processSizeCmd, &ReceiverThread::sendFileChunks, + &ReceiverThread::sendGlobalCheckpoint, &ReceiverThread::sendDoneCmd, + &ReceiverThread::sendAbortCmd, + &ReceiverThread::waitForFinishOrNewCheckpoint, + &ReceiverThread::finishWithError}; + +ReceiverThread::ReceiverThread(Receiver *wdtParent, int threadIndex, + int32_t port, ThreadsController *controller) + : WdtThread(threadIndex, wdtParent->getProtocolVersion(), controller), + wdtParent_(wdtParent), + socket_(port, wdtParent->backlog_, &(wdtParent->abortCheckerCallback_)), + bufferSize_(wdtParent->bufferSize_) { + controller_->registerThread(threadIndex_); + buf_ = new char[bufferSize_]; +} + +/**LISTEN STATE***/ +ReceiverState ReceiverThread::listen() { + VLOG(1) << *this << " entered LISTEN state "; + const auto &options = WdtOptions::get(); + const bool doActualWrites = !options.skip_writes; + int32_t port = socket_.getPort(); + VLOG(1) << "Server Thread for port " << port << " with backlog " + << socket_.getBackLog() << " on " << wdtParent_->getDir() + << " writes = " << doActualWrites; + + for (int retry = 1; retry < options.max_retries; ++retry) { + ErrorCode code = socket_.listen(); + if (code == OK) { + break; + } else if (code == CONN_ERROR) { + threadStats_.setErrorCode(code); + return FAILED; + } + LOG(INFO) << "Sleeping after failed attempt " << retry; + /* sleep override */ + usleep(options.sleep_millis * 1000); + } + // one more/last try (stays true if it worked above) + if (socket_.listen() != OK) { + LOG(ERROR) << "Unable to listen/bind despite retries"; + threadStats_.setErrorCode(CONN_ERROR); + return FAILED; + } + return ACCEPT_FIRST_CONNECTION; +} + +/***ACCEPT_FIRST_CONNECTION***/ +ReceiverState ReceiverThread::acceptFirstConnection() { + VLOG(1) << *this << " entered ACCEPT_FIRST_CONNECTION state "; + const auto &options = WdtOptions::get(); + reset(); + socket_.closeCurrentConnection(); + auto timeout = options.accept_timeout_millis; + int acceptAttempts = 0; + while (true) { + // Move to timeout state if some other thread was successful + // in getting a connection + if (wdtParent_->hasNewTransferStarted()) { + return ACCEPT_WITH_TIMEOUT; + } + if (acceptAttempts == options.max_accept_retries) { + LOG(ERROR) << "unable to accept after " << acceptAttempts << " attempts"; + threadStats_.setErrorCode(CONN_ERROR); + return FAILED; + } + if (wdtParent_->getCurAbortCode() != OK) { + LOG(ERROR) << "Thread marked to abort while trying to accept first" + << " connection. Num attempts " << acceptAttempts; + // Even though there is a transition FAILED here + // getCurAbortCode() is going to be checked again in the receiveOne. + // So this is pretty much irrelevant + return FAILED; + } + ErrorCode code = + socket_.acceptNextConnection(timeout, curConnectionVerified_); + if (code == OK) { + break; + } + ++acceptAttempts; + } + // Make the parent start new global session. This is executed + // only by the first thread that calls this function + controller_->executeAtStart( + [&]() { wdtParent_->startNewGlobalSession(socket_.getPeerIp()); }); + return READ_NEXT_CMD; +} + +/***ACCEPT_WITH_TIMEOUT STATE***/ +ReceiverState ReceiverThread::acceptWithTimeout() { + LOG(INFO) << *this << " entered ACCEPT_WITH_TIMEOUT state "; + const auto &options = WdtOptions::get(); + socket_.closeCurrentConnection(); + auto timeout = options.accept_window_millis; + if (senderReadTimeout_ > 0) { + // transfer is in progress and we have already got sender settings + timeout = std::max(senderReadTimeout_, senderWriteTimeout_) + + kTimeoutBufferMillis; + } + ErrorCode code = + socket_.acceptNextConnection(timeout, curConnectionVerified_); + curConnectionVerified_ = false; + if (code != OK) { + LOG(ERROR) << "accept() failed with timeout " << timeout; + threadStats_.setErrorCode(code); + if (doneSendFailure_) { + // if SEND_DONE_CMD state had already been reached, we do not need to + // wait for other threads to end + return END; + } + return FINISH_WITH_ERROR; + } + + if (doneSendFailure_) { + // no need to reset any session variables in this case + return SEND_LOCAL_CHECKPOINT; + } + + numRead_ = off_ = 0; + pendingCheckpointIndex_ = checkpointIndex_; + ReceiverState nextState = READ_NEXT_CMD; + if (threadStats_.getErrorCode() != OK) { + nextState = SEND_LOCAL_CHECKPOINT; + } + // reset thread status + threadStats_.setErrorCode(OK); + return nextState; +} + +/***SEND_LOCAL_CHECKPOINT STATE***/ +ReceiverState ReceiverThread::sendLocalCheckpoint() { + LOG(INFO) << *this << " entered SEND_LOCAL_CHECKPOINT state "; + std::vector checkpoints; + if (doneSendFailure_) { + // in case SEND_DONE failed, a special checkpoint(-1) is sent to signal this + // condition + Checkpoint localCheckpoint(socket_.getPort()); + localCheckpoint.numBlocks = -1; + checkpoints.emplace_back(localCheckpoint); + } else { + checkpoints.emplace_back(checkpoint_); + } + + int64_t off = 0; + const int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + Protocol::encodeCheckpoints(threadProtocolVersion_, buf_, off, checkpointLen, + checkpoints); + int written = socket_.write(buf_, checkpointLen); + if (written != checkpointLen) { + LOG(ERROR) << "unable to write local checkpoint. write mismatch " + << checkpointLen << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(checkpointLen); + if (doneSendFailure_) { + return SEND_DONE_CMD; + } + return READ_NEXT_CMD; +} + +/***READ_NEXT_CMD***/ +ReceiverState ReceiverThread::readNextCmd() { + VLOG(1) << *this << " entered READ_NEXT_CMD state "; + oldOffset_ = off_; + numRead_ = readAtLeast(socket_, buf_ + off_, bufferSize_ - off_, + Protocol::kMinBufLength, numRead_); + if (numRead_ < Protocol::kMinBufLength) { + LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " + << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[off_++]; + if (cmd == Protocol::DONE_CMD) { + return PROCESS_DONE_CMD; + } + if (cmd == Protocol::FILE_CMD) { + return PROCESS_FILE_CMD; + } + if (cmd == Protocol::SETTINGS_CMD) { + return PROCESS_SETTINGS_CMD; + } + if (cmd == Protocol::SIZE_CMD) { + return PROCESS_SIZE_CMD; + } + LOG(ERROR) << "received an unknown cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; +} + +/***PROCESS_SETTINGS_CMD***/ +ReceiverState ReceiverThread::processSettingsCmd() { + VLOG(1) << *this << " entered PROCESS_SETTINGS_CMD state "; + Settings settings; + int senderProtocolVersion; + + bool success = Protocol::decodeVersion( + buf_, off_, oldOffset_ + Protocol::kMaxVersion, senderProtocolVersion); + if (!success) { + LOG(ERROR) << "Unable to decode version " << threadIndex_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + if (senderProtocolVersion != threadProtocolVersion_) { + LOG(ERROR) << "Receiver and sender protocol version mismatch " + << senderProtocolVersion << " " << threadProtocolVersion_; + int negotiatedProtocol = Protocol::negotiateProtocol( + senderProtocolVersion, threadProtocolVersion_); + if (negotiatedProtocol == 0) { + LOG(WARNING) << "Can not support sender with version " + << senderProtocolVersion << ", aborting!"; + threadStats_.setErrorCode(VERSION_INCOMPATIBLE); + return SEND_ABORT_CMD; + } else { + LOG_IF(INFO, threadProtocolVersion_ != negotiatedProtocol) + << "Changing receiver protocol version to " << negotiatedProtocol; + threadProtocolVersion_ = negotiatedProtocol; + if (negotiatedProtocol != senderProtocolVersion) { + threadStats_.setErrorCode(VERSION_MISMATCH); + return SEND_ABORT_CMD; + } + } + } + + success = Protocol::decodeSettings( + threadProtocolVersion_, buf_, off_, + oldOffset_ + Protocol::kMaxVersion + Protocol::kMaxSettings, settings); + if (!success) { + LOG(ERROR) << *this << "Unable to decode settings cmd "; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + auto senderId = settings.transferId; + auto transferId = wdtParent_->getTransferId(); + if (transferId != senderId) { + LOG(ERROR) << "Receiver and sender id mismatch " << senderId << " " + << transferId; + threadStats_.setErrorCode(ID_MISMATCH); + return SEND_ABORT_CMD; + } + senderReadTimeout_ = settings.readTimeoutMillis; + senderWriteTimeout_ = settings.writeTimeoutMillis; + enableChecksum_ = settings.enableChecksum; + isBlockMode_ = !settings.blockModeDisabled; + curConnectionVerified_ = true; + if (settings.sendFileChunks) { + // We only move to SEND_FILE_CHUNKS state, if download resumption is enabled + // in the sender side + numRead_ = off_ = 0; + return SEND_FILE_CHUNKS; + } + auto msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + return READ_NEXT_CMD; +} + +/***PROCESS_FILE_CMD***/ +ReceiverState ReceiverThread::processFileCmd() { + VLOG(1) << *this << " entered PROCESS_FILE_CMD state "; + const auto &options = WdtOptions::get(); + // following block needs to be executed for the first file cmd. There is no + // harm in executing it more than once. number of blocks equal to 0 is a good + // approximation for first file cmd. Did not want to introduce another boolean + if (options.enable_download_resumption && threadStats_.getNumBlocks() == 0) { + auto sendChunksFunnel = controller_->getFunnel(SEND_FILE_CHUNKS_FUNNEL); + auto state = sendChunksFunnel->getStatus(); + if (state == FUNNEL_START) { + // sender is not in resumption mode + wdtParent_->addTransferLogHeader(isBlockMode_, + /* sender not resuming */ false); + sendChunksFunnel->notifySuccess(); + } + } + checkpoint_.resetLastBlockDetails(); + BlockDetails blockDetails; + auto guard = folly::makeGuard([&] { + if (threadStats_.getErrorCode() != OK) { + threadStats_.incrFailedAttempts(); + } + }); + + ErrorCode transferStatus = (ErrorCode)buf_[off_++]; + if (transferStatus != OK) { + // TODO: use this status information to implement fail fast mode + VLOG(1) << "sender entered into error state " + << errorCodeToStr(transferStatus); + } + int16_t headerLen = folly::loadUnaligned(buf_ + off_); + headerLen = folly::Endian::little(headerLen); + VLOG(2) << "Processing FILE_CMD, header len " << headerLen; + + if (headerLen > numRead_) { + int64_t end = oldOffset_ + numRead_; + numRead_ = readAtLeast(socket_, buf_ + end, bufferSize_ - end, headerLen, + numRead_); + } + if (numRead_ < headerLen) { + LOG(ERROR) << "Unable to read full header " << headerLen << " " << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + off_ += sizeof(int16_t); + bool success = Protocol::decodeHeader(threadProtocolVersion_, buf_, off_, + numRead_ + oldOffset_, blockDetails); + int64_t headerBytes = off_ - oldOffset_; + // transferred header length must match decoded header length + WDT_CHECK_EQ(headerLen, headerBytes) << " " << blockDetails.fileName << " " + << blockDetails.seqId << " " + << threadProtocolVersion_; + threadStats_.addHeaderBytes(headerBytes); + if (!success) { + LOG(ERROR) << "Error decoding at" + << " ooff:" << oldOffset_ << " off_: " << off_ + << " numRead_: " << numRead_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + + // received a well formed file cmd, apply the pending checkpoint update + checkpointIndex_ = pendingCheckpointIndex_; + VLOG(1) << "Read id:" << blockDetails.fileName + << " size:" << blockDetails.dataSize << " ooff:" << oldOffset_ + << " off_: " << off_ << " numRead_: " << numRead_; + auto &fileCreator = wdtParent_->getFileCreator(); + FileWriter writer(threadIndex_, &blockDetails, fileCreator.get()); + auto writtenGuard = folly::makeGuard([&] { + if (threadProtocolVersion_ >= Protocol::CHECKPOINT_OFFSET_VERSION) { + // considering partially written block contents as valid, this bypasses + // checksum verification + // TODO: Make sure checksum verification work with checkpoint offsets + checkpoint_.setLastBlockDetails(blockDetails.seqId, blockDetails.offset, + writer.getTotalWritten()); + threadStats_.addEffectiveBytes(headerBytes, writer.getTotalWritten()); + } + }); + if (writer.open() != OK) { + threadStats_.setErrorCode(FILE_WRITE_ERROR); + return SEND_ABORT_CMD; + } + int32_t checksum = 0; + int64_t remainingData = numRead_ + oldOffset_ - off_; + int64_t toWrite = remainingData; + WDT_CHECK(remainingData >= 0); + if (remainingData >= blockDetails.dataSize) { + toWrite = blockDetails.dataSize; + } + threadStats_.addDataBytes(toWrite); + if (enableChecksum_) { + checksum = folly::crc32c((const uint8_t *)(buf_ + off_), toWrite, checksum); + } + auto throttler = wdtParent_->getThrottler(); + if (throttler) { + // We might be reading more than we require for this file but + // throttling should make sense for any additional bytes received + // on the network + throttler->limit(toWrite + headerBytes); + } + ErrorCode code = writer.write(buf_ + off_, toWrite); + if (code != OK) { + threadStats_.setErrorCode(code); + return SEND_ABORT_CMD; + } + off_ += toWrite; + remainingData -= toWrite; + // also means no leftOver so it's ok we use buf_ from start + while (writer.getTotalWritten() < blockDetails.dataSize) { + if (wdtParent_->getCurAbortCode() != OK) { + LOG(ERROR) << "Thread marked for abort while processing a file." + << " port : " << socket_.getPort(); + return FAILED; + } + int64_t nres = readAtMost(socket_, buf_, bufferSize_, + blockDetails.dataSize - writer.getTotalWritten()); + if (nres <= 0) { + break; + } + if (throttler) { + // We only know how much we have read after we are done calling + // readAtMost. Call throttler with the bytes read off_ the wire. + throttler->limit(nres); + } + threadStats_.addDataBytes(nres); + if (enableChecksum_) { + checksum = folly::crc32c((const uint8_t *)buf_, nres, checksum); + } + code = writer.write(buf_, nres); + if (code != OK) { + threadStats_.setErrorCode(code); + return SEND_ABORT_CMD; + } + } + if (writer.getTotalWritten() != blockDetails.dataSize) { + // This can only happen if there are transmission errors + // Write errors to disk are already taken care of above + LOG(ERROR) << "could not read entire content for " << blockDetails.fileName + << " port " << socket_.getPort(); + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + writtenGuard.dismiss(); + VLOG(2) << "completed " << blockDetails.fileName << " off: " << off_ + << " numRead: " << numRead_; + // Transfer of the file is complete here, mark the bytes effective + WDT_CHECK(remainingData >= 0) << "Negative remainingData " << remainingData; + if (remainingData > 0) { + // if we need to read more anyway, let's move the data + numRead_ = remainingData; + if ((remainingData < Protocol::kMaxHeader) && (off_ > (bufferSize_ / 2))) { + // rare so inefficient is ok + VLOG(3) << "copying extra " << remainingData << " leftover bytes @ " + << off_; + memmove(/* dst */ buf_, + /* from */ buf_ + off_, + /* how much */ remainingData); + off_ = 0; + } else { + // otherwise just continue from the offset + VLOG(3) << "Using remaining extra " << remainingData + << " leftover bytes starting @ " << off_; + } + } else { + numRead_ = off_ = 0; + } + if (enableChecksum_) { + // have to read footer cmd + oldOffset_ = off_; + numRead_ = readAtLeast(socket_, buf_ + off_, bufferSize_ - off_, + Protocol::kMinBufLength, numRead_); + if (numRead_ < Protocol::kMinBufLength) { + LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " + << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[off_++]; + if (cmd != Protocol::FOOTER_CMD) { + LOG(ERROR) << "Expecting footer cmd, but received " << cmd; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + int32_t receivedChecksum; + bool success = Protocol::decodeFooter( + buf_, off_, oldOffset_ + Protocol::kMaxFooter, receivedChecksum); + if (!success) { + LOG(ERROR) << "Unable to decode footer cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + if (checksum != receivedChecksum) { + LOG(ERROR) << "Checksum mismatch " << checksum << " " << receivedChecksum + << " port " << socket_.getPort() << " file " + << blockDetails.fileName; + threadStats_.setErrorCode(CHECKSUM_MISMATCH); + return ACCEPT_WITH_TIMEOUT; + } + int64_t msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + } + auto &transferLogManager = wdtParent_->getTransferLogManager(); + if (options.isLogBasedResumption()) { + transferLogManager.addBlockWriteEntry( + blockDetails.seqId, blockDetails.offset, blockDetails.dataSize); + } + threadStats_.addEffectiveBytes(headerBytes, blockDetails.dataSize); + threadStats_.incrNumBlocks(); + checkpoint_.incrNumBlocks(); + return READ_NEXT_CMD; +} + +ReceiverState ReceiverThread::processDoneCmd() { + VLOG(1) << *this << " entered PROCESS_DONE_CMD state "; + if (numRead_ != Protocol::kMinBufLength) { + LOG(ERROR) << "Unexpected state for done command" + << " off_: " << off_ << " numRead_: " << numRead_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + + ErrorCode senderStatus = (ErrorCode)buf_[off_++]; + int64_t numBlocksSend = -1; + int64_t totalSenderBytes = -1; + bool success = Protocol::decodeDone(threadProtocolVersion_, buf_, off_, + oldOffset_ + Protocol::kMaxDone, + numBlocksSend, totalSenderBytes); + if (!success) { + LOG(ERROR) << "Unable to decode done cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + threadStats_.setNumBlocksSend(numBlocksSend); + threadStats_.setTotalSenderBytes(totalSenderBytes); + threadStats_.setRemoteErrorCode(senderStatus); + + // received a valid command, applying pending checkpoint write update + checkpointIndex_ = pendingCheckpointIndex_; + return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; +} + +ReceiverState ReceiverThread::processSizeCmd() { + VLOG(1) << *this << " entered PROCESS_SIZE_CMD state "; + int64_t totalSenderBytes; + bool success = Protocol::decodeSize( + buf_, off_, oldOffset_ + Protocol::kMaxSize, totalSenderBytes); + if (!success) { + LOG(ERROR) << "Unable to decode size cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + VLOG(1) << "Number of bytes to receive " << totalSenderBytes; + threadStats_.setTotalSenderBytes(totalSenderBytes); + auto msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + return READ_NEXT_CMD; +} + +ReceiverState ReceiverThread::sendFileChunks() { + LOG(INFO) << *this << " entered SEND_FILE_CHUNKS state "; + WDT_CHECK(senderReadTimeout_ > 0); // must have received settings + int waitingTimeMillis = senderReadTimeout_ / kWaitTimeoutFactor; + auto execFunnel = controller_->getFunnel(SEND_FILE_CHUNKS_FUNNEL); + while (true) { + auto status = execFunnel->getStatus(); + switch (status) { + case FUNNEL_END: { + buf_[0] = Protocol::ACK_CMD; + int toWrite = 1; + int written = socket_.write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << *this << " socket write error " << toWrite << " " + << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(toWrite); + return READ_NEXT_CMD; + } + case FUNNEL_PROGRESS: { + buf_[0] = Protocol::WAIT_CMD; + int toWrite = 1; + int written = socket_.write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << *this << " socket write error " << toWrite << " " + << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(toWrite); + execFunnel->wait(waitingTimeMillis); + break; + } + case FUNNEL_START: { + int64_t off = 0; + buf_[off++] = Protocol::CHUNKS_CMD; + const auto &fileChunksInfo = wdtParent_->getFileChunksInfo(); + const int64_t numParsedChunksInfo = fileChunksInfo.size(); + Protocol::encodeChunksCmd(buf_, off, bufferSize_, numParsedChunksInfo); + int written = socket_.write(buf_, off); + if (written > 0) { + threadStats_.addHeaderBytes(written); + } + if (written != off) { + LOG(ERROR) << "Socket write error " << off << " " << written; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + int64_t numEntriesWritten = 0; + // we try to encode as many chunks as possible in the buffer. If a + // single + // chunk can not fit in the buffer, it is ignored. Format of encoding : + // ... + while (numEntriesWritten < numParsedChunksInfo) { + off = sizeof(int32_t); + int64_t numEntriesEncoded = Protocol::encodeFileChunksInfoList( + buf_, off, bufferSize_, numEntriesWritten, fileChunksInfo); + int32_t dataSize = folly::Endian::little(off - sizeof(int32_t)); + folly::storeUnaligned(buf_, dataSize); + written = socket_.write(buf_, off); + if (written > 0) { + threadStats_.addHeaderBytes(written); + } + if (written != off) { + break; + } + numEntriesWritten += numEntriesEncoded; + } + if (numEntriesWritten != numParsedChunksInfo) { + LOG(ERROR) << "Could not write all the file chunks " + << numParsedChunksInfo << " " << numEntriesWritten; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + // try to read ack + int64_t toRead = 1; + int64_t numRead = socket_.read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + wdtParent_->addTransferLogHeader(isBlockMode_, + /* sender resuming */ true); + execFunnel->notifySuccess(); + return READ_NEXT_CMD; + } + } + } +} + +ReceiverState ReceiverThread::sendGlobalCheckpoint() { + LOG(INFO) << *this << " entered SEND_GLOBAL_CHECKPOINTS state"; + buf_[0] = Protocol::ERR_CMD; + off_ = 1; + // leave space for length + off_ += sizeof(int16_t); + auto oldOffset = off_; + Protocol::encodeCheckpoints(threadProtocolVersion_, buf_, off_, bufferSize_, + newCheckpoints_); + int16_t length = off_ - oldOffset; + folly::storeUnaligned(buf_ + 1, folly::Endian::little(length)); + + auto written = socket_.write(buf_, off_); + if (written != off_) { + LOG(ERROR) << "unable to write error checkpoints"; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } else { + threadStats_.addHeaderBytes(off_); + pendingCheckpointIndex_ = checkpointIndex_ + newCheckpoints_.size(); + numRead_ = off_ = 0; + return READ_NEXT_CMD; + } +} + +ReceiverState ReceiverThread::sendAbortCmd() { + LOG(INFO) << *this << " entered SEND_ABORT_CMD state "; + int64_t offset = 0; + buf_[offset++] = Protocol::ABORT_CMD; + Protocol::encodeAbort(buf_, offset, threadProtocolVersion_, + threadStats_.getErrorCode(), + threadStats_.getNumFiles()); + socket_.write(buf_, offset); + // No need to check if we were successful in sending ABORT + // This thread will simply disconnect and sender thread on the + // other side will timeout + socket_.closeCurrentConnection(); + threadStats_.addHeaderBytes(offset); + if (threadStats_.getErrorCode() == VERSION_MISMATCH) { + // Receiver should try again expecting sender to have changed its version + return ACCEPT_WITH_TIMEOUT; + } + return FINISH_WITH_ERROR; +} + +ReceiverState ReceiverThread::sendDoneCmd() { + VLOG(1) << *this << " entered SEND_DONE_CMD state "; + buf_[0] = Protocol::DONE_CMD; + if (socket_.write(buf_, 1) != 1) { + PLOG(ERROR) << "unable to send DONE " << threadIndex_; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + + threadStats_.addHeaderBytes(1); + + auto read = socket_.read(buf_, 1); + if (read != 1 || buf_[0] != Protocol::DONE_CMD) { + LOG(ERROR) << *this << " did not receive ack for DONE"; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + + read = socket_.read(buf_, Protocol::kMinBufLength); + if (read != 0) { + LOG(ERROR) << *this << " EOF not found where expected"; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + socket_.closeCurrentConnection(); + LOG(INFO) << *this << " got ack for DONE. Transfer finished"; + return END; +} + +ReceiverState ReceiverThread::finishWithError() { + LOG(INFO) << *this << " entered FINISH_WITH_ERROR state "; + // should only be in this state if there is some error + WDT_CHECK(threadStats_.getErrorCode() != OK); + + // close the socket, so that sender receives an error during connect + socket_.closeAll(); + auto cv = controller_->getCondition(WAIT_FOR_FINISH_OR_CHECKPOINT_CV); + auto guard = cv->acquire(); + wdtParent_->addCheckpoint(checkpoint_); + controller_->markState(threadIndex_, FINISHED); + // guard deletion notifies one thread + return END; +} + +ReceiverState ReceiverThread::checkForFinishOrNewCheckpoints() { + auto checkpoints = wdtParent_->getNewCheckpoints(checkpointIndex_); + if (!checkpoints.empty()) { + newCheckpoints_ = std::move(checkpoints); + controller_->markState(threadIndex_, RUNNING); + return SEND_GLOBAL_CHECKPOINTS; + } + bool existActiveThreads = controller_->hasThreads(threadIndex_, RUNNING); + if (!existActiveThreads) { + controller_->markState(threadIndex_, FINISHED); + return SEND_DONE_CMD; + } + return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; +} + +ReceiverState ReceiverThread::waitForFinishOrNewCheckpoint() { + LOG(INFO) << *this << " entered WAIT_FOR_FINISH_OR_NEW_CHECKPOINT state "; + // should only be called if the are no errors + WDT_CHECK(threadStats_.getErrorCode() == OK); + auto cv = controller_->getCondition(WAIT_FOR_FINISH_OR_CHECKPOINT_CV); + int timeoutMillis = senderReadTimeout_ / kWaitTimeoutFactor; + controller_->markState(threadIndex_, WAITING); + while (true) { + WDT_CHECK(senderReadTimeout_ > 0); // must have received settings + { + auto guard = cv->acquire(); + auto state = checkForFinishOrNewCheckpoints(); + if (state != WAIT_FOR_FINISH_OR_NEW_CHECKPOINT) { + // guard automatically notfies one + return state; + } + START_PERF_TIMER + guard.wait(timeoutMillis); + RECORD_PERF_RESULT(PerfStatReport::RECEIVER_WAIT_SLEEP) + state = checkForFinishOrNewCheckpoints(); + if (state != WAIT_FOR_FINISH_OR_NEW_CHECKPOINT) { + guard.notifyOne(); + return state; + } + } + // send WAIT cmd to keep sender thread alive + buf_[0] = Protocol::WAIT_CMD; + if (socket_.write(buf_, 1) != 1) { + PLOG(ERROR) << *this << " unable to write WAIT "; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + controller_->markState(threadIndex_, RUNNING); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(1); + } +} + +void ReceiverThread::start() { + INIT_PERF_STAT_REPORT + auto guard = folly::makeGuard([&] { + perfReport_ = *perfStatReport; + LOG(INFO) << *this << threadStats_; + controller_->deRegisterThread(threadIndex_); + controller_->executeAtEnd([&]() { wdtParent_->endCurGlobalSession(); }); + }); + if (!buf_) { + LOG(ERROR) << "error allocating " << bufferSize_; + threadStats_.setErrorCode(MEMORY_ALLOCATION_ERROR); + return; + } + ReceiverState state = LISTEN; + while (true) { + ErrorCode abortCode = wdtParent_->getCurAbortCode(); + if (abortCode != OK) { + LOG(ERROR) << "Transfer aborted " << socket_.getPort() << " " + << errorCodeToStr(abortCode); + threadStats_.setErrorCode(ABORT); + break; + } + if (state == FAILED) { + return; + } + if (state == END) { + return; + } + state = (this->*stateMap_[state])(); + } +} + +int32_t ReceiverThread::getPort() const { + return socket_.getPort(); +} + +ErrorCode ReceiverThread::init() { + int max_retries = WdtOptions::get().max_retries; + for (int retries = 0; retries < max_retries; retries++) { + if (socket_.listen() == OK) { + break; + } + } + if (socket_.listen() != OK) { + LOG(ERROR) << *this << "Couldn't listen on port " << socket_.getPort(); + return ERROR; + } + checkpoint_.port = socket_.getPort(); + LOG(INFO) << "Listening on port " << socket_.getPort(); + return OK; +} + +void ReceiverThread::reset() { + numRead_ = off_ = 0; + checkpointIndex_ = pendingCheckpointIndex_ = 0; + doneSendFailure_ = false; + senderReadTimeout_ = senderWriteTimeout_ = -1; + curConnectionVerified_ = false; + threadStats_.reset(); +} + +ReceiverThread::~ReceiverThread() { + delete[] buf_; +} +} +} diff --git a/ReceiverThread.h b/ReceiverThread.h new file mode 100644 index 00000000..9fbc8cc0 --- /dev/null +++ b/ReceiverThread.h @@ -0,0 +1,341 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "WdtBase.h" +#include "WdtThread.h" +#include "Receiver.h" +#include "ServerSocket.h" +namespace facebook { +namespace wdt { +class Receiver; +/** + * Wdt receiver has logic to maintain the consistency of the + * transfers through connection errors. All threads are run by the logic + * defined as a state machine. These are the all the states in that + * state machine + */ +enum ReceiverState { + LISTEN, + ACCEPT_FIRST_CONNECTION, + ACCEPT_WITH_TIMEOUT, + SEND_LOCAL_CHECKPOINT, + READ_NEXT_CMD, + PROCESS_FILE_CMD, + PROCESS_SETTINGS_CMD, + PROCESS_DONE_CMD, + PROCESS_SIZE_CMD, + SEND_FILE_CHUNKS, + SEND_GLOBAL_CHECKPOINTS, + SEND_DONE_CMD, + SEND_ABORT_CMD, + WAIT_FOR_FINISH_OR_NEW_CHECKPOINT, + FINISH_WITH_ERROR, + FAILED, + END +}; + +/** + * This class represents a receiver thread. It contains + * all the logic for a thread to bind on a port and + * receive data from the wdt sender. All the receiver threads + * share modules like threads controller, throttler etc + */ +class ReceiverThread : public WdtThread { + public: + /// Identifiers for the funnels that this thread will use + enum RECEIVER_FUNNELS { SEND_FILE_CHUNKS_FUNNEL, NUM_FUNNELS }; + + /// Identifiers for the condition variable wrappers used in the thread + enum RECEIVER_CONDITIONS { WAIT_FOR_FINISH_OR_CHECKPOINT_CV, NUM_CONDITIONS }; + + /// Identifiers for the barriers used in the thread + enum RECEIVER_BARRIERS { NUM_BARRIERS }; + /** + * Constructor for receiver thread. + * @param wdtParent Pointer back to the parent receiver for meta + * information + * @param threadIndex Every thread is identified by unique index + * @param port Port this thread will listen on + * @param controller Thread controller for all the instances of the + * receiver threads. All the receiver thread objects + * need to share the same instance of the controller + */ + ReceiverThread(Receiver *wdtParent, int threadIndex, int port, + ThreadsController *controller); + + /// Initializes the receiver thread before starting + ErrorCode init() override; + + /** + * In long running mode, we need to reset thread variables after each + * session. Before starting each session, reset() has to called to do that. + */ + void reset() override; + + /// Destructor of Receiver thread + ~ReceiverThread(); + + /// Get the port this receiver thread is listening on + int32_t getPort() const override; + + private: + /// Overloaded operator for printing thread info + friend std::ostream &operator<<(std::ostream &os, + const ReceiverThread &receiverThread); + typedef ReceiverState (ReceiverThread::*StateFunction)(); + + /// Parent shared among all the threads for meta information + Receiver *wdtParent_; + + /** + * Tries to listen/bind to port. If this fails, thread is considered failed. + * Previous states : n/a (start state) + * Next states : ACCEPT_FIRST_CONNECTION(success), + * FAILED(failure) + */ + ReceiverState listen(); + + /** + * Tries to accept first connection of a new session. Periodically checks + * whether a new session has started or not. If a new session has started then + * goes to ACCEPT_WITH_TIMEOUT state. Also does session initialization. In + * joinable mode, tries to accept for a limited number of user specified + * retries. + * Previous states : LISTEN, + * END(if in long running mode) + * Next states : ACCEPT_WITH_TIMEOUT(if a new transfer has started and this + * thread has not received a connection), + * FINISH_WITH_ERROR(if did not receive a + * connection in specified number of retries), + * READ_NEXT_CMD(if a connection was received) + */ + ReceiverState acceptFirstConnection(); + + /** + * Tries to accept a connection with timeout. There are 2 kinds of timeout. At + * the beginning of the session, it uses accept window as the timeout. Later + * when sender settings are known it uses max(readTimeOut, writeTimeout)) + + * buffer(500) as the timeout. + * Previous states : Almost all states(for any network errors during transfer, + * we transition to this state), + * Next states : READ_NEXT_CMD(if there are no previous errors and accept + * was successful), + * SEND_LOCAL_CHECKPOINT(if there were previous errors and + * accept was successful), + * FINISH_WITH_ERROR(if accept failed and + * transfer previously failed during SEND_DONE_CMD state. Thus + * case needs special handling to ensure that we do not mess up + * session variables), + * END(if accept fails otherwise) + */ + ReceiverState acceptWithTimeout(); + /** + * Sends local checkpoint to the sender. In case of previous error during + * SEND_LOCAL_CHECKPOINT state, we send -1 as the checkpoint. + * Previous states : ACCEPT_WITH_TIMEOUT + * Next states : ACCEPT_WITH_TIMEOUT(if sending fails), + * SEND_DONE_CMD(if send is successful and we have previous + * SEND_DONE_CMD error), + * READ_NEXT_CMD(if send is successful otherwise) + */ + ReceiverState sendLocalCheckpoint(); + /** + * Reads next cmd and transitions to the state accordingly. + * Previous states : SEND_LOCAL_CHECKPOINT, + * ACCEPT_FIRST_CONNECTION, + * ACCEPT_WITH_TIMEOUT, + * PROCESS_SETTINGS_CMD, + * PROCESS_FILE_CMD, + * SEND_GLOBAL_CHECKPOINTS, + * Next states : PROCESS_FILE_CMD, + * PROCESS_DONE_CMD, + * PROCESS_SETTINGS_CMD, + * PROCESS_SIZE_CMD, + * ACCEPT_WITH_TIMEOUT(in case of read failure), + * FINISH_WITH_ERROR(in case of protocol errors) + */ + ReceiverState readNextCmd(); + /** + * Processes file cmd. Logic of how we write the file to the destination + * directory is defined here. + * Previous states : READ_NEXT_CMD + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error), + * ACCEPT_WITH_TIMEOUT(socket read failure) + */ + ReceiverState processFileCmd(); + /** + * Processes settings cmd. Settings has a connection settings, + * protocol version, transfer id, etc. For more info check Protocol.h + * Previous states : READ_NEXT_CMD, + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error), + * ACCEPT_WITH_TIMEOUT(socket read failure), + * SEND_FILE_CHUNKS(If the sender wants to resume transfer) + */ + ReceiverState processSettingsCmd(); + /** + * Processes done cmd. Also checks to see if there are any new global + * checkpoints or not + * Previous states : READ_NEXT_CMD, + * Next states : FINISH_WITH_ERROR(protocol error), + * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT(success), + * SEND_GLOBAL_CHECKPOINTS(if there are global errors) + */ + ReceiverState processDoneCmd(); + /** + * Processes size cmd. Sets the value of totalSenderBytes_ + * Previous states : READ_NEXT_CMD, + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error) + */ + ReceiverState processSizeCmd(); + /** + * Sends file chunks that were received successfully in any previous transfer, + * this is the first step in download resumption. + * Checks to see if they have already been transferred or not. + * If yes, send ACK. If some other thread is sending it, sends wait cmd + * and checks again later. Otherwise, breaks the entire data into bufferSIze_ + * chunks and sends it. + * Previous states: PROCESS_SETTINGS_CMD, + * Next states : ACCEPT_WITH_TIMEOUT(network error), + * READ_NEXT_CMD(success) + */ + ReceiverState sendFileChunks(); + /** + * Sends global checkpoints to sender + * Previous states : PROCESS_DONE_CMD, + * FINISH_WITH_ERROR + * Next states : READ_NEXT_CMD(success), + * ACCEPT_WITH_TIMEOUT(socket write failure) + */ + ReceiverState sendGlobalCheckpoint(); + /** + * Sends DONE to sender, also tries to read back ack. If anything fails during + * this state, doneSendFailure_ thread variable is set. This flag makes the + * state machine behave differently, effectively bypassing all session related + * things. + * Previous states : SEND_LOCAL_CHECKPOINT, + * FINISH_WITH_ERROR + * Next states : END(success), + * ACCEPT_WITH_TIMEOUT(failure) + */ + ReceiverState sendDoneCmd(); + + /** + * Sends ABORT cmd back to the sender + * Previous states : PROCESS_FILE_CMD + * Next states : FINISH_WITH_ERROR + */ + ReceiverState sendAbortCmd(); + + /** + * Internal implementation of waitForFinishOrNewCheckpoint + * Returns : + * SEND_GLOBAL_CHECKPOINTS if there are checkpoints + * SEND_DONE_CMD if there are no checkpoints and + * there are no active threads + * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT in all other cases + */ + ReceiverState checkForFinishOrNewCheckpoints(); + + /** + * Waits for transfer to finish or new checkpoints. This state first + * increments waitingThreadCount_. Then, it + * waits till all the threads have finished. It sends periodic WAIT signal to + * prevent sender from timing out. If a new checkpoint is found, we move to + * SEND_GLOBAL_CHECKPOINTS state. + * Previous states : PROCESS_DONE_CMD + * Next states : SEND_DONE_CMD(all threads finished), + * SEND_GLOBAL_CHECKPOINTS(if new checkpoints are found), + * ACCEPT_WITH_TIMEOUT(if socket write fails) + */ + ReceiverState waitForFinishOrNewCheckpoint(); + + /** + * Waits for transfer to finish. Only called when there is an error for the + * thread. It adds a checkpoint to the global list of checkpoints if a + * connection was received. It increments waitingWithErrorThreadCount_ and + * waits till the session ends. + * Previous states : Almost all states + * Next states : END + */ + ReceiverState finishWithError(); + + /// Mapping from receiver states to state functions + static const StateFunction stateMap_[]; + + /// Main entry point for the thread, starts the state machine + void start() override; + + /** + * Server socket object that provides functionality such as listen() + * accept, read, write on the socket + */ + ServerSocket socket_; + + /// Buffer that receives reads data into from the network + char *buf_{nullptr}; + + /// Size of the buffer + const int64_t bufferSize_; + + /// Marks the number of bytes already read in the buffer + int64_t numRead_{0}; + + /// Following two are markers to mark how much data has been read/parsed + int64_t off_{0}; + int64_t oldOffset_{0}; + + /// Number of checkpoints already transferred + int checkpointIndex_{0}; + + /// Checkpoints saved for this thread + std::vector checkpoints_; + + /** + * Pending value of checkpoint count. since write call success does not + * gurantee actual transfer, we do not apply checkpoint count update after + * the write. Only after receiving next cmd from sender, we apply the + * update + */ + int pendingCheckpointIndex_{0}; + + /// read timeout for sender + int64_t senderReadTimeout_{-1}; + + /// write timeout for sender + int64_t senderWriteTimeout_{-1}; + + /// whether the transfer is in block mode or not + bool isBlockMode_{true}; + + /// Checkpoint local to the thread, updated regularly + Checkpoint checkpoint_; + + /// whether checksum verification is enabled or not + bool enableChecksum_{false}; + + /// whether settings have been received and verified for the current + /// connection. This is used to determine round robin order for polling in + /// the server socket + bool curConnectionVerified_{false}; + + /** + * Whether SEND_DONE_CMD state has already failed for this session or not. + * This has to be separately handled, because session barrier is + * implemented before sending done cmd + */ + bool doneSendFailure_{false}; + + /// Checkpoints that have not been sent back to the sender + std::vector newCheckpoints_; +}; +} +} diff --git a/Reporting.cpp b/Reporting.cpp index f39bc5a0..585bdb1a 100644 --- a/Reporting.cpp +++ b/Reporting.cpp @@ -30,6 +30,24 @@ TransferStats& TransferStats::operator+=(const TransferStats& stats) { numFiles_ += stats.numFiles_; numBlocks_ += stats.numBlocks_; failedAttempts_ += stats.failedAttempts_; + if (numBlocksSend_ == -1) { + numBlocksSend_ = stats.numBlocksSend_; + } else if (stats.numBlocksSend_ != -1 && + numBlocksSend_ != stats.numBlocksSend_) { + LOG_IF(ERROR, errCode_ == OK) << "Mismatch in the numBlocksSend " + << numBlocksSend_ << " " + << stats.numBlocksSend_; + errCode_ = ERROR; + } + if (totalSenderBytes_ == -1) { + totalSenderBytes_ = stats.totalSenderBytes_; + } else if (stats.totalSenderBytes_ != -1 && + totalSenderBytes_ != stats.totalSenderBytes_) { + LOG_IF(ERROR, errCode_ == OK) << "Mismatch in the total sender bytes " + << totalSenderBytes_ << " " + << stats.totalSenderBytes_; + errCode_ = ERROR; + } if (stats.errCode_ != OK) { if (errCode_ == OK) { // First error. Setting this as the error code @@ -123,6 +141,17 @@ TransferReport::TransferReport( summary_.setNumFiles(numTransferredFiles); } +TransferReport::TransferReport(TransferStats&& globalStats, double totalTime, + int64_t totalFileSize) + : TransferReport(std::move(globalStats)) { + totalTime_ = totalTime; + totalFileSize_ = totalFileSize; +} + +TransferReport::TransferReport(TransferStats&& globalStats) { + summary_ = std::move(globalStats); +} + TransferReport::TransferReport(const std::vector& threadStats, double totalTime, int64_t totalFileSize) : totalTime_(totalTime), totalFileSize_(totalFileSize) { diff --git a/Reporting.h b/Reporting.h index dfc2abf5..39da302e 100644 --- a/Reporting.h +++ b/Reporting.h @@ -52,7 +52,7 @@ std::ostream &operator<<(std::ostream &os, const std::vector &v) { std::copy(v.begin(), v.end(), std::ostream_iterator(os, " ")); return os; } - +// TODO rename to ThreadResult /// class representing statistics related to file transfer class TransferStats { private: @@ -75,6 +75,12 @@ class TransferStats { /// number of failed transfers int64_t failedAttempts_ = 0; + /// Total number of blocks sent by sender + int64_t numBlocksSend_{-1}; + + /// Total number of bytes sent by sender + int64_t totalSenderBytes_{-1}; + /// status of the transfer ErrorCode errCode_ = OK; @@ -114,6 +120,36 @@ class TransferStats { errCode_ = remoteErrCode_ = OK; } + /// Validates the global transfer stats. Only call this method + /// on the accumulated stats. Can only be called for receiver + void validate() { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + if (numBlocksSend_ == -1) { + LOG(ERROR) << "Negative number of blocks sent by the sender"; + errCode_ = ERROR; + } else if (totalSenderBytes_ != -1 && + totalSenderBytes_ != effectiveDataBytes_) { + // did not receive all the bytes + LOG(ERROR) << "Number of bytes sent and received do not match " + << totalSenderBytes_ << " " << effectiveDataBytes_; + errCode_ = ERROR; + } else { + errCode_ = OK; + } + } + + /// @return the number of blocks sent by sender + int64_t getNumBlocksSend() const { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + return numBlocksSend_; + } + + /// @return the total sender bytes + int64_t getTotalSenderBytes() const { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + return totalSenderBytes_; + } + /// @return number of header bytes transferred int64_t getHeaderBytes() const { folly::RWSpinLock::ReadHolder lock(mutex_.get()); @@ -225,6 +261,18 @@ class TransferStats { headerBytes_ += count; } + /// @param set num blocks send + void setNumBlocksSend(int64_t numBlocksSend) { + folly::RWSpinLock::WriteHolder lock(mutex_.get()); + numBlocksSend_ = numBlocksSend; + } + + /// @param set total sender bytes + void setTotalSenderBytes(int64_t totalSenderBytes) { + folly::RWSpinLock::WriteHolder lock(mutex_.get()); + totalSenderBytes_ = totalSenderBytes; + } + /// one more file transfer failed void incrFailedAttempts() { folly::RWSpinLock::WriteHolder lock(mutex_.get()); @@ -311,6 +359,9 @@ class TransferReport { TransferReport(const std::vector &threadStats, double totalTime, int64_t totalFileSize); + TransferReport(TransferStats &&stats, double totalTime, + int64_t totalFileSize); + explicit TransferReport(TransferStats &&stats); /// constructor used by receiver, does move the thread stats explicit TransferReport(std::vector &threadStats); /// @return summary of the report diff --git a/Sender.cpp b/Sender.cpp index 8eb2e538..ffb356ae 100644 --- a/Sender.cpp +++ b/Sender.cpp @@ -9,6 +9,7 @@ #include "Sender.h" #include "ClientSocket.h" +#include "SenderThread.h" #include "Throttler.h" #include "SocketUtils.h" @@ -24,209 +25,26 @@ namespace facebook { namespace wdt { - -ThreadTransferHistory::ThreadTransferHistory(DirectorySourceQueue &queue, - TransferStats &threadStats) - : queue_(queue), threadStats_(threadStats) { -} - -std::string ThreadTransferHistory::getSourceId(int64_t index) { - folly::SpinLockGuard guard(lock_); - std::string sourceId; - const int64_t historySize = history_.size(); - if (index >= 0 && index < historySize) { - sourceId = history_[index]->getIdentifier(); - } else { - LOG(WARNING) << "Trying to read out of bounds data " << index << " " - << history_.size(); - } - return sourceId; -} - -bool ThreadTransferHistory::addSource(std::unique_ptr &source) { - folly::SpinLockGuard guard(lock_); - if (globalCheckpoint_) { - // already received an error for this thread - VLOG(1) << "adding source after global checkpoint is received. Returning " - "the source to the queue"; - markSourceAsFailed(source, lastCheckpoint_.get()); - lastCheckpoint_.reset(); - queue_.returnToQueue(source); - return false; - } - history_.emplace_back(std::move(source)); - return true; -} - -ErrorCode ThreadTransferHistory::setCheckpointAndReturnToQueue( - const Checkpoint &checkpoint, bool globalCheckpoint) { - folly::SpinLockGuard guard(lock_); - const int64_t historySize = history_.size(); - int64_t numReceivedSources = checkpoint.numBlocks; - int64_t lastBlockReceivedBytes = checkpoint.lastBlockReceivedBytes; - if (numReceivedSources > historySize) { - LOG(ERROR) - << "checkpoint is greater than total number of sources transfered " - << history_.size() << " " << numReceivedSources; - return INVALID_CHECKPOINT; - } - ErrorCode errCode = validateCheckpoint(checkpoint, globalCheckpoint); - if (errCode == INVALID_CHECKPOINT) { - return INVALID_CHECKPOINT; - } - globalCheckpoint_ |= globalCheckpoint; - lastCheckpoint_ = folly::make_unique(checkpoint); - int64_t numFailedSources = historySize - numReceivedSources; - if (numFailedSources == 0 && lastBlockReceivedBytes > 0) { - if (!globalCheckpoint) { - // no block to apply checkpoint offset. This can happen if we receive same - // local checkpoint without adding anything to the history - LOG(WARNING) << "Local checkpoint has received bytes for last block, but " - "there are no unacked blocks in the history. Ignoring."; - } - } - numAcknowledged_ = numReceivedSources; - std::vector> sourcesToReturn; - for (int64_t i = 0; i < numFailedSources; i++) { - std::unique_ptr source = std::move(history_.back()); - history_.pop_back(); - const Checkpoint *checkpointPtr = - (i == numFailedSources - 1 ? &checkpoint : nullptr); - markSourceAsFailed(source, checkpointPtr); - sourcesToReturn.emplace_back(std::move(source)); - } - queue_.returnToQueue(sourcesToReturn); - LOG(INFO) << numFailedSources - << " number of sources returned to queue, checkpoint: " - << checkpoint; - return errCode; -} - -std::vector ThreadTransferHistory::popAckedSourceStats() { - const int64_t historySize = history_.size(); - WDT_CHECK(numAcknowledged_ == historySize); - // no locking needed, as this should be called after transfer has finished - std::vector sourceStats; - while (!history_.empty()) { - sourceStats.emplace_back(std::move(history_.back()->getTransferStats())); - history_.pop_back(); - } - return sourceStats; -} - -void ThreadTransferHistory::markAllAcknowledged() { - folly::SpinLockGuard guard(lock_); - numAcknowledged_ = history_.size(); -} - -void ThreadTransferHistory::returnUnackedSourcesToQueue() { - Checkpoint checkpoint; - checkpoint.numBlocks = numAcknowledged_; - setCheckpointAndReturnToQueue(checkpoint, false); -} - -ErrorCode ThreadTransferHistory::validateCheckpoint( - const Checkpoint &checkpoint, bool globalCheckpoint) { - if (lastCheckpoint_ == nullptr) { - return OK; - } - if (checkpoint.numBlocks < lastCheckpoint_->numBlocks) { - LOG(ERROR) << "Current checkpoint must be higher than previous checkpoint, " - "Last checkpoint: " << *lastCheckpoint_ - << ", Current checkpoint: " << checkpoint; - return INVALID_CHECKPOINT; - } - if (checkpoint.numBlocks > lastCheckpoint_->numBlocks) { - return OK; - } - bool noProgress = false; - // numBlocks same - if (checkpoint.lastBlockSeqId == lastCheckpoint_->lastBlockSeqId && - checkpoint.lastBlockOffset == lastCheckpoint_->lastBlockOffset) { - // same block - if (checkpoint.lastBlockReceivedBytes != - lastCheckpoint_->lastBlockReceivedBytes) { - LOG(ERROR) << "Current checkpoint has different received bytes, but all " - "other fields are same, Last checkpoint " - << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; - return INVALID_CHECKPOINT; - } - noProgress = true; - } else { - // different block - WDT_CHECK(checkpoint.lastBlockReceivedBytes >= 0); - if (checkpoint.lastBlockReceivedBytes == 0) { - noProgress = true; - } - } - if (noProgress && !globalCheckpoint) { - // we can get same global checkpoint multiple times, so no need to check for - // progress - LOG(WARNING) << "No progress since last checkpoint, Last checkpoint: " - << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; - return NO_PROGRESS; +void Sender::endCurTransfer() { + endTime_ = Clock::now(); + LOG(INFO) << "Last thread finished " << durationSeconds(endTime_ - startTime_) + << " for transfer id " << transferId_; + transferFinished_ = true; + if (throttler_) { + throttler_->deRegisterTransfer(); } - return OK; } -void ThreadTransferHistory::markSourceAsFailed( - std::unique_ptr &source, const Checkpoint *checkpoint) { - auto metadata = source->getMetaData(); - bool validCheckpoint = false; - if (checkpoint != nullptr) { - if (checkpoint->hasSeqId) { - if ((checkpoint->lastBlockSeqId == metadata.seqId) && - (checkpoint->lastBlockOffset == source->getOffset())) { - validCheckpoint = true; - } else { - LOG(WARNING) - << "Checkpoint block does not match history block. Checkpoint: " - << checkpoint->lastBlockSeqId << ", " << checkpoint->lastBlockOffset - << " History: " << metadata.seqId << ", " << source->getOffset(); - } - } else { - // Receiver at lower version! - // checkpoint does not have seq-id. We have to blindly trust - // lastBlockReceivedBytes. If we do not, transfer will fail because of - // number of bytes mismatch. Even if an error happens because of this, - // Receiver will fail. - validCheckpoint = true; - } - } - int64_t receivedBytes = - (validCheckpoint ? checkpoint->lastBlockReceivedBytes : 0); - TransferStats &sourceStats = source->getTransferStats(); - if (sourceStats.getErrorCode() != OK) { - // already marked as failed - sourceStats.addEffectiveBytes(0, receivedBytes); - threadStats_.addEffectiveBytes(0, receivedBytes); - } else { - auto dataBytes = source->getSize(); - auto headerBytes = sourceStats.getEffectiveHeaderBytes(); - int64_t wastedBytes = dataBytes - receivedBytes; - sourceStats.subtractEffectiveBytes(headerBytes, wastedBytes); - sourceStats.decrNumBlocks(); - sourceStats.setErrorCode(SOCKET_WRITE_ERROR); - sourceStats.incrFailedAttempts(); - - threadStats_.subtractEffectiveBytes(headerBytes, wastedBytes); - threadStats_.decrNumBlocks(); - threadStats_.incrFailedAttempts(); +void Sender::startNewTransfer() { + if (throttler_) { + throttler_->registerTransfer(); } - source->advanceOffset(receivedBytes); + LOG(INFO) << "Starting a new transfer " << transferId_ << " to " << destHost_; } -const Sender::StateFunction Sender::stateMap_[] = { - &Sender::connect, &Sender::readLocalCheckPoint, &Sender::sendSettings, - &Sender::sendBlocks, &Sender::sendDoneCmd, &Sender::sendSizeCmd, - &Sender::checkForAbort, &Sender::readFileChunks, &Sender::readReceiverCmd, - &Sender::processDoneCmd, &Sender::processWaitCmd, &Sender::processErrCmd, - &Sender::processAbortCmd, &Sender::processVersionMismatch}; - Sender::Sender(const std::string &destHost, const std::string &srcDir) - : queueAbortChecker_(this) { + : queueAbortChecker_(this), destHost_(destHost) { LOG(INFO) << "WDT Sender " << Protocol::getFullVersion(); - destHost_ = destHost; srcDir_ = srcDir; transferFinished_ = true; const auto &options = WdtOptions::get(); @@ -262,6 +80,8 @@ Sender::Sender(const std::string &destHost, const std::string &srcDir, : Sender(destHost, srcDir) { ports_ = ports; dirQueue_->setFileInfo(srcFileInfo); + transferHistoryController_ = + folly::make_unique(*dirQueue_); } WdtTransferRequest Sender::init() { @@ -317,6 +137,43 @@ const std::string &Sender::getSrcDir() const { return srcDir_; } +ProtoNegotiationStatus Sender::getNegotiationStatus() { + return protoNegotiationStatus_; +} + +std::vector Sender::getNegotiatedProtocols() const { + std::vector ret; + for (const auto &senderThread : senderThreads_) { + ret.push_back(senderThread->getNegotiatedProtocol()); + } + return ret; +} + +void Sender::setProtoNegotiationStatus(ProtoNegotiationStatus status) { + protoNegotiationStatus_ = status; +} + +bool Sender::isSendFileChunks() const { + return (downloadResumptionEnabled_ && + protocolVersion_ >= Protocol::DOWNLOAD_RESUMPTION_VERSION); +} + +bool Sender::isFileChunksReceived() const { + std::lock_guard lock(mutex_); + return fileChunksReceived_; +} + +void Sender::setFileChunksInfo( + std::vector &fileChunksInfoList) { + std::lock_guard lock(mutex_); + if (fileChunksReceived_) { + LOG(WARNING) << "File chunks list received multiple times"; + return; + } + dirQueue_->setPreviouslyReceivedChunks(fileChunksInfoList); + fileChunksReceived_ = true; +} + const std::string &Sender::getDestination() const { return destHost_; } @@ -324,8 +181,9 @@ const std::string &Sender::getDestination() const { std::unique_ptr Sender::getTransferReport() { int64_t totalFileSize = dirQueue_->getTotalSize(); double totalTime = durationSeconds(Clock::now() - startTime_); + auto globalStats = getGlobalTransferStats(); std::unique_ptr transferReport = - folly::make_unique(globalThreadStats_, totalTime, + folly::make_unique(std::move(globalStats), totalTime, totalFileSize); return transferReport; } @@ -339,6 +197,26 @@ Clock::time_point Sender::getEndTime() { return endTime_; } +TransferStats Sender::getGlobalTransferStats() const { + TransferStats globalStats; + for (const auto &thread : senderThreads_) { + globalStats += thread->getTransferStats(); + } + return globalStats; +} + +ErrorCode Sender::verifyVersionMismatchStats() const { + for (const auto &senderThread : senderThreads_) { + const auto &threadStats = senderThread->getTransferStats(); + if (threadStats.getErrorCode() == OK) { + LOG(ERROR) << "Found a thread that completed transfer successfully " + << "despite version mismatch. " << senderThread->getPort(); + return ERROR; + } + } + return OK; +} + std::unique_ptr Sender::finish() { std::unique_lock instanceLock(instanceManagementMutex_); VLOG(1) << "Sender::finish()"; @@ -351,9 +229,8 @@ std::unique_ptr Sender::finish() { const bool twoPhases = options.two_phases; bool progressReportEnabled = progressReporter_ && progressReportIntervalMillis_ > 0; - const int64_t numPorts = ports_.size(); - for (int64_t i = 0; i < numPorts; i++) { - senderThreads_[i].join(); + for (auto &senderThread : senderThreads_) { + senderThread->finish(); } if (!twoPhases) { dirThread_.join(); @@ -366,9 +243,13 @@ std::unique_ptr Sender::finish() { } progressReporterThread_.join(); } - + std::vector threadStats; + for (auto &senderThread : senderThreads_) { + threadStats.push_back(std::move(senderThread->moveStats())); + } bool allSourcesAcked = false; - for (auto &stats : globalThreadStats_) { + for (auto &senderThread : senderThreads_) { + auto &stats = senderThread->getTransferStats(); if (stats.getErrorCode() == OK) { // at least one thread finished correctly // that means all transferred sources are acked @@ -378,7 +259,9 @@ std::unique_ptr Sender::finish() { } std::vector transferredSourceStats; - for (auto &transferHistory : transferHistories_) { + for (auto port : ports_) { + auto &transferHistory = + transferHistoryController_->getTransferHistory(port); if (allSourcesAcked) { transferHistory.markAllAcknowledged(); } else { @@ -391,18 +274,16 @@ std::unique_ptr Sender::finish() { std::make_move_iterator(stats.end())); } } - if (WdtOptions::get().full_reporting) { validateTransferStats(transferredSourceStats, - dirQueue_->getFailedSourceStats(), - globalThreadStats_); + dirQueue_->getFailedSourceStats()); } int64_t totalFileSize = dirQueue_->getTotalSize(); double totalTime = durationSeconds(endTime_ - startTime_); std::unique_ptr transferReport = folly::make_unique( transferredSourceStats, dirQueue_->getFailedSourceStats(), - globalThreadStats_, dirQueue_->getFailedDirectories(), totalTime, + threadStats, dirQueue_->getFailedDirectories(), totalTime, totalFileSize, dirQueue_->getCount()); if (progressReportEnabled) { @@ -410,8 +291,8 @@ std::unique_ptr Sender::finish() { } if (options.enable_perf_stat_collection) { PerfStatReport report; - for (auto &perfReport : perfReports_) { - report += perfReport; + for (auto &senderThread : senderThreads_) { + report += senderThread->getPerfReport(); } LOG(INFO) << report; } @@ -461,23 +342,14 @@ ErrorCode Sender::start() { } else { configureThrottler(); } - - // WARNING: Do not MERGE the following two loops. ThreadTransferHistory keeps - // a reference of TransferStats. And, any emplace operation on a vector - // invalidates all its references - const int64_t numPorts = ports_.size(); - for (int64_t i = 0; i < numPorts; i++) { - globalThreadStats_.emplace_back(true); - } - for (int64_t i = 0; i < numPorts; i++) { - transferHistories_.emplace_back(*dirQueue_, globalThreadStats_[i]); - } - perfReports_.resize(numPorts); - negotiatedProtocolVersions_.resize(numPorts, 0); - numActiveThreads_ = numPorts; - for (int64_t i = 0; i < numPorts; i++) { - globalThreadStats_[i].setId(folly::to(i)); - senderThreads_.emplace_back(&Sender::sendOne, this, i); + threadsController_ = new ThreadsController(ports_.size()); + threadsController_->setNumBarriers(SenderThread::NUM_BARRIERS); + threadsController_->setNumFunnels(SenderThread::NUM_FUNNELS); + threadsController_->setNumConditions(SenderThread::NUM_CONDITIONS); + senderThreads_ = threadsController_->makeThreads( + this, ports_.size(), ports_); + for (auto &senderThread : senderThreads_) { + senderThread->startThread(); } if (progressReportEnabled) { progressReporter_->start(); @@ -489,8 +361,7 @@ ErrorCode Sender::start() { void Sender::validateTransferStats( const std::vector &transferredSourceStats, - const std::vector &failedSourceStats, - const std::vector &threadStats) { + const std::vector &failedSourceStats) { int64_t sourceFailedAttempts = 0; int64_t sourceDataBytes = 0; int64_t sourceEffectiveDataBytes = 0; @@ -513,7 +384,8 @@ void Sender::validateTransferStats( sourceEffectiveDataBytes += stat.getEffectiveDataBytes(); sourceNumBlocks += stat.getNumBlocks(); } - for (const auto &stat : threadStats) { + for (const auto &senderThread : senderThreads_) { + const auto &stat = senderThread->getTransferStats(); threadFailedAttempts += stat.getFailedAttempts(); threadDataBytes += stat.getDataBytes(); threadEffectiveDataBytes += stat.getEffectiveDataBytes(); @@ -530,19 +402,18 @@ void Sender::setSocketCreator(const SocketCreator socketCreator) { socketCreator_ = socketCreator; } -std::unique_ptr Sender::connectToReceiver(const int port, - ErrorCode &errCode) { +std::unique_ptr Sender::connectToReceiver( + const int port, IAbortChecker const *abortChecker, ErrorCode &errCode) { auto startTime = Clock::now(); const auto &options = WdtOptions::get(); int connectAttempts = 0; std::unique_ptr socket; + std::string portStr = folly::to(port); if (!socketCreator_) { // socket creator not set, creating ClientSocket - socket = folly::make_unique( - destHost_, folly::to(port), &abortCheckerCallback_); + socket = folly::make_unique(destHost_, portStr, abortChecker); } else { - socket = socketCreator_(destHost_, folly::to(port), - &abortCheckerCallback_); + socket = socketCreator_(destHost_, portStr, abortChecker); } double retryInterval = options.sleep_millis; int maxRetries = options.max_retries; @@ -582,678 +453,6 @@ std::unique_ptr Sender::connectToReceiver(const int port, return socket; } -Sender::SenderState Sender::connect(ThreadData &data) { - VLOG(1) << "entered CONNECT state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - auto &socket = data.socket_; - auto &numReconnectWithoutProgress = data.numReconnectWithoutProgress_; - auto &options = WdtOptions::get(); - - if (socket) { - socket->close(); - } - if (numReconnectWithoutProgress >= options.max_transfer_retries) { - LOG(ERROR) << "Sender thread reconnected " << numReconnectWithoutProgress - << " times without making any progress, giving up. port: " - << socket->getPort(); - threadStats.setErrorCode(NO_PROGRESS); - return END; - } - - ErrorCode code; - socket = connectToReceiver(port, code); - if (code == ABORT) { - threadStats.setErrorCode(ABORT); - if (getCurAbortCode() == VERSION_MISMATCH) { - return PROCESS_VERSION_MISMATCH; - } - return END; - } - if (code != OK) { - threadStats.setErrorCode(code); - return END; - } - // clearing the totalSizeSent_ flag. This way if anything breaks, we resend - // the total size. - data.totalSizeSent_ = false; - auto nextState = - threadStats.getErrorCode() == OK ? SEND_SETTINGS : READ_LOCAL_CHECKPOINT; - // clear the error code, as this is a new transfer - threadStats.setErrorCode(OK); - return nextState; -} - -Sender::SenderState Sender::readLocalCheckPoint(ThreadData &data) { - LOG(INFO) << "entered READ_LOCAL_CHECKPOINT state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - auto &numReconnectWithoutProgress = data.numReconnectWithoutProgress_; - - std::vector checkpoints; - int64_t decodeOffset = 0; - char *buf = data.buf_; - int checkpointLen = Protocol::getMaxLocalCheckpointLength(protocolVersion_); - int64_t numRead = data.socket_->read(buf, checkpointLen); - if (numRead != checkpointLen) { - LOG(ERROR) << "read mismatch during reading local checkpoint " - << checkpointLen << " " << numRead << " port " << port; - threadStats.setErrorCode(SOCKET_READ_ERROR); - numReconnectWithoutProgress++; - return CONNECT; - } - if (!Protocol::decodeCheckpoints(protocolVersion_, buf, decodeOffset, - checkpointLen, checkpoints)) { - LOG(ERROR) << "checkpoint decode failure " - << folly::humanify(std::string(buf, checkpointLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (checkpoints.size() != 1 || checkpoints[0].port != port) { - LOG(ERROR) << "illegal local checkpoint " - << folly::humanify(std::string(buf, checkpointLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - const Checkpoint &checkpoint = checkpoints[0]; - auto numBlocks = checkpoint.numBlocks; - VLOG(1) << "received local checkpoint, port " << port << " num-blocks " - << numBlocks << " seq-id " << checkpoint.lastBlockSeqId << " offset " - << checkpoint.lastBlockOffset << " received-bytes " - << checkpoint.lastBlockReceivedBytes; - - if (numBlocks == -1) { - // Receiver failed while sending DONE cmd - return READ_RECEIVER_CMD; - } - - ErrorCode errCode = - transferHistory.setCheckpointAndReturnToQueue(checkpoint, false); - if (errCode == INVALID_CHECKPOINT) { - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (errCode == NO_PROGRESS) { - numReconnectWithoutProgress++; - } else { - numReconnectWithoutProgress = 0; - } - return SEND_SETTINGS; -} - -Sender::SenderState Sender::sendSettings(ThreadData &data) { - VLOG(1) << "entered SEND_SETTINGS state " << data.threadIndex_; - - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &options = WdtOptions::get(); - int64_t readTimeoutMillis = options.read_timeout_millis; - int64_t writeTimeoutMillis = options.write_timeout_millis; - int64_t off = 0; - buf[off++] = Protocol::SETTINGS_CMD; - bool sendFileChunks; - { - std::lock_guard lock(mutex_); - sendFileChunks = - (downloadResumptionEnabled_ && - protocolVersion_ >= Protocol::DOWNLOAD_RESUMPTION_VERSION); - } - Settings settings; - settings.readTimeoutMillis = readTimeoutMillis; - settings.writeTimeoutMillis = writeTimeoutMillis; - settings.transferId = transferId_; - settings.enableChecksum = options.enable_checksum; - settings.sendFileChunks = sendFileChunks; - settings.blockModeDisabled = (options.block_size_mbytes <= 0); - Protocol::encodeSettings(protocolVersion_, buf, off, Protocol::kMaxSettings, - settings); - int64_t toWrite = sendFileChunks ? Protocol::kMinBufLength : off; - int64_t written = socket->write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write failure " << written << " " << toWrite; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CONNECT; - } - threadStats.addHeaderBytes(toWrite); - return sendFileChunks ? READ_FILE_CHUNKS : SEND_BLOCKS; -} - -Sender::SenderState Sender::sendBlocks(ThreadData &data) { - VLOG(1) << "entered SEND_BLOCKS state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - auto &totalSizeSent = data.totalSizeSent_; - - if (protocolVersion_ >= Protocol::RECEIVER_PROGRESS_REPORT_VERSION && - !totalSizeSent && dirQueue_->fileDiscoveryFinished()) { - return SEND_SIZE_CMD; - } - - ErrorCode transferStatus; - std::unique_ptr source = dirQueue_->getNextSource(transferStatus); - if (!source) { - return SEND_DONE_CMD; - } - WDT_CHECK(!source->hasError()); - TransferStats transferStats = - sendOneByteSource(data.socket_, source, transferStatus); - threadStats += transferStats; - source->addTransferStats(transferStats); - source->close(); - if (!transferHistory.addSource(source)) { - // global checkpoint received for this thread. no point in - // continuing - LOG(ERROR) << "global checkpoint received, no point in continuing"; - threadStats.setErrorCode(CONN_ERROR); - return END; - } - - if (transferStats.getErrorCode() != OK) { - return CHECK_FOR_ABORT; - } - return SEND_BLOCKS; -} - -Sender::SenderState Sender::sendSizeCmd(ThreadData &data) { - VLOG(1) << "entered SEND_SIZE_CMD state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &totalSizeSent = data.totalSizeSent_; - int64_t off = 0; - buf[off++] = Protocol::SIZE_CMD; - - Protocol::encodeSize(buf, off, Protocol::kMaxSize, dirQueue_->getTotalSize()); - int64_t written = socket->write(buf, off); - if (written != off) { - LOG(ERROR) << "Socket write error " << off << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(off); - totalSizeSent = true; - return SEND_BLOCKS; -} - -Sender::SenderState Sender::sendDoneCmd(ThreadData &data) { - VLOG(1) << "entered SEND_DONE_CMD state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - int64_t off = 0; - buf[off++] = Protocol::DONE_CMD; - - auto pair = dirQueue_->getNumBlocksAndStatus(); - int64_t numBlocksDiscovered = pair.first; - ErrorCode transferStatus = pair.second; - buf[off++] = transferStatus; - - Protocol::encodeDone(protocolVersion_, buf, off, Protocol::kMaxDone, - numBlocksDiscovered, dirQueue_->getTotalSize()); - - int toWrite = Protocol::kMinBufLength; - int64_t written = socket->write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write failure " << written << " " << toWrite; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(toWrite); - VLOG(1) << "Wrote done cmd on " << socket->getFd() << " waiting for reply..."; - return READ_RECEIVER_CMD; -} - -Sender::SenderState Sender::checkForAbort(ThreadData &data) { - LOG(INFO) << "entered CHECK_FOR_ABORT state " << data.threadIndex_; - char *buf = data.buf_; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - - auto numRead = socket->read(buf, 1); - if (numRead != 1) { - VLOG(1) << "No abort cmd found"; - return CONNECT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd != Protocol::ABORT_CMD) { - VLOG(1) << "Unexpected result found while reading for abort " << buf[0]; - return CONNECT; - } - threadStats.addHeaderBytes(1); - return PROCESS_ABORT_CMD; -} - -Sender::SenderState Sender::readFileChunks(ThreadData &data) { - LOG(INFO) << "entered READ_FILE_CHUNKS state " << data.threadIndex_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - int64_t numRead = socket->read(buf, 1); - if (numRead != 1) { - LOG(ERROR) << "Socket read error 1 " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd == Protocol::ABORT_CMD) { - return PROCESS_ABORT_CMD; - } - if (cmd == Protocol::WAIT_CMD) { - return READ_FILE_CHUNKS; - } - if (cmd == Protocol::ACK_CMD) { - { - std::lock_guard lock(mutex_); - if (!fileChunksReceived_) { - LOG(ERROR) << "Sender has not yet received file chunks, but receiver " - "thinks it has already sent it"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - } - return SEND_BLOCKS; - } - if (cmd != Protocol::CHUNKS_CMD) { - LOG(ERROR) << "Unexpected cmd " << cmd; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - int64_t toRead = Protocol::kChunksCmdLen; - numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - int64_t off = 0; - int64_t bufSize, numFiles; - Protocol::decodeChunksCmd(buf, off, bufSize, numFiles); - LOG(INFO) << "File chunk list has " << numFiles - << " entries and is broken in buffers of length " << bufSize; - std::unique_ptr chunkBuffer(new char[bufSize]); - std::vector fileChunksInfoList; - while (true) { - int64_t numFileChunks = fileChunksInfoList.size(); - if (numFileChunks > numFiles) { - // We should never be able to read more file chunks than mentioned in the - // chunks cmd. Chunks cmd has buffer size used to transfer chunks and also - // number of chunks. This chunks are read and parsed and added to - // fileChunksInfoList. Number of chunks we decode should match with the - // number mentioned in the Chunks cmd. - LOG(ERROR) << "Number of file chunks received is more than the number " - "mentioned in CHUNKS_CMD " << numFileChunks << " " - << numFiles; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (numFileChunks == numFiles) { - break; - } - toRead = sizeof(int32_t); - numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - toRead = folly::loadUnaligned(buf); - toRead = folly::Endian::little(toRead); - numRead = socket->read(chunkBuffer.get(), toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - off = 0; - // decode function below adds decoded file chunks to fileChunksInfoList - bool success = Protocol::decodeFileChunksInfoList( - chunkBuffer.get(), off, toRead, fileChunksInfoList); - if (!success) { - LOG(ERROR) << "Unable to decode file chunks list"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - } - { - std::lock_guard lock(mutex_); - if (fileChunksReceived_) { - LOG(WARNING) << "File chunks list received multiple times"; - } else { - dirQueue_->setPreviouslyReceivedChunks(fileChunksInfoList); - fileChunksReceived_ = true; - } - } - // send ack for file chunks list - buf[0] = Protocol::ACK_CMD; - int64_t toWrite = 1; - int64_t written = socket->write(buf, toWrite); - if (toWrite != written) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(written); - return SEND_BLOCKS; -} - -Sender::SenderState Sender::readReceiverCmd(ThreadData &data) { - VLOG(1) << "entered READ_RECEIVER_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - int64_t numRead = data.socket_->read(buf, 1); - if (numRead != 1) { - LOG(ERROR) << "READ unexpected " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd == Protocol::ERR_CMD) { - return PROCESS_ERR_CMD; - } - if (cmd == Protocol::WAIT_CMD) { - return PROCESS_WAIT_CMD; - } - if (cmd == Protocol::DONE_CMD) { - return PROCESS_DONE_CMD; - } - if (cmd == Protocol::ABORT_CMD) { - return PROCESS_ABORT_CMD; - } - if (cmd == Protocol::LOCAL_CHECKPOINT_CMD) { - int checkpointLen = Protocol::getMaxLocalCheckpointLength(protocolVersion_); - int64_t toRead = checkpointLen - 1; - numRead = data.socket_->read(buf + 1, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Could not read possible local checkpoint " << toRead << " " - << numRead << " " << port; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - int64_t offset = 0; - std::vector checkpoints; - if (Protocol::decodeCheckpoints(protocolVersion_, buf, offset, - checkpointLen, checkpoints)) { - if (checkpoints.size() == 1 && checkpoints[0].port == port && - checkpoints[0].numBlocks == 0 && - checkpoints[0].lastBlockReceivedBytes == 0) { - // In a spurious local checkpoint, number of blocks and offset must both - // be zero - // Ignore the checkpoint - LOG(WARNING) - << "Received valid but unexpected local checkpoint, ignoring " - << port; - return READ_RECEIVER_CMD; - } - } - LOG(ERROR) << "Failed to verify spurious local checkpoint, port " << port; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - LOG(ERROR) << "Read unexpected receiver cmd " << cmd << " port " << port; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; -} - -Sender::SenderState Sender::processDoneCmd(ThreadData &data) { - VLOG(1) << "entered PROCESS_DONE_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - char *buf = data.buf_; - auto &socket = data.socket_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - transferHistory.markAllAcknowledged(); - - // send ack for DONE - buf[0] = Protocol::DONE_CMD; - socket->write(buf, 1); - - socket->shutdown(); - auto numRead = socket->read(buf, Protocol::kMinBufLength); - if (numRead != 0) { - LOG(WARNING) << "EOF not found when expected"; - return END; - } - VLOG(1) << "done with transfer, port " << port; - return END; -} - -Sender::SenderState Sender::processWaitCmd(ThreadData &data) { - LOG(INFO) << "entered PROCESS_WAIT_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - VLOG(1) << "received WAIT_CMD, port " << port; - transferHistory.markAllAcknowledged(); - return READ_RECEIVER_CMD; -} - -Sender::SenderState Sender::processErrCmd(ThreadData &data) { - int port = ports_[data.threadIndex_]; - LOG(INFO) << "entered PROCESS_ERR_CMD state " << data.threadIndex_ << " port " - << port; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - TransferStats &threadStats = data.threadStats_; - auto &transferHistories = data.transferHistories_; - auto &socket = data.socket_; - char *buf = data.buf_; - - int64_t toRead = sizeof(int16_t); - int64_t numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "read unexpected " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - - int16_t checkpointsLen = folly::loadUnaligned(buf); - checkpointsLen = folly::Endian::little(checkpointsLen); - char checkpointBuf[checkpointsLen]; - numRead = socket->read(checkpointBuf, checkpointsLen); - if (numRead != checkpointsLen) { - LOG(ERROR) << "read unexpected " << checkpointsLen << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - - std::vector checkpoints; - int64_t decodeOffset = 0; - if (!Protocol::decodeCheckpoints(protocolVersion_, checkpointBuf, - decodeOffset, checkpointsLen, checkpoints)) { - LOG(ERROR) << "checkpoint decode failure " - << folly::humanify(std::string(checkpointBuf, checkpointsLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - transferHistory.markAllAcknowledged(); - for (auto &checkpoint : checkpoints) { - auto errPort = checkpoint.port; - auto errPoint = checkpoint.numBlocks; - auto it = std::find(ports_.begin(), ports_.end(), errPort); - if (it == ports_.end()) { - LOG(ERROR) << "Invalid checkpoint " << errPoint - << ". No sender thread running on port " << errPort; - continue; - } - auto errThread = it - ports_.begin(); - VLOG(1) << "received global checkpoint " << errThread << " -> " << errPoint - << ", " << checkpoint.lastBlockSeqId << ", " - << checkpoint.lastBlockOffset << ", " - << checkpoint.lastBlockReceivedBytes; - transferHistories[errThread].setCheckpointAndReturnToQueue(checkpoint, - true); - } - return SEND_BLOCKS; -} - -Sender::SenderState Sender::processAbortCmd(ThreadData &data) { - LOG(INFO) << "entered PROCESS_ABORT_CMD state " << data.threadIndex_; - char *buf = data.buf_; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - - threadStats.setErrorCode(ABORT); - int toRead = Protocol::kAbortLength; - auto numRead = socket->read(buf, toRead); - if (numRead != toRead) { - // can not read checkpoint, but still must exit because of ABORT - LOG(ERROR) << "Error while trying to read ABORT cmd " << numRead << " " - << toRead; - return END; - } - int64_t offset = 0; - int32_t negotiatedProtocol; - ErrorCode remoteError; - int64_t checkpoint; - Protocol::decodeAbort(buf, offset, negotiatedProtocol, remoteError, - checkpoint); - threadStats.setRemoteErrorCode(remoteError); - std::string failedFileName = transferHistory.getSourceId(checkpoint); - LOG(WARNING) << "Received abort on " << data.threadIndex_ - << " remote protocol version " << negotiatedProtocol - << " remote error code " << errorCodeToStr(remoteError) - << " file " << failedFileName << " checkpoint " << checkpoint; - abort(remoteError); - if (remoteError == VERSION_MISMATCH) { - if (Protocol::negotiateProtocol(negotiatedProtocol, protocolVersion_) == - negotiatedProtocol) { - // sender can support this negotiated version - negotiatedProtocolVersions_[data.threadIndex_] = negotiatedProtocol; - return PROCESS_VERSION_MISMATCH; - } else { - LOG(ERROR) << "Sender can not support receiver version " - << negotiatedProtocol; - threadStats.setRemoteErrorCode(VERSION_INCOMPATIBLE); - } - } - return END; -} - -Sender::SenderState Sender::processVersionMismatch(ThreadData &data) { - LOG(INFO) << "entered PROCESS_VERSION_MISMATCH state " << data.threadIndex_; - auto &threadStats = data.threadStats_; - - WDT_CHECK(threadStats.getErrorCode() == ABORT); - std::unique_lock lock(mutex_); - if (protoNegotiationStatus_ != V_MISMATCH_WAIT) { - LOG(WARNING) << "Protocol version already negotiated, but transfer still " - "aborted due to version mismatch, port " - << ports_[data.threadIndex_]; - return END; - } - numWaitingWithAbort_++; - while (protoNegotiationStatus_ == V_MISMATCH_WAIT && - numWaitingWithAbort_ != numActiveThreads_) { - WDT_CHECK(numWaitingWithAbort_ < numActiveThreads_); - conditionAllAborted_.wait(lock); - } - numWaitingWithAbort_--; - if (protoNegotiationStatus_ == V_MISMATCH_FAILED) { - return END; - } - if (protoNegotiationStatus_ == V_MISMATCH_RESOLVED) { - threadStats.setRemoteErrorCode(OK); - return CONNECT; - } - protoNegotiationStatus_ = V_MISMATCH_FAILED; - - for (const auto &stat : globalThreadStats_) { - if (stat.getErrorCode() == OK) { - // Some other thread finished successfully, should never happen in case of - // version mismatch - return END; - } - } - - const int numHistories = transferHistories_.size(); - for (int i = 0; i < numHistories; i++) { - auto &history = transferHistories_[i]; - if (history.getNumAcked() > 0) { - LOG(ERROR) << "Even though the transfer aborted due to VERSION_MISMATCH, " - "some blocks got acked by the receiver, port " << ports_[i] - << " numAcked " << history.getNumAcked(); - return END; - } - history.returnUnackedSourcesToQueue(); - } - - int negotiatedProtocol = 0; - for (int protocolVersion : negotiatedProtocolVersions_) { - if (protocolVersion > 0) { - if (negotiatedProtocol > 0 && negotiatedProtocol != protocolVersion) { - LOG(ERROR) << "Different threads negotiated different protocols " - << negotiatedProtocol << " " << protocolVersion; - return END; - } - negotiatedProtocol = protocolVersion; - } - } - WDT_CHECK_GT(negotiatedProtocol, 0); - - LOG_IF(INFO, negotiatedProtocol != protocolVersion_) - << "Changing protocol version to " << negotiatedProtocol - << ", previous version " << protocolVersion_; - protocolVersion_ = negotiatedProtocol; - threadStats.setRemoteErrorCode(OK); - protoNegotiationStatus_ = V_MISMATCH_RESOLVED; - clearAbort(); - conditionAllAborted_.notify_all(); - return CONNECT; -} - -void Sender::sendOne(int threadIndex) { - INIT_PERF_STAT_REPORT - std::vector &transferHistories = transferHistories_; - TransferStats &threadStats = globalThreadStats_[threadIndex]; - Clock::time_point startTime = Clock::now(); - int port = ports_[threadIndex]; - auto completionGuard = folly::makeGuard([&] { - std::unique_lock lock(mutex_); - numActiveThreads_--; - if (numActiveThreads_ == 0) { - LOG(INFO) << "Last thread finished " - << durationSeconds(Clock::now() - startTime_); - endTime_ = Clock::now(); - transferFinished_ = true; - } - if (throttler_) { - throttler_->deRegisterTransfer(); - } - conditionAllAborted_.notify_one(); - }); - if (throttler_) { - throttler_->registerTransfer(); - } - ThreadData threadData(threadIndex, threadStats, transferHistories); - SenderState state = CONNECT; - while (state != END) { - ErrorCode abortCode = getCurAbortCode(); - if (abortCode != OK) { - LOG(ERROR) << "Transfer aborted " << port << " " - << errorCodeToStr(abortCode); - threadStats.setErrorCode(ABORT); - if (abortCode == VERSION_MISMATCH) { - state = PROCESS_VERSION_MISMATCH; - } else { - break; - } - } - state = (this->*stateMap_[state])(threadData); - } - - double totalTime = durationSeconds(Clock::now() - startTime); - LOG(INFO) << "Port " << port << " done. " << threadStats - << " Total throughput = " - << threadStats.getEffectiveTotalBytes() / totalTime / kMbToB - << " Mbytes/sec"; - perfReports_[threadIndex] = *perfStatReport; - return; -} - TransferStats Sender::sendOneByteSource( const std::unique_ptr &socket, const std::unique_ptr &source, ErrorCode transferStatus) { @@ -1267,7 +466,6 @@ TransferStats Sender::sendOneByteSource( off += sizeof(int16_t); const int64_t expectedSize = source->getSize(); int64_t actualSize = 0; - const SourceMetaData &metadata = source->getMetaData(); BlockDetails blockDetails; blockDetails.fileName = metadata.relPath; @@ -1277,7 +475,6 @@ TransferStats Sender::sendOneByteSource( blockDetails.dataSize = expectedSize; blockDetails.allocationStatus = metadata.allocationStatus; blockDetails.prevSeqId = metadata.prevSeqId; - Protocol::encodeHeader(protocolVersion_, headerBuf, off, Protocol::kMaxHeader, blockDetails); int16_t littleEndianOff = folly::Endian::little((int16_t)off); diff --git a/Sender.h b/Sender.h index 73cb9c57..7e4536c7 100644 --- a/Sender.h +++ b/Sender.h @@ -11,9 +11,7 @@ #include "WdtBase.h" #include "ClientSocket.h" #include "WdtOptions.h" - -#include - +#include "ThreadsController.h" #include #include #include @@ -22,87 +20,13 @@ namespace facebook { namespace wdt { - -class DirectorySourceQueue; - -/// transfer history of a sender thread -class ThreadTransferHistory { - public: - /** - * @param queue directory queue - * @param threadStats stat object of the thread - */ - ThreadTransferHistory(DirectorySourceQueue &queue, - TransferStats &threadStats); - - /** - * @param index of the source - * @return if index is in bounds, returns the identifier for the - * source, else returns empty string - */ - std::string getSourceId(int64_t index); - - /** - * Adds the source to the history. If global checkpoint has already been - * received, then the source is returned to the queue. - * - * @param source source to be added to history - * @return true if added to history, false if not added due to a - * global checkpoint - */ - bool addSource(std::unique_ptr &source); - - /** - * Sets checkpoint. Also, returns unacked sources to queue - * - * @param checkpoint checkpoint received - * @param globalCheckpoint global or local checkpoint - * @return number of sources returned to queue - */ - ErrorCode setCheckpointAndReturnToQueue(const Checkpoint &checkpoint, - bool globalCheckpoint); - - /** - * @return stats for acked sources, must be called after all the - * unacked sources are returned to the queue - */ - std::vector popAckedSourceStats(); - - /// marks all the sources as acked - void markAllAcknowledged(); - - /// returns all unacked sources to the queue - void returnUnackedSourcesToQueue(); - - /** - * @return number of sources acked by the receiver - */ - int64_t getNumAcked() const { - return numAcknowledged_; - } - - private: - ErrorCode validateCheckpoint(const Checkpoint &checkpoint, - bool globalCheckpoint); - - void markSourceAsFailed(std::unique_ptr &source, - const Checkpoint *checkpoint); - - /// reference to global queue - DirectorySourceQueue &queue_; - /// reference to thread stats - TransferStats &threadStats_; - /// history of the thread - std::vector> history_; - /// whether a global error checkpoint has been received or not - bool globalCheckpoint_{false}; - /// number of sources acked by the receiver thread - int64_t numAcknowledged_{0}; - /// last received checkpoint - std::unique_ptr lastCheckpoint_{nullptr}; - folly::SpinLock lock_; +class SenderThread; +class TransferHistoryController; +enum ProtoNegotiationStatus { + V_MISMATCH_WAIT, // waiting for version mismatch to be processed + V_MISMATCH_RESOLVED, // version mismatch processed and was successful + V_MISMATCH_FAILED, // version mismatch processed and it failed }; - /** * The sender for the transfer. One instance of sender should only be * responsible for one transfer. For a second transfer you should make @@ -226,6 +150,34 @@ class Sender : public WdtBase { void setSocketCreator(const SocketCreator socketCreator); private: + friend class SenderThread; + + /// Get the sum of all the thread transfer stats + TransferStats getGlobalTransferStats() const; + + /// Verifies that if there is version mismatch then no thread stats + /// should be OK + ErrorCode verifyVersionMismatchStats() const; + + /// General utility used by sender threads to connect to receiver + std::unique_ptr connectToReceiver( + const int port, IAbortChecker const *abortChecker, ErrorCode &errCode); + + /// Method responsible for sending one source to the destination + virtual TransferStats sendOneByteSource( + const std::unique_ptr &socket, + const std::unique_ptr &source, ErrorCode transferStatus); + + /// Returns true if file chunks need to be read + bool isSendFileChunks() const; + + /// Retruns true if file chunks been received by a thread + bool isFileChunksReceived() const; + + /// Sender thread calls this method to set the file chunks info received + /// from the receiver + void setFileChunksInfo(std::vector &fileChunksInfoList); + /// Abort checker passed to DirectoryQueue. If all the network threads finish, /// directory discovery thread is also aborted class QueueAbortChecker : public IAbortChecker { @@ -241,191 +193,9 @@ class Sender : public WdtBase { Sender *sender_; }; + /// Abort checker shared with the directory queue QueueAbortChecker queueAbortChecker_; - /// state machine states - enum SenderState { - CONNECT, - READ_LOCAL_CHECKPOINT, - SEND_SETTINGS, - SEND_BLOCKS, - SEND_DONE_CMD, - SEND_SIZE_CMD, - CHECK_FOR_ABORT, - READ_FILE_CHUNKS, - READ_RECEIVER_CMD, - PROCESS_DONE_CMD, - PROCESS_WAIT_CMD, - PROCESS_ERR_CMD, - PROCESS_ABORT_CMD, - PROCESS_VERSION_MISMATCH, - END - }; - - /// structure to share data among different states - struct ThreadData { - const int threadIndex_; - TransferStats &threadStats_; - std::vector &transferHistories_; - std::unique_ptr socket_; - char buf_[Protocol::kMinBufLength]; - /// whether total file size has been sent to the receiver - bool totalSizeSent_{false}; - /// number of consecutive reconnects without any progress - int numReconnectWithoutProgress_{0}; - ThreadData(int threadIndex, TransferStats &threadStats, - std::vector &transferHistories) - : threadIndex_(threadIndex), - threadStats_(threadStats), - transferHistories_(transferHistories) { - } - - ThreadTransferHistory &getTransferHistory() { - return transferHistories_[threadIndex_]; - } - }; - - typedef SenderState (Sender::*StateFunction)(ThreadData &data); - - /** - * tries to connect to the receiver - * Previous states : Almost all states(in case of network errors, all states - * move to this state) - * Next states : SEND_SETTINGS(if there is no previous error) - * READ_LOCAL_CHECKPOINT(if there is previous error) - * END(failed) - */ - SenderState connect(ThreadData &data); - /** - * tries to read local checkpoint and return unacked sources to queue. If the - * checkpoint value is -1, then we know previous attempt to send DONE had - * failed. So, we move to READ_RECEIVER_CMD state. - * Previous states : CONNECT - * Next states : CONNECT(read failure), - * END(protocol error or global checkpoint found), - * READ_RECEIVER_CMD(if checkpoint is -1), - * SEND_SETTINGS(success) - */ - SenderState readLocalCheckPoint(ThreadData &data); - /** - * sends sender settings to the receiver - * Previous states : READ_LOCAL_CHECKPOINT, - * CONNECT - * Next states : SEND_BLOCKS(success), - * CONNECT(failure) - */ - SenderState sendSettings(ThreadData &data); - /** - * sends blocks to receiver till the queue is not empty. After transferring a - * block, we add it to the history. While adding to history, if it is found - * that global checkpoint has been received for this thread, we move to END - * state. - * Previous states : SEND_SETTINGS, - * PROCESS_ERR_CMD - * Next states : SEND_BLOCKS(success), - * END(global checkpoint received), - * CHECK_FOR_ABORT(socket write failure), - * SEND_DONE_CMD(no more blocks left to transfer) - */ - SenderState sendBlocks(ThreadData &data); - /** - * sends DONE cmd to the receiver - * Previous states : SEND_BLOCKS - * Next states : CONNECT(failure), - * READ_RECEIVER_CMD(success) - */ - SenderState sendDoneCmd(ThreadData &data); - /** - * sends size cmd to the receiver - * Previous states : SEND_BLOCKS - * Next states : CHECK_FOR_ABORT(failure), - * SEND_BLOCKS(success) - */ - SenderState sendSizeCmd(ThreadData &data); - /** - * checks to see if the receiver has sent ABORT or not - * Previous states : SEND_BLOCKS, - * SEND_DONE_CMD - * Next states : CONNECT(no ABORT cmd), - * END(protocol error), - * PROCESS_ABORT_CMD(read ABORT cmd) - */ - SenderState checkForAbort(ThreadData &data); - /** - * reads previously transferred file chunks list. If it receives an ACK cmd, - * then it moves on. If wait cmd is received, it waits. Otherwise reads the - * file chunks and when done starts directory queue thread. - * Previous states : SEND_SETTINGS, - * Next states: READ_FILE_CHUNKS(if wait cmd is received), - * CHECK_FOR_ABORT(network error), - * END(protocol error), - * SEND_BLOCKS(success) - * - */ - SenderState readFileChunks(ThreadData &data); - /** - * reads receiver cmd - * Previous states : SEND_DONE_CMD - * Next states : PROCESS_DONE_CMD, - * PROCESS_WAIT_CMD, - * PROCESS_ERR_CMD, - * END(protocol error), - * CONNECT(failure) - */ - SenderState readReceiverCmd(ThreadData &data); - /** - * handles DONE cmd - * Previous states : READ_RECEIVER_CMD - * Next states : END - */ - SenderState processDoneCmd(ThreadData &data); - /** - * handles WAIT cmd - * Previous states : READ_RECEIVER_CMD - * Next states : READ_RECEIVER_CMD - */ - SenderState processWaitCmd(ThreadData &data); - /** - * reads list of global checkpoints and returns unacked sources to queue. - * Previous states : READ_RECEIVER_CMD - * Next states : CONNECT(socket read failure) - * END(checkpoint list decode failure), - * SEND_BLOCKS(success) - */ - SenderState processErrCmd(ThreadData &data); - /** - * processes ABORT cmd - * Previous states : CHECK_FOR_ABORT, - * READ_RECEIVER_CMD - * Next states : END - */ - SenderState processAbortCmd(ThreadData &data); - - /** - * waits for all active threads to be aborted, checks to see if the abort was - * due to version mismatch. Also performs various sanity checks. - * Previous states : Almost all threads, abort flags is checked between every - * state transition - * Next states : CONNECT(Abort was due to version mismatch), - * END(if abort was not due to version mismatch or some sanity - * check failed) - */ - SenderState processVersionMismatch(ThreadData &data); - - /// mapping from sender states to state functions - static const StateFunction stateMap_[]; - - /// Method responsible for sending one source to the destination - virtual TransferStats sendOneByteSource( - const std::unique_ptr &socket, - const std::unique_ptr &source, ErrorCode transferStatus); - - /// Every sender thread executes this method to send the data - void sendOne(int threadIndex); - - std::unique_ptr connectToReceiver(const int port, - ErrorCode &errCode); - /** * Internal API that triggers the directory thread, sets up the sender * threads and starts the transfer. Returns after the sender threads @@ -437,12 +207,10 @@ class Sender : public WdtBase { * @param transferredSourceStats Stats for the successfully transmitted * sources * @param failedSourceStats Stats for the failed sources - * @param threadStats Stats calculated by each sender thread */ void validateTransferStats( const std::vector &transferredSourceStats, - const std::vector &failedSourceStats, - const std::vector &threadStats); + const std::vector &failedSourceStats); /** * Responsible for doing a periodic check. @@ -452,16 +220,14 @@ class Sender : public WdtBase { */ void reportProgress(); + /// Address of the destination host where the files are sent + const std::string destHost_; /// Pointer to DirectorySourceQueue which reads the srcDir and the files std::unique_ptr dirQueue_; - /// List of ports where the receiver threads are running on the destination - std::vector ports_; /// Number of active threads, decremented every time a thread is finished int32_t numActiveThreads_{0}; /// The directory from where the files are read std::string srcDir_; - /// Address of the destination host where the files are sent - std::string destHost_; /// The interval at which the progress reporter should check for progress int progressReportIntervalMillis_; /// Socket creator used to optionally create different kinds of client socket @@ -473,26 +239,25 @@ class Sender : public WdtBase { /// Thread that is running the discovery of files using the dirQueue_ std::thread dirThread_; /// Threads which are responsible for transfer of the sources - std::vector senderThreads_; + std::vector> senderThreads_; /// Thread responsible for doing the progress checks. Uses reportProgress() std::thread progressReporterThread_; - /// Vector of per thread stats, this same instance is used in reporting - std::vector globalThreadStats_; - /// per thread perf report - std::vector perfReports_; - /// per thread negotiated protocol versions - std::vector negotiatedProtocolVersions_; - /// number of threads waiting in PROCESS_VERSION_MISMATCH state - int numWaitingWithAbort_{0}; - /// Condition variable used to co-ordinate threads waiting in - /// PROCESS_VERSION_MISMATCH state - std::condition_variable conditionAllAborted_; - - enum ProtoNegotiationStatus { - V_MISMATCH_WAIT, // waiting for version mismatch to be processed - V_MISMATCH_RESOLVED, // version mismatch processed and was successful - V_MISMATCH_FAILED, // version mismatch processed and it failed - }; + + /// Returns the protocol negotiation status of the parent sender + ProtoNegotiationStatus getNegotiationStatus(); + + /// Set the protocol negotiation status, called by sender thread + void setProtoNegotiationStatus(ProtoNegotiationStatus status); + + /// Things to do before ending the current transfer + void endCurTransfer(); + + /// Initializing the new transfer + void startNewTransfer(); + + /// Returns vector of negotiated protocols set by sender threads + std::vector getNegotiatedProtocols() const; + /// Protocol negotiation status, used to co-ordinate processing of version /// mismatch. Threads aborted due to version mismatch waits for all threads to /// abort and reach PROCESS_VERSION_MISMATCH state. Last thread processes @@ -503,20 +268,21 @@ class Sender : public WdtBase { std::condition_variable conditionFinished_; /// Mutex which is shared between the parent thread, sender thread and /// progress reporter thread - std::mutex mutex_; + mutable std::mutex mutex_; /// Set to false when the transfer begins and then turned on when it ends bool transferFinished_; /// Time at which the transfer was started std::chrono::time_point startTime_; /// Time at which the transfer finished std::chrono::time_point endTime_; - /// Per thread transfer history - std::vector transferHistories_; /// Has finished been called and threads joined bool areThreadsJoined_{true}; /// Mutex for the management of this instance, specifically to keep the /// instance sane for multi threaded public API calls std::mutex instanceManagementMutex_; + + /// Transfer history controller for the sender threads + std::unique_ptr transferHistoryController_; }; } } // namespace facebook::wdt diff --git a/SenderThread.cpp b/SenderThread.cpp new file mode 100644 index 00000000..0fe2d3b4 --- /dev/null +++ b/SenderThread.cpp @@ -0,0 +1,633 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "SenderThread.h" +#include "Sender.h" +#include "ClientSocket.h" +#include +#include +#include +#include +#include +#include +#include + +namespace facebook { +namespace wdt { +std::ostream &operator<<(std::ostream &os, const SenderThread &senderThread) { + os << "Thread[" << senderThread.threadIndex_ + << ", port: " << senderThread.port_ << "] "; + return os; +} + +const SenderThread::StateFunction SenderThread::stateMap_[] = { + &SenderThread::connect, &SenderThread::readLocalCheckPoint, + &SenderThread::sendSettings, &SenderThread::sendBlocks, + &SenderThread::sendDoneCmd, &SenderThread::sendSizeCmd, + &SenderThread::checkForAbort, &SenderThread::readFileChunks, + &SenderThread::readReceiverCmd, &SenderThread::processDoneCmd, + &SenderThread::processWaitCmd, &SenderThread::processErrCmd, + &SenderThread::processAbortCmd, &SenderThread::processVersionMismatch}; + +SenderState SenderThread::connect() { + VLOG(1) << *this << " entered CONNECT state"; + if (socket_) { + socket_->close(); + } + const auto &options = WdtOptions::get(); + if (numReconnectWithoutProgress_ >= options.max_transfer_retries) { + LOG(ERROR) << "Sender thread reconnected " << numReconnectWithoutProgress_ + << " times without making any progress, giving up. port: " + << socket_->getPort(); + threadStats_.setErrorCode(NO_PROGRESS); + return END; + } + ErrorCode code; + socket_ = + wdtParent_->connectToReceiver(port_, socketAbortChecker_.get(), code); + if (code == ABORT) { + threadStats_.setErrorCode(ABORT); + if (wdtParent_->getCurAbortCode() == VERSION_MISMATCH) { + return PROCESS_VERSION_MISMATCH; + } + return END; + } + if (code != OK) { + threadStats_.setErrorCode(code); + return END; + } + auto nextState = SEND_SETTINGS; + if (threadStats_.getErrorCode() != OK) { + nextState = READ_LOCAL_CHECKPOINT; + } + // resetting the status of thread + reset(); + return nextState; +} + +SenderState SenderThread::readLocalCheckPoint() { + LOG(INFO) << *this << " entered READ_LOCAL_CHECKPOINT state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + std::vector checkpoints; + int64_t decodeOffset = 0; + int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + int64_t numRead = socket_->read(buf_, checkpointLen); + if (numRead != checkpointLen) { + LOG(ERROR) << "read mismatch during reading local checkpoint " + << checkpointLen << " " << numRead << " port " << port_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + numReconnectWithoutProgress_++; + return CONNECT; + } + if (!Protocol::decodeCheckpoints(threadProtocolVersion_, buf_, decodeOffset, + checkpointLen, checkpoints)) { + LOG(ERROR) << "checkpoint decode failure " + << folly::humanify(std::string(buf_, checkpointLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (checkpoints.size() != 1 || checkpoints[0].port != port_) { + LOG(ERROR) << "illegal local checkpoint " + << folly::humanify(std::string(buf_, checkpointLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + const Checkpoint &checkpoint = checkpoints[0]; + auto numBlocks = checkpoint.numBlocks; + VLOG(1) << "received local checkpoint " << checkpoint; + + if (numBlocks == -1) { + // Receiver failed while sending DONE cmd + return READ_RECEIVER_CMD; + } + + ErrorCode errCode = transferHistory.setLocalCheckpoint(checkpoint); + if (errCode == INVALID_CHECKPOINT) { + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (errCode == NO_PROGRESS) { + ++numReconnectWithoutProgress_; + } else { + numReconnectWithoutProgress_ = 0; + } + return SEND_SETTINGS; +} + +SenderState SenderThread::sendSettings() { + VLOG(1) << *this << " entered SEND_SETTINGS state"; + auto &options = WdtOptions::get(); + int64_t readTimeoutMillis = options.read_timeout_millis; + int64_t writeTimeoutMillis = options.write_timeout_millis; + int64_t off = 0; + buf_[off++] = Protocol::SETTINGS_CMD; + bool sendFileChunks = wdtParent_->isSendFileChunks(); + Settings settings; + settings.readTimeoutMillis = readTimeoutMillis; + settings.writeTimeoutMillis = writeTimeoutMillis; + settings.transferId = wdtParent_->getTransferId(); + settings.enableChecksum = options.enable_checksum; + settings.sendFileChunks = sendFileChunks; + settings.blockModeDisabled = (options.block_size_mbytes <= 0); + Protocol::encodeSettings(threadProtocolVersion_, buf_, off, + Protocol::kMaxSettings, settings); + int64_t toWrite = sendFileChunks ? Protocol::kMinBufLength : off; + int64_t written = socket_->write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << "Socket write failure " << written << " " << toWrite; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CONNECT; + } + threadStats_.addHeaderBytes(toWrite); + return sendFileChunks ? READ_FILE_CHUNKS : SEND_BLOCKS; +} + +SenderState SenderThread::sendBlocks() { + VLOG(1) << *this << " entered SEND_BLOCKS state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + if (threadProtocolVersion_ >= Protocol::RECEIVER_PROGRESS_REPORT_VERSION && + !totalSizeSent_ && dirQueue_->fileDiscoveryFinished()) { + return SEND_SIZE_CMD; + } + ErrorCode transferStatus; + std::unique_ptr source = dirQueue_->getNextSource(transferStatus); + if (!source) { + return SEND_DONE_CMD; + } + WDT_CHECK(!source->hasError()); + TransferStats transferStats = + wdtParent_->sendOneByteSource(socket_, source, transferStatus); + threadStats_ += transferStats; + source->addTransferStats(transferStats); + source->close(); + if (!transferHistory.addSource(source)) { + // global checkpoint received for this thread. no point in + // continuing + LOG(ERROR) << *this << " global checkpoint received. Stopping"; + threadStats_.setErrorCode(CONN_ERROR); + return END; + } + if (transferStats.getErrorCode() != OK) { + return CHECK_FOR_ABORT; + } + return SEND_BLOCKS; +} + +SenderState SenderThread::sendSizeCmd() { + VLOG(1) << *this << " entered SEND_SIZE_CMD state"; + int64_t off = 0; + buf_[off++] = Protocol::SIZE_CMD; + + Protocol::encodeSize(buf_, off, Protocol::kMaxSize, + dirQueue_->getTotalSize()); + int64_t written = socket_->write(buf_, off); + if (written != off) { + LOG(ERROR) << "Socket write error " << off << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(off); + totalSizeSent_ = true; + return SEND_BLOCKS; +} + +SenderState SenderThread::sendDoneCmd() { + VLOG(1) << *this << " entered SEND_DONE_CMD state"; + int64_t off = 0; + buf_[off++] = Protocol::DONE_CMD; + auto pair = dirQueue_->getNumBlocksAndStatus(); + int64_t numBlocksDiscovered = pair.first; + ErrorCode transferStatus = pair.second; + buf_[off++] = transferStatus; + Protocol::encodeDone(threadProtocolVersion_, buf_, off, Protocol::kMaxDone, + numBlocksDiscovered, dirQueue_->getTotalSize()); + int toWrite = Protocol::kMinBufLength; + int64_t written = socket_->write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << "Socket write failure " << written << " " << toWrite; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(toWrite); + VLOG(1) << "Wrote done cmd on " << socket_->getFd() + << " waiting for reply..."; + return READ_RECEIVER_CMD; +} + +SenderState SenderThread::checkForAbort() { + LOG(INFO) << *this << " entered CHECK_FOR_ABORT state"; + auto numRead = socket_->read(buf_, 1); + if (numRead != 1) { + VLOG(1) << "No abort cmd found"; + return CONNECT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd != Protocol::ABORT_CMD) { + VLOG(1) << "Unexpected result found while reading for abort " << buf_[0]; + return CONNECT; + } + threadStats_.addHeaderBytes(1); + return PROCESS_ABORT_CMD; +} + +SenderState SenderThread::readFileChunks() { + LOG(INFO) << *this << " entered READ_FILE_CHUNKS state "; + int64_t numRead = socket_->read(buf_, 1); + if (numRead != 1) { + LOG(ERROR) << "Socket read error 1 " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd == Protocol::ABORT_CMD) { + return PROCESS_ABORT_CMD; + } + if (cmd == Protocol::WAIT_CMD) { + return READ_FILE_CHUNKS; + } + if (cmd == Protocol::ACK_CMD) { + if (!wdtParent_->isFileChunksReceived()) { + LOG(ERROR) << "Sender has not yet received file chunks, but receiver " + << "thinks it has already sent it"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + return SEND_BLOCKS; + } + if (cmd != Protocol::CHUNKS_CMD) { + LOG(ERROR) << "Unexpected cmd " << cmd; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + int64_t toRead = Protocol::kChunksCmdLen; + numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + int64_t off = 0; + int64_t bufSize, numFiles; + Protocol::decodeChunksCmd(buf_, off, bufSize, numFiles); + LOG(INFO) << "File chunk list has " << numFiles + << " entries and is broken in buffers of length " << bufSize; + std::unique_ptr chunkBuffer(new char[bufSize]); + std::vector fileChunksInfoList; + while (true) { + int64_t numFileChunks = fileChunksInfoList.size(); + if (numFileChunks > numFiles) { + // We should never be able to read more file chunks than mentioned in the + // chunks cmd. Chunks cmd has buffer size used to transfer chunks and also + // number of chunks. This chunks are read and parsed and added to + // fileChunksInfoList. Number of chunks we decode should match with the + // number mentioned in the Chunks cmd. + LOG(ERROR) << "Number of file chunks received is more than the number " + "mentioned in CHUNKS_CMD " << numFileChunks << " " + << numFiles; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (numFileChunks == numFiles) { + break; + } + toRead = sizeof(int32_t); + numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + toRead = folly::loadUnaligned(buf_); + toRead = folly::Endian::little(toRead); + numRead = socket_->read(chunkBuffer.get(), toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + off = 0; + // decode function below adds decoded file chunks to fileChunksInfoList + bool success = Protocol::decodeFileChunksInfoList( + chunkBuffer.get(), off, toRead, fileChunksInfoList); + if (!success) { + LOG(ERROR) << "Unable to decode file chunks list"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + } + wdtParent_->setFileChunksInfo(fileChunksInfoList); + // send ack for file chunks list + buf_[0] = Protocol::ACK_CMD; + int64_t toWrite = 1; + int64_t written = socket_->write(buf_, toWrite); + if (toWrite != written) { + LOG(ERROR) << "Socket write error " << toWrite << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(written); + return SEND_BLOCKS; +} + +SenderState SenderThread::readReceiverCmd() { + VLOG(1) << *this << " entered READ_RECEIVER_CMD state"; + int64_t numRead = socket_->read(buf_, 1); + if (numRead != 1) { + LOG(ERROR) << "READ unexpected " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd == Protocol::ERR_CMD) { + return PROCESS_ERR_CMD; + } + if (cmd == Protocol::WAIT_CMD) { + return PROCESS_WAIT_CMD; + } + if (cmd == Protocol::DONE_CMD) { + return PROCESS_DONE_CMD; + } + if (cmd == Protocol::ABORT_CMD) { + return PROCESS_ABORT_CMD; + } + if (cmd == Protocol::LOCAL_CHECKPOINT_CMD) { + int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + int64_t toRead = checkpointLen - 1; + numRead = socket_->read(buf_ + 1, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Could not read possible local checkpoint " << toRead << " " + << numRead << " " << port_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + int64_t offset = 0; + std::vector checkpoints; + if (Protocol::decodeCheckpoints(threadProtocolVersion_, buf_, offset, + checkpointLen, checkpoints)) { + if (checkpoints.size() == 1 && checkpoints[0].port == port_ && + checkpoints[0].numBlocks == 0 && + checkpoints[0].lastBlockReceivedBytes == 0) { + // In a spurious local checkpoint, number of blocks and offset must both + // be zero + // Ignore the checkpoint + LOG(WARNING) + << "Received valid but unexpected local checkpoint, ignoring " + << port_ << " checkpoint " << checkpoints[0]; + return READ_RECEIVER_CMD; + } + } + LOG(ERROR) << "Failed to verify spurious local checkpoint, port " << port_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + LOG(ERROR) << "Read unexpected receiver cmd " << cmd << " port " << port_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; +} + +SenderState SenderThread::processDoneCmd() { + VLOG(1) << *this << " entered PROCESS_DONE_CMD state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + transferHistory.markAllAcknowledged(); + + // send ack for DONE + buf_[0] = Protocol::DONE_CMD; + socket_->write(buf_, 1); + + socket_->shutdown(); + auto numRead = socket_->read(buf_, Protocol::kMinBufLength); + if (numRead != 0) { + LOG(WARNING) << "EOF not found when expected"; + return END; + } + VLOG(1) << "done with transfer, port " << port_; + return END; +} + +SenderState SenderThread::processWaitCmd() { + LOG(INFO) << *this << " entered PROCESS_WAIT_CMD state "; + ; + ThreadTransferHistory &transferHistory = getTransferHistory(); + VLOG(1) << "received WAIT_CMD, port " << port_; + transferHistory.markAllAcknowledged(); + return READ_RECEIVER_CMD; +} + +SenderState SenderThread::processErrCmd() { + LOG(INFO) << *this << " entered PROCESS_ERR_CMD state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + int64_t toRead = sizeof(int16_t); + int64_t numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "read unexpected " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + + int16_t checkpointsLen = folly::loadUnaligned(buf_); + checkpointsLen = folly::Endian::little(checkpointsLen); + char checkpointBuf[checkpointsLen]; + numRead = socket_->read(checkpointBuf, checkpointsLen); + if (numRead != checkpointsLen) { + LOG(ERROR) << "read unexpected " << checkpointsLen << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + + std::vector checkpoints; + int64_t decodeOffset = 0; + if (!Protocol::decodeCheckpoints(threadProtocolVersion_, checkpointBuf, + decodeOffset, checkpointsLen, checkpoints)) { + LOG(ERROR) << "checkpoint decode failure " + << folly::humanify(std::string(checkpointBuf, checkpointsLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + transferHistory.markAllAcknowledged(); + for (auto &checkpoint : checkpoints) { + LOG(INFO) << *this << " Received global checkpoint " << checkpoint; + transferHistoryController_->handleGlobalCheckpoint(checkpoint); + } + return SEND_BLOCKS; +} + +SenderState SenderThread::processAbortCmd() { + LOG(INFO) << *this << " entered PROCESS_ABORT_CMD state "; + ThreadTransferHistory &transferHistory = getTransferHistory(); + threadStats_.setErrorCode(ABORT); + int toRead = Protocol::kAbortLength; + auto numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + // can not read checkpoint, but still must exit because of ABORT + LOG(ERROR) << "Error while trying to read ABORT cmd " << numRead << " " + << toRead; + return END; + } + int64_t offset = 0; + int32_t negotiatedProtocol; + ErrorCode remoteError; + int64_t checkpoint; + Protocol::decodeAbort(buf_, offset, negotiatedProtocol, remoteError, + checkpoint); + threadStats_.setRemoteErrorCode(remoteError); + std::string failedFileName = transferHistory.getSourceId(checkpoint); + LOG(WARNING) << *this << "Received abort on " + << " remote protocol version " << negotiatedProtocol + << " remote error code " << errorCodeToStr(remoteError) + << " file " << failedFileName << " checkpoint " << checkpoint; + wdtParent_->abort(remoteError); + if (remoteError == VERSION_MISMATCH) { + if (Protocol::negotiateProtocol( + negotiatedProtocol, threadProtocolVersion_) == negotiatedProtocol) { + // sender can support this negotiated version + negotiatedProtocol_ = negotiatedProtocol; + return PROCESS_VERSION_MISMATCH; + } else { + LOG(ERROR) << "Sender can not support receiver version " + << negotiatedProtocol; + threadStats_.setRemoteErrorCode(VERSION_INCOMPATIBLE); + } + } + return END; +} + +SenderState SenderThread::processVersionMismatch() { + LOG(INFO) << *this << " entered PROCESS_VERSION_MISMATCH state "; + WDT_CHECK(threadStats_.getErrorCode() == ABORT); + auto negotiationStatus = wdtParent_->getNegotiationStatus(); + WDT_CHECK_NE(negotiationStatus, V_MISMATCH_FAILED) + << "Thread should have ended in case of version mismatch"; + if (negotiationStatus == V_MISMATCH_RESOLVED) { + LOG(WARNING) << *this << " Protocol version already negotiated, but " + "transfer still aborted due to version mismatch"; + return END; + } + WDT_CHECK_EQ(negotiationStatus, V_MISMATCH_WAIT); + // Need a barrier here to make sure all the negotiated protocol versions + // have been collected + auto barrier = controller_->getBarrier(VERSION_MISMATCH_BARRIER); + barrier->execute(); + VLOG(1) << *this << " cleared the protocol version barrier"; + auto execFunnel = controller_->getFunnel(VERSION_MISMATCH_FUNNEL); + while (true) { + auto status = execFunnel->getStatus(); + switch (status) { + case FUNNEL_START: { + LOG(INFO) << *this << " started the funnel for version mismatch"; + wdtParent_->setProtoNegotiationStatus(V_MISMATCH_FAILED); + if (wdtParent_->verifyVersionMismatchStats() != OK) { + execFunnel->notifySuccess(); + return END; + } + if (transferHistoryController_->handleVersionMismatch() != OK) { + execFunnel->notifySuccess(); + return END; + } + int negotiatedProtocol = 0; + for (int threadProtocolVersion_ : + wdtParent_->getNegotiatedProtocols()) { + if (threadProtocolVersion_ > 0) { + if (negotiatedProtocol > 0 && + negotiatedProtocol != threadProtocolVersion_) { + LOG(ERROR) << "Different threads negotiated different protocols " + << negotiatedProtocol << " " << threadProtocolVersion_; + execFunnel->notifySuccess(); + return END; + } + negotiatedProtocol = threadProtocolVersion_; + } + } + WDT_CHECK_GT(negotiatedProtocol, 0); + LOG_IF(INFO, negotiatedProtocol != threadProtocolVersion_) + << *this << "Changing protocol version to " << negotiatedProtocol + << ", previous version " << threadProtocolVersion_; + wdtParent_->setProtocolVersion(negotiatedProtocol); + threadProtocolVersion_ = wdtParent_->getProtocolVersion(); + threadStats_.setRemoteErrorCode(OK); + wdtParent_->setProtoNegotiationStatus(V_MISMATCH_RESOLVED); + wdtParent_->clearAbort(); + execFunnel->notifySuccess(); + return CONNECT; + } + case FUNNEL_PROGRESS: { + execFunnel->wait(); + break; + } + case FUNNEL_END: { + negotiationStatus = wdtParent_->getNegotiationStatus(); + WDT_CHECK_NE(negotiationStatus, V_MISMATCH_WAIT); + if (negotiationStatus == V_MISMATCH_FAILED) { + return END; + } + if (negotiationStatus == V_MISMATCH_RESOLVED) { + threadProtocolVersion_ = wdtParent_->getProtocolVersion(); + threadStats_.setRemoteErrorCode(OK); + return CONNECT; + } + } + } + } +} + +void SenderThread::start() { + INIT_PERF_STAT_REPORT + Clock::time_point startTime = Clock::now(); + auto completionGuard = folly::makeGuard([&] { + ThreadTransferHistory &transferHistory = getTransferHistory(); + transferHistory.markNotInUse(); + controller_->deRegisterThread(threadIndex_); + controller_->executeAtEnd([&]() { wdtParent_->endCurTransfer(); }); + }); + controller_->executeAtStart([&]() { wdtParent_->startNewTransfer(); }); + SenderState state = CONNECT; + while (state != END) { + ErrorCode abortCode = wdtParent_->getCurAbortCode(); + if (abortCode != OK) { + LOG(ERROR) << *this << "Transfer aborted " << errorCodeToStr(abortCode); + threadStats_.setErrorCode(ABORT); + if (abortCode == VERSION_MISMATCH) { + state = PROCESS_VERSION_MISMATCH; + } else { + break; + } + } + state = (this->*stateMap_[state])(); + } + + double totalTime = durationSeconds(Clock::now() - startTime); + LOG(INFO) << "Port " << port_ << " done. " << threadStats_ + << " Total throughput = " + << threadStats_.getEffectiveTotalBytes() / totalTime / kMbToB + << " Mbytes/sec"; + perfReport_ = *perfStatReport; + return; +} + +int SenderThread::getPort() const { + return port_; +} + +int SenderThread::getNegotiatedProtocol() const { + return negotiatedProtocol_; +} + +ErrorCode SenderThread::init() { + return OK; +} + +void SenderThread::reset() { + totalSizeSent_ = false; + threadStats_.setErrorCode(OK); +} +} +} diff --git a/SenderThread.h b/SenderThread.h new file mode 100644 index 00000000..f75684f5 --- /dev/null +++ b/SenderThread.h @@ -0,0 +1,278 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include +#include "WdtThread.h" +#include "Sender.h" +#include "ClientSocket.h" +#include "ThreadTransferHistory.h" +namespace facebook { +namespace wdt { +class DirectorySourceQueue; +/// state machine states +enum SenderState { + CONNECT, + READ_LOCAL_CHECKPOINT, + SEND_SETTINGS, + SEND_BLOCKS, + SEND_DONE_CMD, + SEND_SIZE_CMD, + CHECK_FOR_ABORT, + READ_FILE_CHUNKS, + READ_RECEIVER_CMD, + PROCESS_DONE_CMD, + PROCESS_WAIT_CMD, + PROCESS_ERR_CMD, + PROCESS_ABORT_CMD, + PROCESS_VERSION_MISMATCH, + END +}; + +/** + * This class represents one sender thread. It contains all the + * functionalities that a thread would need to send data over + * a connection to the receiver. + * All the sender threads share bunch of modules like directory queue, + * throttler, threads controller etc + */ +class SenderThread : public WdtThread { + public: + /// Identifers for the barriers used in the thread + enum SENDER_BARRIERS { VERSION_MISMATCH_BARRIER, NUM_BARRIERS }; + + /// Identifiers for the funnels used in the thread + enum SENDER_FUNNELS { VERSION_MISMATCH_FUNNEL, NUM_FUNNELS }; + + /// Identifier for the condition wrappers used in the thread + enum SENDER_CONDITIONS { NUM_CONDITIONS }; + + /// abort checker passed to client sockets. This checks both global sender + /// abort and whether global checkpoint has been received or not + class SocketAbortChecker : public WdtBase::AbortChecker { + public: + explicit SocketAbortChecker(WdtBase *wdtBase, + ThreadTransferHistory &transferHistory) + : AbortChecker(wdtBase), transferHistory_(transferHistory) { + } + + bool shouldAbort() const { + return AbortChecker::shouldAbort() || + transferHistory_.isGlobalCheckpointReceived(); + } + + private: + ThreadTransferHistory &transferHistory_; + }; + + /// Constructor for the sender thread + SenderThread(Sender *sender, int threadIndex, int32_t port, + ThreadsController *threadsController) + : WdtThread(threadIndex, sender->getProtocolVersion(), threadsController), + wdtParent_(sender), + port_(port), + dirQueue_(sender->dirQueue_.get()), + transferHistoryController_(sender->transferHistoryController_.get()) { + controller_->registerThread(threadIndex_); + transferHistoryController_->addThreadHistory(port_, threadStats_); + socketAbortChecker_ = + folly::make_unique(sender, getTransferHistory()); + threadStats_.setId(folly::to(threadIndex_)); + } + + typedef SenderState (SenderThread::*StateFunction)(); + + /// Returns the neogtiated protocol + int getNegotiatedProtocol() const override; + + /// Steps to do ebfore calling start + ErrorCode init() override; + + /// Reset the sender thread + void reset() override; + + /// Get the port sender thread is connecting to + int getPort() const override; + + /// Destructor of the sender thread + ~SenderThread() { + } + + private: + /// Overloaded operator for printing thread info + friend std::ostream &operator<<(std::ostream &os, + const SenderThread &senderThread); + + /// Parent shared among all the threads for meta information + Sender *wdtParent_; + + /// Special abort checker for the client socket + std::unique_ptr socketAbortChecker_{nullptr}; + + /// The main entry point of the thread + void start() override; + + /// Get the local transfer history + ThreadTransferHistory &getTransferHistory() { + return transferHistoryController_->getTransferHistory(port_); + } + + /** + * tries to connect to the receiver + * Previous states : Almost all states(in case of network errors, all states + * move to this state) + * Next states : SEND_SETTINGS(if there is no previous error) + * READ_LOCAL_CHECKPOINT(if there is previous error) + * END(failed) + */ + SenderState connect(); + /** + * tries to read local checkpoint and return unacked sources to queue. If the + * checkpoint value is -1, then we know previous attempt to send DONE had + * failed. So, we move to READ_RECEIVER_CMD state. + * Previous states : CONNECT + * Next states : CONNECT(read failure), + * END(protocol error or global checkpoint found), + * READ_RECEIVER_CMD(if checkpoint is -1), + * SEND_SETTINGS(success) + */ + SenderState readLocalCheckPoint(); + /** + * sends sender settings to the receiver + * Previous states : READ_LOCAL_CHECKPOINT, + * CONNECT + * Next states : SEND_BLOCKS(success), + * CONNECT(failure) + */ + SenderState sendSettings(); + /** + * sends blocks to receiver till the queue is not empty. After transferring a + * block, we add it to the history. While adding to history, if it is found + * that global checkpoint has been received for this thread, we move to END + * state. + * Previous states : SEND_SETTINGS, + * PROCESS_ERR_CMD + * Next states : SEND_BLOCKS(success), + * END(global checkpoint received), + * CHECK_FOR_ABORT(socket write failure), + * SEND_DONE_CMD(no more blocks left to transfer) + */ + SenderState sendBlocks(); + /** + * sends DONE cmd to the receiver + * Previous states : SEND_BLOCKS + * Next states : CONNECT(failure), + * READ_RECEIVER_CMD(success) + */ + SenderState sendDoneCmd(); + /** + * sends size cmd to the receiver + * Previous states : SEND_BLOCKS + * Next states : CHECK_FOR_ABORT(failure), + * SEND_BLOCKS(success) + */ + SenderState sendSizeCmd(); + /** + * checks to see if the receiver has sent ABORT or not + * Previous states : SEND_BLOCKS, + * SEND_DONE_CMD + * Next states : CONNECT(no ABORT cmd), + * END(protocol error), + * PROCESS_ABORT_CMD(read ABORT cmd) + */ + SenderState checkForAbort(); + /** + * reads previously transferred file chunks list. If it receives an ACK cmd, + * then it moves on. If wait cmd is received, it waits. Otherwise reads the + * file chunks and when done starts directory queue thread. + * Previous states : SEND_SETTINGS, + * Next states: READ_FILE_CHUNKS(if wait cmd is received), + * CHECK_FOR_ABORT(network error), + * END(protocol error), + * SEND_BLOCKS(success) + * + */ + SenderState readFileChunks(); + /** + * reads receiver cmd + * Previous states : SEND_DONE_CMD + * Next states : PROCESS_DONE_CMD, + * PROCESS_WAIT_CMD, + * PROCESS_ERR_CMD, + * END(protocol error), + * CONNECT(failure) + */ + SenderState readReceiverCmd(); + /** + * handles DONE cmd + * Previous states : READ_RECEIVER_CMD + * Next states : END + */ + SenderState processDoneCmd(); + /** + * handles WAIT cmd + * Previous states : READ_RECEIVER_CMD + * Next states : READ_RECEIVER_CMD + */ + SenderState processWaitCmd(); + /** + * reads list of global checkpoints and returns unacked sources to queue. + * Previous states : READ_RECEIVER_CMD + * Next states : CONNECT(socket read failure) + * END(checkpoint list decode failure), + * SEND_BLOCKS(success) + */ + SenderState processErrCmd(); + /** + * processes ABORT cmd + * Previous states : CHECK_FOR_ABORT, + * READ_RECEIVER_CMD + * Next states : END + */ + SenderState processAbortCmd(); + + /** + * waits for all active threads to be aborted, checks to see if the abort was + * due to version mismatch. Also performs various sanity checks. + * Previous states : Almost all threads, abort flags is checked between every + * state transition + * Next states : CONNECT(Abort was due to version kismatch), + * END(if abort was not due to version mismatch or some sanity + * check failed) + */ + SenderState processVersionMismatch(); + + /// mapping from sender states to state functions + static const StateFunction stateMap_[]; + + /// Port number of this sender thread + const int32_t port_; + + /// Negotiated protocol of the sender thread + int negotiatedProtocol_{-1}; + + /// Pointer to client socket to maintain connection to the receiver + std::unique_ptr socket_; + + /// Buffer used by the sender thread to read/write data + char buf_[Protocol::kMinBufLength]; + + /// whether total file size has been sent to the receiver + bool totalSizeSent_{false}; + + /// number of consecutive reconnects without any progress + int numReconnectWithoutProgress_{0}; + + /// Point to the directory queue of parent sender + DirectorySourceQueue *dirQueue_; + + /// Thread history controller shared across all threads + TransferHistoryController *transferHistoryController_; +}; +} +} diff --git a/ThreadTransferHistory.cpp b/ThreadTransferHistory.cpp new file mode 100644 index 00000000..a69f7fd4 --- /dev/null +++ b/ThreadTransferHistory.cpp @@ -0,0 +1,282 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ThreadTransferHistory.h" +#include "Sender.h" +namespace facebook { +namespace wdt { +ThreadTransferHistory::ThreadTransferHistory(DirectorySourceQueue &queue, + TransferStats &threadStats, + int32_t port) + : queue_(queue), threadStats_(threadStats), port_(port) { + VLOG(1) << "Making thread history for port " << port_; +} + +std::string ThreadTransferHistory::getSourceId(int64_t index) { + std::lock_guard lock(mutex_); + std::string sourceId; + const int64_t historySize = history_.size(); + if (index >= 0 && index < historySize) { + sourceId = history_[index]->getIdentifier(); + } else { + LOG(WARNING) << "Trying to read out of bounds data " << index << " " + << history_.size(); + } + return sourceId; +} + +bool ThreadTransferHistory::addSource(std::unique_ptr &source) { + std::lock_guard lock(mutex_); + if (globalCheckpoint_) { + // already received an error for this thread + VLOG(1) << "adding source after global checkpoint is received. returning " + "the source to the queue"; + markSourceAsFailed(source, lastCheckpoint_.get()); + lastCheckpoint_.reset(); + queue_.returnToQueue(source); + return false; + } + history_.emplace_back(std::move(source)); + return true; +} + +ErrorCode ThreadTransferHistory::setLocalCheckpoint( + const Checkpoint &checkpoint) { + std::lock_guard lock(mutex_); + return setCheckpointAndReturnToQueue(checkpoint, false); +} + +ErrorCode ThreadTransferHistory::setGlobalCheckpoint( + const Checkpoint &checkpoint) { + std::unique_lock lock(mutex_); + ErrorCode status = setCheckpointAndReturnToQueue(checkpoint, true); + while (inUse_) { + // have to wait, error thread signalled through globalCheckpoint_ flag + LOG(INFO) << "Transfer history still in use, waiting, checkpoint " + << checkpoint; + conditionInUse_.wait(lock); + } + return status; +} +ErrorCode ThreadTransferHistory::setCheckpointAndReturnToQueue( + const Checkpoint &checkpoint, bool globalCheckpoint) { + const int64_t historySize = history_.size(); + int64_t numReceivedSources = checkpoint.numBlocks; + int64_t lastBlockReceivedBytes = checkpoint.lastBlockReceivedBytes; + if (numReceivedSources > historySize) { + LOG(ERROR) + << "checkpoint is greater than total number of sources transferred " + << history_.size() << " " << numReceivedSources; + return INVALID_CHECKPOINT; + } + ErrorCode errCode = validateCheckpoint(checkpoint, globalCheckpoint); + if (errCode == INVALID_CHECKPOINT) { + return INVALID_CHECKPOINT; + } + globalCheckpoint_ |= globalCheckpoint; + lastCheckpoint_ = folly::make_unique(checkpoint); + int64_t numFailedSources = historySize - numReceivedSources; + if (numFailedSources == 0 && lastBlockReceivedBytes > 0) { + if (!globalCheckpoint) { + // no block to apply checkpoint offset. This can happen if we receive same + // local checkpoint without adding anything to the history + LOG(WARNING) << "Local checkpoint has received bytes for last block, but " + "there are no unacked blocks in the history. Ignoring."; + } + } + numAcknowledged_ = numReceivedSources; + std::vector> sourcesToReturn; + for (int64_t i = 0; i < numFailedSources; i++) { + std::unique_ptr source = std::move(history_.back()); + history_.pop_back(); + const Checkpoint *checkpointPtr = + (i == numFailedSources - 1 ? &checkpoint : nullptr); + markSourceAsFailed(source, checkpointPtr); + sourcesToReturn.emplace_back(std::move(source)); + } + queue_.returnToQueue(sourcesToReturn); + LOG(INFO) << numFailedSources + << " number of sources returned to queue, checkpoint: " + << checkpoint; + return errCode; +} + +std::vector ThreadTransferHistory::popAckedSourceStats() { + std::unique_lock lock(mutex_); + const int64_t historySize = history_.size(); + WDT_CHECK(numAcknowledged_ == historySize); + // no locking needed, as this should be called after transfer has finished + std::vector sourceStats; + while (!history_.empty()) { + sourceStats.emplace_back(std::move(history_.back()->getTransferStats())); + history_.pop_back(); + } + return sourceStats; +} + +void ThreadTransferHistory::markAllAcknowledged() { + std::unique_lock lock(mutex_); + numAcknowledged_ = history_.size(); +} + +void ThreadTransferHistory::returnUnackedSourcesToQueue() { + std::unique_lock lock(mutex_); + Checkpoint checkpoint; + checkpoint.numBlocks = numAcknowledged_; + setCheckpointAndReturnToQueue(checkpoint, false); +} + +ErrorCode ThreadTransferHistory::validateCheckpoint( + const Checkpoint &checkpoint, bool globalCheckpoint) { + if (lastCheckpoint_ == nullptr) { + return OK; + } + if (checkpoint.numBlocks < lastCheckpoint_->numBlocks) { + LOG(ERROR) << "Current checkpoint must be higher than previous checkpoint, " + "Last checkpoint: " << *lastCheckpoint_ + << ", Current checkpoint: " << checkpoint; + return INVALID_CHECKPOINT; + } + if (checkpoint.numBlocks > lastCheckpoint_->numBlocks) { + return OK; + } + bool noProgress = false; + // numBlocks same + if (checkpoint.lastBlockSeqId == lastCheckpoint_->lastBlockSeqId && + checkpoint.lastBlockOffset == lastCheckpoint_->lastBlockOffset) { + // same block + if (checkpoint.lastBlockReceivedBytes != + lastCheckpoint_->lastBlockReceivedBytes) { + LOG(ERROR) << "Current checkpoint has different received bytes, but all " + "other fields are same, Last checkpoint " + << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; + return INVALID_CHECKPOINT; + } + noProgress = true; + } else { + // different block + WDT_CHECK(checkpoint.lastBlockReceivedBytes >= 0); + if (checkpoint.lastBlockReceivedBytes == 0) { + noProgress = true; + } + } + if (noProgress && !globalCheckpoint) { + // we can get same global checkpoint multiple times, so no need to check for + // progress + LOG(WARNING) << "No progress since last checkpoint, Last checkpoint: " + << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; + return NO_PROGRESS; + } + return OK; +} + +void ThreadTransferHistory::markSourceAsFailed( + std::unique_ptr &source, const Checkpoint *checkpoint) { + auto metadata = source->getMetaData(); + bool validCheckpoint = false; + if (checkpoint != nullptr) { + if (checkpoint->hasSeqId) { + if ((checkpoint->lastBlockSeqId == metadata.seqId) && + (checkpoint->lastBlockOffset == source->getOffset())) { + validCheckpoint = true; + } else { + LOG(WARNING) + << "Checkpoint block does not match history block. Checkpoint: " + << checkpoint->lastBlockSeqId << ", " << checkpoint->lastBlockOffset + << " History: " << metadata.seqId << ", " << source->getOffset(); + } + } else { + // Receiver at lower version! + // checkpoint does not have seq-id. We have to blindly trust + // lastBlockReceivedBytes. If we do not, transfer will fail because of + // number of bytes mismatch. Even if an error happens because of this, + // Receiver will fail. + validCheckpoint = true; + } + } + int64_t receivedBytes = + (validCheckpoint ? checkpoint->lastBlockReceivedBytes : 0); + TransferStats &sourceStats = source->getTransferStats(); + if (sourceStats.getErrorCode() != OK) { + // already marked as failed + sourceStats.addEffectiveBytes(0, receivedBytes); + threadStats_.addEffectiveBytes(0, receivedBytes); + } else { + auto dataBytes = source->getSize(); + auto headerBytes = sourceStats.getEffectiveHeaderBytes(); + int64_t wastedBytes = dataBytes - receivedBytes; + sourceStats.subtractEffectiveBytes(headerBytes, wastedBytes); + sourceStats.decrNumBlocks(); + sourceStats.setErrorCode(SOCKET_WRITE_ERROR); + sourceStats.incrFailedAttempts(); + + threadStats_.subtractEffectiveBytes(headerBytes, wastedBytes); + threadStats_.decrNumBlocks(); + threadStats_.incrFailedAttempts(); + } + source->advanceOffset(receivedBytes); +} + +bool ThreadTransferHistory::isGlobalCheckpointReceived() { + std::lock_guard lock(mutex_); + return globalCheckpoint_; +} + +void ThreadTransferHistory::markNotInUse() { + std::lock_guard lock(mutex_); + inUse_ = false; + conditionInUse_.notify_all(); +} + +TransferHistoryController::TransferHistoryController( + DirectorySourceQueue &dirQueue) + : dirQueue_(dirQueue) { +} + +ThreadTransferHistory &TransferHistoryController::getTransferHistory( + int32_t port) { + auto it = threadHistoriesMap_.find(port); + WDT_CHECK(it != threadHistoriesMap_.end()) << "port not found" << port; + return *(it->second.get()); +} + +void TransferHistoryController::addThreadHistory(int32_t port, + TransferStats &threadStats) { + VLOG(1) << "Adding the history for " << port; + threadHistoriesMap_.emplace(port, folly::make_unique( + dirQueue_, threadStats, port)); +} + +ErrorCode TransferHistoryController::handleVersionMismatch() { + for (auto &historyPair : threadHistoriesMap_) { + auto &history = historyPair.second; + if (history->getNumAcked() > 0) { + LOG(ERROR) << "Even though the transfer aborted due to VERSION_MISMATCH, " + "some blocks got acked by the receiver, port " + << historyPair.first << " numAcked " << history->getNumAcked(); + return ERROR; + } + history->returnUnackedSourcesToQueue(); + } + return OK; +} + +void TransferHistoryController::handleGlobalCheckpoint( + const Checkpoint &checkpoint) { + auto errPort = checkpoint.port; + auto it = threadHistoriesMap_.find(errPort); + if (it == threadHistoriesMap_.end()) { + LOG(ERROR) << "Invalid checkpoint " << checkpoint + << ". No sender thread running on port " << errPort; + return; + } + VLOG(1) << "received global checkpoint " << checkpoint; + it->second->setGlobalCheckpoint(checkpoint); +} +} +} diff --git a/ThreadTransferHistory.h b/ThreadTransferHistory.h new file mode 100644 index 00000000..dfc26559 --- /dev/null +++ b/ThreadTransferHistory.h @@ -0,0 +1,168 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "DirectorySourceQueue.h" +#include "Reporting.h" +#include "Protocol.h" +#include +#include +namespace facebook { +namespace wdt { +/// Transfer history of a sender thread +class ThreadTransferHistory { + public: + /** + * @param queue directory queue + * @param threadStats stat object of the thread + */ + ThreadTransferHistory(DirectorySourceQueue &queue, TransferStats &threadStats, + int32_t port); + + /** + * @param index of the source + * @return if index is in bounds, returns the identifier for the + * source, else returns empty string + */ + std::string getSourceId(int64_t index); + + /** + * Adds the source to the history. If global checkpoint has already been + * received, then the source is returned to the queue. + * + * @param source source to be added to history + * @return true if added to history, false if not added due to a + * global checkpoint + */ + bool addSource(std::unique_ptr &source); + + /** + * Sets checkpoint. Also, returns unacked sources to queue + * @param checkpoint checkpoint received + * @param globalCheckpoint global or local checkpoint + * @return number of sources returned to queue + */ + ErrorCode setLocalCheckpoint(const Checkpoint &checkpoint); + + /** + * @return stats for acked sources, must be called after all the + * unacked sources are returned to the queue + */ + std::vector popAckedSourceStats(); + + /// marks all the sources as acked + void markAllAcknowledged(); + + /** + * returns all unacked sources to the queue + */ + void returnUnackedSourcesToQueue(); + + /** + * @return number of sources acked by the receiver + */ + int64_t getNumAcked() const { + return numAcknowledged_; + } + + /// @return whether global checkpoint has been received or not + bool isGlobalCheckpointReceived(); + + /// Clears the inUse_ flag and notifies other waiting threads + void markNotInUse(); + + /// Copy constructor deleted + ThreadTransferHistory(const ThreadTransferHistory &that) = delete; + + /// Delete the assignment operatory by copy + ThreadTransferHistory &operator=(const ThreadTransferHistory &that) = delete; + + private: + friend class TransferHistoryController; + /// Validates a checkpoint and returns the status + ErrorCode validateCheckpoint(const Checkpoint &checkpoint, + bool globalCheckpoint); + /** + * Sets global checkpoint. If the history is still in use, waits for the + * error thread to add current block to the history. + * + * @param checkpoint checkpoint received + * + * @return status of the operation + */ + ErrorCode setGlobalCheckpoint(const Checkpoint &checkpoint); + + void markSourceAsFailed(std::unique_ptr &source, + const Checkpoint *checkpoint); + /** + * Sets checkpoint. Also, returns unacked sources to queue + * + * @param checkpoint checkpoint received + * @param globalCheckpoint global or local checkpoint + * @return status of sources returned to queue + */ + ErrorCode setCheckpointAndReturnToQueue(const Checkpoint &checkpoint, + bool globalCheckpoint); + + /// reference to global queue + DirectorySourceQueue &queue_; + /// reference to thread stats + TransferStats &threadStats_; + /// history of the thread + std::vector> history_; + /// whether a global error checkpoint has been received or not + bool globalCheckpoint_{false}; + /// number of sources acked by the receiver thread + int64_t numAcknowledged_{0}; + /// last received checkpoint + std::unique_ptr lastCheckpoint_{nullptr}; + /// Port assosciated with the history + int32_t port_; + /// whether the owner thread is still using this + bool inUse_{true}; + /// Mutex used by history internally for synchronization + std::mutex mutex_; + /// Condition variable to signify the history being in use + std::condition_variable conditionInUse_; +}; + +/// Controller for history across the sender threads +class TransferHistoryController { + public: + /** + * Constructor for the history controller + * @param dirQueue Directory queue used by the sender + */ + explicit TransferHistoryController(DirectorySourceQueue &dirQueue); + + /** + * Add transfer history for a thread + * @param port Port being used by the sender thread + * @param threadStats Thread stats for the sender thread + */ + void addThreadHistory(int32_t port, TransferStats &threadStats); + + /// Get transfer history for a thread using its port number + ThreadTransferHistory &getTransferHistory(int32_t port); + + /// Handle version mismatch across the histories of all thread + ErrorCode handleVersionMismatch(); + + /// Handle checkpoint that was sent by the receiver + void handleGlobalCheckpoint(const Checkpoint &checkpoint); + + private: + /// Reference to the directory queue being used by the sender + DirectorySourceQueue &dirQueue_; + + /// Map of port (used by sender threads) and transfer history + std::unordered_map> + threadHistoriesMap_; +}; +} +} diff --git a/ThreadsController.cpp b/ThreadsController.cpp new file mode 100644 index 00000000..cca8b81c --- /dev/null +++ b/ThreadsController.cpp @@ -0,0 +1,241 @@ +#include "ThreadsController.h" +#include "WdtOptions.h" +using namespace std; +namespace facebook { +namespace wdt { +void ConditionGuardImpl::wait(int timeoutMillis) { + if (timeoutMillis <= 0) { + cv_.wait(*lock_); + return; + } + auto waitingTime = chrono::milliseconds(timeoutMillis); + cv_.wait_for(*lock_, waitingTime); +} + +void ConditionGuardImpl::notifyAll() { + cv_.notify_all(); +} + +void ConditionGuardImpl::notifyOne() { + cv_.notify_one(); +} + +ConditionGuardImpl::~ConditionGuardImpl() { + if (lock_ != nullptr) { + delete lock_; + } + cv_.notify_one(); +} + +ConditionGuardImpl::ConditionGuardImpl(mutex &guardMutex, + condition_variable &cv) + : cv_(cv) { + lock_ = new unique_lock(guardMutex); +} + +ConditionGuardImpl::ConditionGuardImpl(ConditionGuardImpl &&that) noexcept + : cv_(that.cv_) { + swap(lock_, that.lock_); +} + +ConditionGuardImpl ConditionGuard::acquire() { + return ConditionGuardImpl(mutex_, cv_); +} + +FunnelStatus Funnel::getStatus() { + unique_lock lock(mutex_); + if (status_ == FUNNEL_START) { + status_ = FUNNEL_PROGRESS; + return FUNNEL_START; + } + return status_; +} + +void Funnel::wait() { + unique_lock lock(mutex_); + if (status_ != FUNNEL_PROGRESS) { + return; + } + cv_.wait(lock); +} + +void Funnel::wait(int32_t waitingTime) { + auto waitMillis = chrono::milliseconds(waitingTime); + unique_lock lock(mutex_); + if (status_ != FUNNEL_PROGRESS) { + return; + } + cv_.wait_for(lock, waitMillis); +} + +void Funnel::notifySuccess() { + unique_lock lock(mutex_); + status_ = FUNNEL_END; + cv_.notify_all(); +} + +void Funnel::notifyFail() { + unique_lock lock(mutex_); + status_ = FUNNEL_START; + cv_.notify_one(); +} + +bool Barrier::checkForFinish() { + // lock should be held while calling this method + WDT_CHECK_GE(numThreads_, numHits_); + if (numHits_ == numThreads_) { + isComplete_ = true; + cv_.notify_all(); + } + return isComplete_; +} + +void Barrier::execute() { + unique_lock lock(mutex_); + WDT_CHECK(!isComplete_) << "Hitting the barrier after completion"; + ++numHits_; + if (checkForFinish()) { + return; + } + while (!isComplete_) { + cv_.wait(lock); + } +} + +void Barrier::deRegister() { + unique_lock lock(mutex_); + if (isComplete_) { + return; + } + --numThreads_; + checkForFinish(); +} + +ThreadsController::ThreadsController(int totalThreads) { + totalThreads_ = totalThreads; + for (int threadNum = 0; threadNum < totalThreads; ++threadNum) { + threadStateMap_[threadNum] = INIT; + } + execAtStart_.reset(new ExecuteOnceFunc(totalThreads_, true)); + execAtEnd_.reset(new ExecuteOnceFunc(totalThreads_, false)); +} + +void ThreadsController::registerThread(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + threadStateMap_[threadIndex] = RUNNING; +} + +void ThreadsController::deRegisterThread(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + threadStateMap_[threadIndex] = FINISHED; + // Notify all the barriers + for (auto barrier : barriers_) { + WDT_CHECK(barrier != nullptr); + barrier->deRegister(); + } +} + +ThreadStatus ThreadsController::getState(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + return it->second; +} + +void ThreadsController::markState(int threadIndex, ThreadStatus threadState) { + GuardLock lock(controllerMutex_); + threadStateMap_[threadIndex] = threadState; +} + +unordered_map ThreadsController::getThreadStates() const { + GuardLock lock(controllerMutex_); + return threadStateMap_; +} + +int ThreadsController::getTotalThreads() { + return totalThreads_; +} + +bool ThreadsController::hasThreads(int threadIndex, ThreadStatus threadState) { + GuardLock lock(controllerMutex_); + for (auto &threadPair : threadStateMap_) { + if (threadPair.first == threadIndex) { + continue; + } + if (threadPair.second == threadState) { + return true; + } + } + return false; +} + +shared_ptr ThreadsController::getCondition( + const uint64_t conditionIndex) { + bool isExists = (conditionGuards_.size() > conditionIndex) && + (conditionGuards_[conditionIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a condition wrapper that doesn't exist." + << " Request Index : " << conditionIndex + << ", num condition wrappers : " << conditionGuards_.size(); + return conditionGuards_[conditionIndex]; +} + +shared_ptr ThreadsController::getBarrier(const uint64_t barrierIndex) { + bool isExists = + (barriers_.size() > barrierIndex) && (barriers_[barrierIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a barrier that doesn't exist. Request index : " + << barrierIndex << ", num barriers " << barriers_.size(); + return barriers_[barrierIndex]; +} + +shared_ptr ThreadsController::getFunnel(const uint64_t funnelIndex) { + bool isExists = (funnelExecutors_.size() > funnelIndex) && + (funnelExecutors_[funnelIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a funnel that doesn't exist. Request index : " + << funnelIndex << ", num funnels " << funnelExecutors_.size(); + return funnelExecutors_[funnelIndex]; +} + +void ThreadsController::reset() { + // Only used in the case of long running mode + setNumBarriers(barriers_.size()); + setNumConditions(conditionGuards_.size()); + setNumFunnels(funnelExecutors_.size()); + execAtStart_->reset(); + execAtEnd_->reset(); + GuardLock lock(controllerMutex_); + // Restore threads back to initial state + for (auto &threadPair : threadStateMap_) { + threadPair.second = RUNNING; + } +} + +void ThreadsController::setNumBarriers(int numBarriers) { + // Meant to be called outside of threads + barriers_.clear(); + for (int i = 0; i < numBarriers; i++) { + barriers_.push_back(make_shared(getTotalThreads())); + } +} + +void ThreadsController::setNumConditions(int numConditions) { + conditionGuards_.clear(); + for (int i = 0; i < numConditions; i++) { + conditionGuards_.push_back(make_shared()); + } +} + +void ThreadsController::setNumFunnels(int numFunnels) { + funnelExecutors_.clear(); + for (int i = 0; i < numFunnels; i++) { + funnelExecutors_.push_back(make_shared()); + } +} +} +} diff --git a/ThreadsController.h b/ThreadsController.h new file mode 100644 index 00000000..927dd8dc --- /dev/null +++ b/ThreadsController.h @@ -0,0 +1,371 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include +#include +#include +#include +#include +#include "WdtThread.h" +#include "ErrorCodes.h" + +namespace facebook { +namespace wdt { + +class WdtThread; +/** + * Thread states that represent what kind of functionality + * are they executing on a higher level. + * INIT - State before running at the time of construction + * RUNNING - Thread is running without any errors + * WAITING - Thread is not doing anything meaninful but + * rather waiting on other threads for something + * FINISHED - Threads have finished with/without error + */ +enum ThreadStatus { INIT, RUNNING, WAITING, FINISHED }; + +/** + * Primitive that takes a function and executes + * it only once either on the first thread + * or the last thread entrance + */ +class ExecuteOnceFunc { + public: + /// Constructor for the once only executor + ExecuteOnceFunc(int numThreads, bool execFirst) { + execFirst_ = execFirst; + numThreads_ = numThreads; + } + + /// Deleted copy constructor + ExecuteOnceFunc(const ExecuteOnceFunc &that) = delete; + + /// Deleted assignment operator + ExecuteOnceFunc &operator=(const ExecuteOnceFunc &that) = delete; + + /// Implements the main functionality of the executor + template + void execute(Func &&execFunc) { + std::unique_lock lock(mutex_); + ++numHits_; + WDT_CHECK(numHits_ <= numThreads_); + int64_t numExpected = (execFirst_) ? 1 : numThreads_; + if (numHits_ == numExpected) { + execFunc(); + } + } + + /// Reset the number of hits + void reset() { + numHits_ = 0; + } + + private: + /// Mutex for thread synchronization + std::mutex mutex_; + /// Number of times execute has been called + int numHits_{0}; + /// Function can be executed on the first + /// thread or the last thread + bool execFirst_{true}; + /// Number of total threads + int numThreads_; +}; + +/** + * A scoped locking primitive. When you get this object + * it means that you already have the lock. You can also + * wait, notify etc using this primitive + */ +class ConditionGuardImpl { + public: + /// Release the lock and wait for the timeout + /// After the wait is over, lock is reacquired + void wait(int timeoutMillis = -1); + /// Notify all the threads waiting on the lock + void notifyAll(); + /// Notify one thread waiting on the lock + void notifyOne(); + /// Delete the copy constructor + ConditionGuardImpl(const ConditionGuardImpl &that) = delete; + /// Delete the copy assignment operator + ConditionGuardImpl &operator=(const ConditionGuardImpl &that) = delete; + /// Move constructor for the guard + ConditionGuardImpl(ConditionGuardImpl &&that) noexcept; + /// Move assignment operator deleted + ConditionGuardImpl &operator=(ConditionGuardImpl &&that) = delete; + /// Destructor that releases the lock, you would explicitly + /// need to notify any other threads waiting in the wait() + ~ConditionGuardImpl(); + + protected: + friend class ConditionGuard; + /// Constructor that takes the shared mutex and condition + /// variable + ConditionGuardImpl(std::mutex &mutex, std::condition_variable &cv); + /// Instance of lock is made on construction with the specified mutex + std::unique_lock *lock_{nullptr}; + /// Shared condition variable + std::condition_variable &cv_; +}; + +/** + * Class for simplifying the primitive to take a lock + * in conjunction with the ability to do things + * on a condition variable based on the lock. + * Use the condition guard like this + * ConditionGuard condition; + * auto guard = condition.acquire(); + * guard.wait(); + */ +class ConditionGuard { + public: + /// Caller has to call acquire before doing anything + ConditionGuardImpl acquire(); + + /// Default constructor + ConditionGuard() { + } + + /// Deleted copy constructor + ConditionGuard(const ConditionGuard &that) = delete; + + /// Deleted assignment operator + ConditionGuard &operator=(const ConditionGuard &that) = delete; + + private: + /// Mutex for the condition variable + std::mutex mutex_; + /// std condition variable to support the functionality + std::condition_variable cv_; +}; + +/** + * A barrier primitive. When called for executing + * will block the threads till all the threads registered + * call execute() + */ +class Barrier { + public: + /// Deleted copy constructor + Barrier(const Barrier &that) = delete; + + /// Deleted assignment operator + Barrier &operator=(const Barrier &that) = delete; + + /// Constructor which takes total number of threads + /// to be hit in order for the barrier to clear + explicit Barrier(int numThreads) { + numThreads_ = numThreads; + VLOG(1) << "making barrier with " << numThreads; + } + + /// Executes the main functionality of the barrier + void execute(); + + /** + * Thread controller should call this method when one thread + * has been finished, since that thread will no longer be + * participating in the barrier + */ + void deRegister(); + + private: + /// Checks for finish, need to hold a lock to call this method + bool checkForFinish(); + /// Condition variable that threads wait on + std::condition_variable cv_; + + /// Number of threads entered the execute + int64_t numHits_{0}; + + /// Total number of threads that are supposed + /// to hit the barrier + int numThreads_{0}; + + /// Thread synchronization mutex + std::mutex mutex_; + + /// Represents the completion of barrier + bool isComplete_{false}; +}; + +/** + * Different stages of the simple funnel + * FUNNEL_START the state of funnel at the beginning + * FUNNEL_PROGRESS is set by the first thread to enter the funnel + * and it means that funnel functionality is in progress + * FUNNEL_END means that funnel functionality has been executed + */ +enum FunnelStatus { FUNNEL_START, FUNNEL_PROGRESS, FUNNEL_END }; + +/** + * Primitive that makes the threads execute in a funnel + * manner. Only one thread gets to execute the main functionality + * while other entering threads wait (while executing a function) + */ +class Funnel { + public: + /// Deleted copy constructor + Funnel(const Funnel &that) = delete; + + /// Default constructor for funnel + Funnel() { + status_ = FUNNEL_START; + } + + /// Deleted assignment operator + Funnel &operator=(const Funnel &that) = delete; + + /** + * Get the current status of funnel. + * If the status is FUNNEL_START it gets set + * to FUNNEL_PROGRESS else it is just a get + */ + FunnelStatus getStatus(); + + /// Threads in progress can wait indefinitely + void wait(); + + /// Threads that get status as progress execute this function + void wait(int32_t waitingTime); + + /** + * The first thread that was able to start the funnel + * calls this method on successful execution + */ + void notifySuccess(); + + /// The first thread that was able to start the funnel + /// calls this method on failure in execution + void notifyFail(); + + private: + /// Status of the funnel + FunnelStatus status_; + /// Mutex for the simple funnel executor + std::mutex mutex_; + /// Condition variable on which progressing threads wait + std::condition_variable cv_; +}; + +/** + * Controller class responsible for the receiver + * and sender threads. Manages the states of threads and + * session information + */ +class ThreadsController { + public: + /// Constructor that takes in the total number of threads + /// to be run + explicit ThreadsController(int totalThreads); + + /// Make threads of a type Sender/Receiver + template + std::vector> makeThreads( + WdtBaseType *wdtParent, int numThreads, + const std::vector &ports) { + std::vector> threads; + for (int threadIndex = 0; threadIndex < numThreads; ++threadIndex) { + threads.emplace_back(folly::make_unique( + wdtParent, threadIndex, ports[threadIndex], this)); + } + return threads; + } + /// Mark the state of a thread + void markState(int threadIndex, ThreadStatus state); + + /// Get the status of the thread by index + ThreadStatus getState(int threadIndex); + + /// Execute a function func once, by the first thread + template + void executeAtStart(FunctionType &&fn) const { + execAtStart_->execute(fn); + } + + /// Execute a function once by the last thread + template + void executeAtEnd(FunctionType &&fn) const { + execAtEnd_->execute(fn); + } + + /// Returns a funnel executor shared between the threads + /// If the executor does not exist then it creates one + std::shared_ptr getFunnel(const uint64_t funnelIndex); + + /// Returns a barrier shared between the threads + /// If the executor does not exist then it creates one + std::shared_ptr getBarrier(const uint64_t barrierIndex); + + /// Get the condition variable wrapper + std::shared_ptr getCondition(const uint64_t conditionIndex); + + /* + * Returns back states of all the threads + */ + std::unordered_map getThreadStates() const; + + /// Register a thread, a thread registers with the state RUNNING + void registerThread(int threadIndex); + + /// De-register a thread, marks it ended + void deRegisterThread(int threadIndex); + + /// Returns true if any thread apart from the calling is in the state + bool hasThreads(int threadIndex, ThreadStatus threadState); + + /// Get the nunber of registered threads + int getTotalThreads(); + + /// Reset the thread controller so that same instance can be used again + void reset(); + + /// Set the total number of barriers + void setNumBarriers(int numBarriers); + + /// Set the number of condition wrappers + void setNumConditions(int numConditions); + + /// Set total number of funnel executors + void setNumFunnels(int numFunnels); + + /// Destructor for the threads controller + ~ThreadsController() { + } + + private: + /// Total number of threads managed by the thread controller + int totalThreads_; + + typedef std::unique_lock GuardLock; + + /// Mutex used in all of the thread controller methods + mutable std::mutex controllerMutex_; + + /// States of the threads + std::unordered_map threadStateMap_; + + /// Executor to execute things at the start of transfer + std::unique_ptr execAtStart_; + + /// Executor to execute things at the end of transfer + std::unique_ptr execAtEnd_; + + /// Vector of funnel executors, read/modified by get/set funnel methods + std::vector> funnelExecutors_; + + /// Vector of condition wrappers, read/modified by get/set condition methods + std::vector> conditionGuards_; + + /// Vector of barriers, can be read/modified by get/set barrier methods + std::vector> barriers_; +}; +} +} diff --git a/ThreadsControllerTest.cpp b/ThreadsControllerTest.cpp new file mode 100644 index 00000000..c9ad55e4 --- /dev/null +++ b/ThreadsControllerTest.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ThreadsController.h" +#include +#include +#include +#include +#include +#include +#include +using namespace std; +namespace facebook { +namespace wdt { +class ThreadUtil { + public: + template + void addThread(Fn&& func) { + threads_.emplace_back(func); + } + void joinThreads() { + for (auto& t : threads_) { + t.join(); + } + threads_.clear(); + } + + void threadSleep(int64_t millis) { + auto waitingTime = chrono::milliseconds(millis); + LOG(INFO) << "Sleeping for " << millis << " ms"; + unique_lock lock(mutex_); + cv_.wait_for(lock, waitingTime); + } + + void notifyThreads() { + cv_.notify_all(); + } + + ~ThreadUtil() { + joinThreads(); + } + + private: + vector threads_; + mutex mutex_; + condition_variable cv_; +}; + +TEST(ThreadsController, Barrier) { + int numThreads = 15; + { + Barrier barrier(numThreads); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&]() { barrier.execute(); }); + } + } + { + Barrier barrier(numThreads); + srand(time(nullptr)); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&barrier, &threadUtil, i]() { + if (i % 2 == 0) { + int seconds = folly::Random::rand32() % 5; + threadUtil.threadSleep(seconds * 100); + barrier.deRegister(); + return; + } + barrier.execute(); + }); + } + } + { + Barrier barrier(numThreads); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&barrier, &threadUtil, i]() { + switch (i) { + case 1: + threadUtil.threadSleep(1 * 100); + barrier.execute(); + return; + case 2: + threadUtil.threadSleep(2 * 100); + barrier.deRegister(); + return; + case 3: + threadUtil.threadSleep(5 * 100); + barrier.execute(); + return; + default: + barrier.execute(); + return; + }; + }); + } + } +} + +TEST(ThreadsController, ExecutOnceFunc) { + int numThreads = 8; + ExecuteOnceFunc execAtStart(numThreads, true); + ExecuteOnceFunc execAtEnd(numThreads, false); + ThreadUtil threadUtil; + int result = 0; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&]() { + execAtStart.execute([&]() { ++result; }); + execAtEnd.execute([&]() { ++result; }); + }); + } + threadUtil.joinThreads(); + EXPECT_EQ(result, 2); +} +} +} diff --git a/TransferLogManager.cpp b/TransferLogManager.cpp index f61dd068..dec58236 100644 --- a/TransferLogManager.cpp +++ b/TransferLogManager.cpp @@ -256,7 +256,6 @@ void TransferLogManager::openLog() { std::move(std::thread(&TransferLogManager::writeEntriesToDisk, this)); LOG(INFO) << "Log writer thread started"; } - return; } void TransferLogManager::closeLog() { @@ -326,7 +325,8 @@ bool TransferLogManager::verifySenderIp(const std::string &curSenderIp) { bool verifySuccessful = true; if (!options.disable_sender_verification_during_resumption) { if (senderIp_.empty()) { - LOG(INFO) << "Sender-ip empty, not verifying sender-ip"; + LOG(INFO) << "Sender-ip empty, not verifying sender-ip, new-ip: " + << curSenderIp; } else if (senderIp_ != curSenderIp) { LOG(ERROR) << "Current sender ip does not match ip in the " "transfer log " << curSenderIp << " " << senderIp_ @@ -579,6 +579,7 @@ ErrorCode LogParser::processHeaderEntry(char *buf, int size, int64_t logConfig; if (!encoderDecoder_.decodeLogHeader(buf, size, timestamp, logVersion, logRecoveryId, senderIp, logConfig)) { + LOG(ERROR) << "Couldn't decode the log header"; return INVALID_LOG; } if (logVersion != TransferLogManager::LOG_VERSION) { @@ -894,6 +895,7 @@ ErrorCode LogParser::parseLog(int fd, std::string &senderIp, return INVALID_LOG; } if (status == INVALID_LOG) { + LOG(ERROR) << "Invalid transfer log"; return status; } if (status == INCONSISTENT_DIRECTORY) { diff --git a/WdtBase.cpp b/WdtBase.cpp index 6a953502..6e595de9 100644 --- a/WdtBase.cpp +++ b/WdtBase.cpp @@ -351,6 +351,7 @@ WdtBase::WdtBase() : abortCheckerCallback_(this) { WdtBase::~WdtBase() { abortChecker_ = nullptr; + delete threadsController_; } std::vector WdtBase::genPortsVector(int32_t startPort, @@ -407,6 +408,10 @@ void WdtBase::setThrottler(std::shared_ptr throttler) { throttler_ = throttler; } +std::shared_ptr WdtBase::getThrottler() const { + return throttler_; +} + void WdtBase::setTransferId(const std::string& transferId) { transferId_ = transferId; LOG(INFO) << "Setting transfer id " << transferId_; @@ -423,6 +428,10 @@ void WdtBase::setProtocolVersion(int64_t protocol) { LOG(INFO) << "using wdt protocol version " << protocolVersion_; } +int WdtBase::getProtocolVersion() const { + return protocolVersion_; +} + std::string WdtBase::getTransferId() { return transferId_; } diff --git a/WdtBase.h b/WdtBase.h index a1916c19..3b114c9b 100644 --- a/WdtBase.h +++ b/WdtBase.h @@ -15,6 +15,7 @@ #include "Reporting.h" #include "Throttler.h" #include "Protocol.h" +#include "WdtThread.h" #include #include #include @@ -227,6 +228,9 @@ class WdtBase { /// Sets the protocol version for the transfer void setProtocolVersion(int64_t protocolVersion); + /// Get the protocol version of the transfer + int getProtocolVersion() const; + /// Get the transfer id of the object std::string getTransferId(); @@ -243,19 +247,8 @@ class WdtBase { /// Utility to generate a random transfer id static std::string generateTransferId(); - protected: - /// Global throttler across all threads - std::shared_ptr throttler_; - - /// Holds the instance of the progress reporter default or customized - std::unique_ptr progressReporter_; - - /// Unique id for the transfer - std::string transferId_; - - /// protocol version to use, this is determined by negotiating protocol - /// version with the other side - int protocolVersion_{Protocol::protocol_version}; + /// Get the throttler + std::shared_ptr getThrottler() const; /// abort checker class passed to socket functions class AbortChecker : public IAbortChecker { @@ -271,9 +264,29 @@ class WdtBase { WdtBase* wdtBase_; }; + protected: + /// Ports that the sender/receiver is running on + std::vector ports_; + + /// Global throttler across all threads + std::shared_ptr throttler_; + + /// Holds the instance of the progress reporter default or customized + std::unique_ptr progressReporter_; + + /// Unique id for the transfer + std::string transferId_; + + /// protocol version to use, this is determined by negotiating protocol + /// version with the other side + int protocolVersion_{Protocol::protocol_version}; + /// abort checker passed to socket functions AbortChecker abortCheckerCallback_; + /// Controller for wdt threads shared between base and threads + ThreadsController* threadsController_{nullptr}; + private: folly::RWSpinLock abortCodeLock_; /// Internal and default abort code diff --git a/WdtConfig.h b/WdtConfig.h index 281ef12a..c0610382 100644 --- a/WdtConfig.h +++ b/WdtConfig.h @@ -7,10 +7,10 @@ #include #define WDT_VERSION_MAJOR 1 -#define WDT_VERSION_MINOR 21 -#define WDT_VERSION_BUILD 1510120 +#define WDT_VERSION_MINOR 22 +#define WDT_VERSION_BUILD 1510210 // Add -fbcode to version str -#define WDT_VERSION_STR "1.21.1510120-fbcode" +#define WDT_VERSION_STR "1.22.1510210-fbcode" // Tie minor and proto version #define WDT_PROTOCOL_VERSION WDT_VERSION_MINOR diff --git a/WdtThread.cpp b/WdtThread.cpp new file mode 100644 index 00000000..c1f0a523 --- /dev/null +++ b/WdtThread.cpp @@ -0,0 +1,48 @@ +#include "WdtThread.h" +using namespace std; +namespace facebook { +namespace wdt { +TransferStats WdtThread::moveStats() { + return std::move(threadStats_); +} + +const PerfStatReport& WdtThread::getPerfReport() const { + return perfReport_; +} + +const TransferStats& WdtThread::getTransferStats() const { + return threadStats_; +} + +void WdtThread::startThread() { + if (threadPtr_) { + WDT_CHECK(false) << "There is a already a thread running " << threadIndex_ + << " " << getPort(); + } + auto state = controller_->getState(threadIndex_); + // Check the state should be running here + WDT_CHECK_EQ(state, RUNNING); + threadPtr_.reset(new std::thread(&WdtThread::start, this)); +} + +ErrorCode WdtThread::finish() { + if (!threadPtr_) { + LOG(ERROR) << "Finish called on an instance while no thread has been " + << " created to do any work"; + return ERROR; + } + threadPtr_->join(); + threadPtr_.reset(); + return OK; +} + +WdtThread::~WdtThread() { + if (threadPtr_) { + LOG(INFO) << threadIndex_ + << " has an alive thread while the instance is being " + << "destructed"; + finish(); + } +} +} +} diff --git a/WdtThread.h b/WdtThread.h new file mode 100644 index 00000000..4b1b3774 --- /dev/null +++ b/WdtThread.h @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "Reporting.h" +#include "ErrorCodes.h" +#include "ThreadsController.h" +#include +#include +namespace facebook { +namespace wdt { +class ThreadsController; +class WdtThread { + public: + /// Constructor for wdt thread + WdtThread(int threadIndex, int protocolVersion, ThreadsController *controller) + : threadIndex_(threadIndex), threadProtocolVersion_(protocolVersion) { + controller_ = controller; + } + /// Starts a thread which runs the wdt functionality + void startThread(); + + /// Get the perf stats of the transfer for this thread + const PerfStatReport &getPerfReport() const; + + /// Initializes the wdt thread before starting + virtual ErrorCode init() = 0; + + /// Conclude the thread transfer + virtual ErrorCode finish(); + + /// Moves the local stats into a new instance + TransferStats moveStats(); + + /// Get the transfer stats recorded by this thread + const TransferStats &getTransferStats() const; + + /// Reset the wdt thread + virtual void reset() = 0; + + /// Get the port this thread is running on + virtual int getPort() const = 0; + + // TODO remove this function + virtual int getNegotiatedProtocol() const { + return threadProtocolVersion_; + } + + virtual ~WdtThread(); + + protected: + /// The main entry point of the thread + virtual void start() = 0; + + /// Index of this thread with respect to other threads + const int threadIndex_; + + /// Copy of the protocol version that might be changed + int threadProtocolVersion_; + + /// Transfer stats for this thread + TransferStats threadStats_{true}; + + /// Perf stats report for this thread + PerfStatReport perfReport_; + + /// Thread controller for all the sender threads + ThreadsController *controller_{nullptr}; + + /// Pointer to the std::thread executing the transfer + std::unique_ptr threadPtr_{nullptr}; +}; +} +} diff --git a/wdt_global_checkpoint_test.sh b/wdt_global_checkpoint_test.sh new file mode 100644 index 00000000..36b191f8 --- /dev/null +++ b/wdt_global_checkpoint_test.sh @@ -0,0 +1,38 @@ +#! /bin/bash + +set -o pipefail + +source `dirname "$0"`/common_functions.sh + +BASEDIR=/dev/shm/wdtTest_$USER +mkdir -p "$BASEDIR" +DIR=`mktemp -d --tmpdir=$BASEDIR` +echo "Testing in $DIR" + +mkdir "$DIR/src" +# create a big file +dd if=/dev/zero of="$DIR/src/file" bs=536870912 count=1 + +TEST_COUNT=0 +# start the server +_bin/wdt/wdt -skip_writes -num_ports=2 -transfer_id=wdt \ +-connect_timeout_millis 100 -read_timeout_millis=200 > "$DIR/server${TEST_COUNT}.log" 2>&1 & +pidofreceiver=$! + +# block 22356 for small duration so that file is transferred through 22357 +blockDportByDropping 22356 +# start client +_bin/wdt/wdt -destination localhost -directory "$DIR/src" -num_ports=2 \ +-block_size_mbytes=-1 -avg_mbytes_per_sec=100 -transfer_id=wdt 2>&1 | \ +tee "$DIR/client${TEST_COUNT}.log" & +pidofsender=$! +sleep 0.1 +undoLastIpTableChange +sleep 5 +# block 22357 in the middle +blockDportByDropping 22357 +waitForTransferEnd + +echo "Test passed, deleting directory $DIR" +rm -rf "$DIR" +wdtExit 0