diff --git a/CMakeLists.txt b/CMakeLists.txt index e6467528..a4c1eecd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ cmake_minimum_required(VERSION 3.2) # There is no C per se in WDT but if you use CXX only here many checks fail # Version is Major.Minor.YYMMDDX for up to 10 releases per day # Minor currently is also the protocol version - has to match with Protocol.cpp -project("WDT" LANGUAGES C CXX VERSION 1.21.1510120) +project("WDT" LANGUAGES C CXX VERSION 1.22.1510210) # On MacOS this requires the latest (master) CMake (and/or CMake 3.1.1/3.2) set(CMAKE_CXX_STANDARD 11) @@ -79,8 +79,13 @@ ErrorCodes.cpp FileByteSource.cpp FileCreator.cpp Protocol.cpp +WdtThread.cpp +ThreadsController.cpp +ReceiverThread.cpp Receiver.cpp Reporting.cpp +ThreadTransferHistory.cpp +SenderThread.cpp Sender.cpp ServerSocket.cpp SocketUtils.cpp diff --git a/ErrorCodes.h b/ErrorCodes.h index 1eef6654..7d7fd90a 100644 --- a/ErrorCodes.h +++ b/ErrorCodes.h @@ -9,7 +9,7 @@ #pragma once #include - +#include #include namespace facebook { diff --git a/FileCreator.cpp b/FileCreator.cpp index 2495335e..725d3a66 100644 --- a/FileCreator.cpp +++ b/FileCreator.cpp @@ -164,8 +164,7 @@ int FileCreator::openExistingFile(const string &relPathStr) { WDT_CHECK(relPathStr[0] != '/'); WDT_CHECK(relPathStr.back() != '/'); - string path(rootDir_); - path.append(relPathStr); + const string path = rootDir_ + relPathStr; int openFlags = O_WRONLY; START_PERF_TIMER @@ -184,8 +183,7 @@ int FileCreator::createFile(const string &relPathStr) { CHECK(relPathStr[0] != '/'); CHECK(relPathStr.back() != '/'); - std::string path(rootDir_); - path.append(relPathStr); + const string path = rootDir_ + relPathStr; int p = relPathStr.size(); while (p && relPathStr[p - 1] != '/') { diff --git a/Protocol.cpp b/Protocol.cpp index 018e60b9..f25a0a0d 100644 --- a/Protocol.cpp +++ b/Protocol.cpp @@ -49,7 +49,8 @@ int Protocol::negotiateProtocol(int requestedProtocolVersion, } std::ostream &operator<<(std::ostream &os, const Checkpoint &checkpoint) { - os << "num-blocks: " << checkpoint.numBlocks + os << "checkpoint-port: " << checkpoint.port + << "num-blocks: " << checkpoint.numBlocks << " seq-id: " << checkpoint.lastBlockSeqId << " block-offset: " << checkpoint.lastBlockOffset << " received-bytes: " << checkpoint.lastBlockReceivedBytes; diff --git a/README.md b/README.md index 0352759f..5de55502 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,10 @@ caller get the mutable object of options and set different options accordingly. When wdt is run in a standalone mode, behavior is changed through gflags in wdtCmdLine.cpp +* WdtThread.{h|cpp} +Common functionality and settings between SenderThread and ReceiverThread. +Both of these kind of threads inherit from this base class. + * WdtBase.{h|cpp} Common functionality and settings between Sender and Receiver @@ -156,11 +160,20 @@ directory, sorted by decreasing size (as they are discovered, you can start pulling from the queue even before all the files are found, it will return the current largest file) +* ThreadTransferHistory.{h|cpp} -* Sender.{h|cpp} +Every thread maintains a transfer history so that when a connection breaks +it can talk to the receiver to find out up to where in the history has been +sent. This class encapsulates all the logic for that bookkeeping + +* SenderThread.{h|cpp} + +Implements the functionality of one sender thread, which binds to a certain port +and sends files over. -Formerly wdtlib.cpp - main code sending files +* Sender.{h|cpp} +Spawns multiple SenderThread threads and sends the data across to receiver ### Consuming / Receiving @@ -168,10 +181,15 @@ Formerly wdtlib.cpp - main code sending files Creates file and directories necessary for said file (mkdir -p like) -* Receiver.{h|cpp} +* ReceiverThread.{h|cpp} -Formerly wdtlib.cpp - main code receiving files +Implements the funcionality of the receiver threads, responsible for listening on +a port and receiving files over the network. + +* Receiver.{h|cpp} +Parent receiver class that spawns multiple ReceiverThread threads and receives +data from a remote host ### Low level building blocks diff --git a/Receiver.cpp b/Receiver.cpp index 0934e186..48cf5196 100644 --- a/Receiver.cpp +++ b/Receiver.cpp @@ -10,7 +10,6 @@ #include "ServerSocket.h" #include "FileWriter.h" #include "SocketUtils.h" -#include "DirectorySourceQueue.h" #include #include @@ -29,82 +28,10 @@ using std::vector; namespace facebook { namespace wdt { - -const static int kTimeoutBufferMillis = 1000; -const static int kWaitTimeoutFactor = 5; - -std::ostream &operator<<(std::ostream &os, const Receiver::ThreadData &data) { - os << "Thread[" << data.threadIndex_ << ", port: " << data.socket_.getPort() - << "] "; - return os; -} - -int64_t readAtLeast(ServerSocket &s, char *buf, int64_t max, int64_t atLeast, - int64_t len) { - VLOG(4) << "readAtLeast len " << len << " max " << max << " atLeast " - << atLeast << " from " << s.getFd(); - CHECK_GE(len, 0); - CHECK_GT(atLeast, 0); - CHECK_LE(atLeast, max); - int count = 0; - while (len < atLeast) { - // because we want to process data as soon as it arrives, tryFull option for - // read is false - int64_t n = s.read(buf + len, max - len, false); - if (n < 0) { - PLOG(ERROR) << "Read error on " << s.getPort() << " after " << count; - if (len) { - return len; - } else { - return n; - } - } - if (n == 0) { - VLOG(2) << "Eof on " << s.getPort() << " after " << count << " reads " - << "got " << len; - return len; - } - len += n; - count++; - } - VLOG(3) << "Took " << count << " reads to get " << len - << " from fd : " << s.getFd(); - return len; -} - -int64_t readAtMost(ServerSocket &s, char *buf, int64_t max, int64_t atMost) { - const int64_t target = atMost < max ? atMost : max; - VLOG(3) << "readAtMost target " << target; - // because we want to process data as soon as it arrives, tryFull option for - // read is false - int64_t n = s.read(buf, target, false); - if (n < 0) { - PLOG(ERROR) << "Read error on " << s.getPort() << " with target " << target; - return n; - } - if (n == 0) { - LOG(WARNING) << "Eof on " << s.getFd(); - return n; - } - VLOG(3) << "readAtMost " << n << " / " << atMost << " from " << s.getFd(); - return n; -} - -const Receiver::StateFunction Receiver::stateMap_[] = { - &Receiver::listen, &Receiver::acceptFirstConnection, - &Receiver::acceptWithTimeout, &Receiver::sendLocalCheckpoint, - &Receiver::readNextCmd, &Receiver::processFileCmd, - &Receiver::processSettingsCmd, &Receiver::processDoneCmd, - &Receiver::processSizeCmd, &Receiver::sendFileChunks, - &Receiver::sendGlobalCheckpoint, &Receiver::sendDoneCmd, - &Receiver::sendAbortCmd, &Receiver::waitForFinishOrNewCheckpoint, - &Receiver::waitForFinishWithThreadError}; - void Receiver::addCheckpoint(Checkpoint checkpoint) { LOG(INFO) << "Adding global checkpoint " << checkpoint.port << " " << checkpoint.numBlocks << " " << checkpoint.lastBlockReceivedBytes; checkpoints_.emplace_back(checkpoint); - conditionAllFinished_.notify_all(); } std::vector Receiver::getNewCheckpoints(int startIndex) { @@ -126,11 +53,7 @@ Receiver::Receiver(const WdtTransferRequest &transferRequest) { } setProtocolVersion(transferRequest.protocolVersion); setDir(transferRequest.directory); - const auto &options = WdtOptions::get(); - for (int32_t portNum : transferRequest.ports) { - threadServerSockets_.emplace_back(portNum, options.backlog, - &abortCheckerCallback_); - } + ports_ = transferRequest.ports; } Receiver::Receiver(int port, int numSockets, const std::string &destDir) @@ -151,32 +74,78 @@ void Receiver::traverseDestinationDir( return; } -WdtTransferRequest Receiver::init() { - vector successfulSockets; - for (size_t i = 0; i < threadServerSockets_.size(); i++) { - ServerSocket socket = std::move(threadServerSockets_[i]); - int max_retries = WdtOptions::get().max_retries; - for (int retries = 0; retries < max_retries; retries++) { - if (socket.listen() == OK) { - break; - } +void Receiver::startNewGlobalSession(const std::string &peerIp) { + if (throttler_) { + // If throttler is configured/set then register this session + // in the throttler. This is guranteed to work in either of the + // modes long running or not. We will de register from the throttler + // when the current session ends + throttler_->registerTransfer(); + } + startTime_ = Clock::now(); + const auto &options = WdtOptions::get(); + if (options.enable_download_resumption) { + bool verifySuccessful = transferLogManager_.verifySenderIp(peerIp); + if (!verifySuccessful) { + fileChunksInfo_.clear(); } - if (socket.listen() == OK) { - successfulSockets.push_back(std::move(socket)); - } else { - LOG(ERROR) << "Couldn't listen on port " << socket.getPort(); + } + hasNewTransferStarted_.store(true); + LOG(INFO) << "Starting new transfer, peerIp " << peerIp << " , transfer id " + << transferId_; +} + +bool Receiver::hasNewTransferStarted() const { + return hasNewTransferStarted_.load(); +} + +void Receiver::endCurGlobalSession() { + LOG(INFO) << "Ending the transfer " << transferId_; + if (throttler_) { + throttler_->deRegisterTransfer(); + } + checkpoints_.clear(); + fileCreator_->clearAllocationMap(); + // TODO might consider moving closing the transfer log here + hasNewTransferStarted_.store(false); +} + +WdtTransferRequest Receiver::init() { + const auto &options = WdtOptions::get(); + backlog_ = options.backlog; + bufferSize_ = options.buffer_size; + if (bufferSize_ < Protocol::kMaxHeader) { + // round up to even k + bufferSize_ = 2 * 1024 * ((Protocol::kMaxHeader - 1) / (2 * 1024) + 1); + LOG(INFO) << "Specified -buffer_size " << options.buffer_size + << " smaller than " << Protocol::kMaxHeader << " using " + << bufferSize_ << " instead"; + } + auto numThreads = ports_.size(); + fileCreator_.reset( + new FileCreator(destDir_, numThreads, transferLogManager_)); + threadsController_ = new ThreadsController(numThreads); + threadsController_->setNumFunnels(ReceiverThread::NUM_FUNNELS); + threadsController_->setNumBarriers(ReceiverThread::NUM_BARRIERS); + threadsController_->setNumConditions(ReceiverThread::NUM_CONDITIONS); + receiverThreads_ = threadsController_->makeThreads( + this, ports_.size(), ports_); + size_t numSuccessfulInitThreads = 0; + for (auto &receiverThread : receiverThreads_) { + ErrorCode code = receiverThread->init(); + if (code == OK) { + ++numSuccessfulInitThreads; } } - LOG(INFO) << "Registered " << successfulSockets.size() << " sockets"; + LOG(INFO) << "Registered " << numSuccessfulInitThreads + << " successful sockets"; ErrorCode code = OK; - if (threadServerSockets_.size() != successfulSockets.size()) { + if (numSuccessfulInitThreads != ports_.size()) { code = FEWER_PORTS; - if (successfulSockets.size() == 0) { + if (numSuccessfulInitThreads == 0) { code = ERROR; } } - threadServerSockets_ = std::move(successfulSockets); - // TODO: mutate input request instead, post validation WdtTransferRequest transferRequest(getPorts()); transferRequest.protocolVersion = protocolVersion_; transferRequest.transferId = transferId_; @@ -191,7 +160,7 @@ WdtTransferRequest Receiver::init() { code = ERROR; } } - transferRequest.directory = destDir_; + transferRequest.directory = getDir(); transferRequest.errorCode = code; return transferRequest; } @@ -201,6 +170,14 @@ void Receiver::setDir(const std::string &destDir) { transferLogManager_.setRootDir(destDir_); } +TransferLogManager &Receiver::getTransferLogManager() { + return transferLogManager_; +} + +std::unique_ptr &Receiver::getFileCreator() { + return fileCreator_; +} + const std::string &Receiver::getDir() { return destDir_; } @@ -221,12 +198,16 @@ Receiver::~Receiver() { vector Receiver::getPorts() const { vector ports; - for (const auto &socket : threadServerSockets_) { - ports.push_back(socket.getPort()); + for (const auto &receiverThread : receiverThreads_) { + ports.push_back(receiverThread->getPort()); } return ports; } +const std::vector &Receiver::getFileChunksInfo() const { + return fileChunksInfo_; +} + int64_t Receiver::getTransferConfig() const { auto &options = WdtOptions::get(); int64_t config = 0; @@ -265,10 +246,9 @@ std::unique_ptr Receiver::finish() { LOG(WARNING) << "The receiver is not joinable. The threads will never" << " finish and this method will never return"; } - for (size_t i = 0; i < receiverThreads_.size(); i++) { - receiverThreads_[i].join(); + for (auto &receiverThread : receiverThreads_) { + receiverThread->finish(); } - // A very important step to mark the transfer finished // No other transferAsync, or runForever can be called on this // instance unless the current transfer has finished @@ -279,50 +259,37 @@ std::unique_ptr Receiver::finish() { progressTrackerThread_.join(); } std::unique_ptr report = getTransferReport(); - + auto &summary = report->getSummary(); bool transferSuccess = (report->getSummary().getCombinedErrorCode() == OK); fixAndCloseTransferLog(transferSuccess); - - if (progressReporter_ && totalSenderBytes_ >= 0) { - report->setTotalFileSize(totalSenderBytes_); + auto totalSenderBytes = summary.getTotalSenderBytes(); + if (progressReporter_ && totalSenderBytes >= 0) { + report->setTotalFileSize(totalSenderBytes); report->setTotalTime(durationSeconds(Clock::now() - startTime_)); progressReporter_->end(report); } if (options.enable_perf_stat_collection) { PerfStatReport globalPerfReport; - for (auto &perfReport : perfReports_) { - globalPerfReport += perfReport; + for (auto &receiverThread : receiverThreads_) { + globalPerfReport += receiverThread->getPerfReport(); } LOG(INFO) << globalPerfReport; } LOG(WARNING) << "WDT receiver's transfer has been finished"; LOG(INFO) << *report; - receiverThreads_.clear(); - threadServerSockets_.clear(); - threadStats_.clear(); areThreadsJoined_ = true; return report; } std::unique_ptr Receiver::getTransferReport() { - std::unique_ptr report = - folly::make_unique(threadStats_); - const TransferStats &summary = report->getSummary(); - - if (numBlocksSend_ == -1 || numBlocksSend_ != summary.getNumBlocks()) { - // either none of the threads finished properly or not all of the blocks - // were transferred - report->setErrorCode(ERROR); - } else if (totalSenderBytes_ != -1 && - totalSenderBytes_ != summary.getEffectiveDataBytes()) { - // did not receive all the bytes - LOG(ERROR) << "Number of bytes sent and received do not match " - << totalSenderBytes_ << " " << summary.getEffectiveDataBytes(); - report->setErrorCode(ERROR); - } else { - report->setErrorCode(OK); + TransferStats globalStats; + for (const auto &receiverThread : receiverThreads_) { + globalStats += receiverThread->getTransferStats(); } + globalStats.validate(); + std::unique_ptr report = + folly::make_unique(std::move(globalStats)); return report; } @@ -385,11 +352,10 @@ void Receiver::progressTracker() { std::chrono::time_point lastUpdateTime = Clock::now(); int intervalsSinceLastUpdate = 0; double currentThroughput = 0; - LOG(INFO) << "Progress reporter updating every " << progressReportIntervalMillis << " ms"; auto waitingTime = std::chrono::milliseconds(progressReportIntervalMillis); - int64_t totalSenderBytes; + int64_t totalSenderBytes = -1; while (true) { { std::unique_lock lock(mutex_); @@ -397,14 +363,18 @@ void Receiver::progressTracker() { if (transferFinished_ || getCurAbortCode() != OK) { break; } - if (totalSenderBytes_ == -1) { - continue; - } - totalSenderBytes = totalSenderBytes_; } double totalTime = durationSeconds(Clock::now() - startTime_); + TransferStats globalStats; + for (const auto &receiverThread : receiverThreads_) { + globalStats += receiverThread->getTransferStats(); + } + totalSenderBytes = globalStats.getTotalSenderBytes(); + if (totalSenderBytes == -1) { + continue; + } auto transferReport = folly::make_unique( - threadStats_, totalTime, totalSenderBytes); + std::move(globalStats), totalTime, totalSenderBytes); intervalsSinceLastUpdate++; if (intervalsSinceLastUpdate >= throughputUpdateInterval) { auto curTime = Clock::now(); @@ -417,7 +387,6 @@ void Receiver::progressTracker() { intervalsSinceLastUpdate = 0; } transferReport->setCurrentThroughput(currentThroughput); - progressReporter_->progress(transferReport); } } @@ -428,21 +397,11 @@ void Receiver::start() { LOG(WARNING) << "There is an existing transfer in progress on this object"; } areThreadsJoined_ = false; - numActiveThreads_ = threadServerSockets_.size(); LOG(INFO) << "Starting (receiving) server on ports [ " << getPorts() << "] Target dir : " << destDir_; markTransferFinished(false); const auto &options = WdtOptions::get(); - int64_t bufferSize = options.buffer_size; - if (bufferSize < Protocol::kMaxHeader) { - // round up to even k - bufferSize = 2 * 1024 * ((Protocol::kMaxHeader - 1) / (2 * 1024) + 1); - LOG(INFO) << "Specified -buffer_size " << options.buffer_size - << " smaller than " << Protocol::kMaxHeader << " using " - << bufferSize << " instead"; - } - fileCreator_.reset(new FileCreator(destDir_, threadServerSockets_.size(), - transferLogManager_)); + // TODO do the init stuff here if (options.enable_download_resumption) { WDT_CHECK(!options.skip_writes) << "Can not skip transfers with download resumption turned on"; @@ -458,21 +417,27 @@ void Receiver::start() { traverseDestinationDir(fileChunksInfo_); } } - perfReports_.resize(threadServerSockets_.size()); - const int64_t numSockets = threadServerSockets_.size(); - for (int64_t i = 0; i < numSockets; i++) { - threadStats_.emplace_back(true); - } if (!throttler_) { configureThrottler(); } else { LOG(INFO) << "Throttler set externally. Throttler : " << *throttler_; } - - for (int64_t i = 0; i < numSockets; i++) { - receiverThreads_.emplace_back(&Receiver::receiveOne, this, i, - std::ref(threadServerSockets_[i]), bufferSize, - std::ref(threadStats_[i])); + while (true) { + for (auto &receiverThread : receiverThreads_) { + receiverThread->startThread(); + } + if (!isJoinable_) { + // If it is long running mode, finish the threads + // processing the current transfer and re spawn them again + // with the same sockets + for (auto &receiverThread : receiverThreads_) { + receiverThread->finish(); + receiverThread->reset(); + } + threadsController_->reset(); + continue; + } + break; } if (isJoinable_) { if (progressReporter_) { @@ -483,780 +448,6 @@ void Receiver::start() { } } -bool Receiver::areAllThreadsFinished(bool checkpointAdded) { - const int64_t numSockets = threadServerSockets_.size(); - bool finished = (failedThreadCount_ + waitingThreadCount_ + - waitingWithErrorThreadCount_) == numSockets; - if (checkpointAdded) { - // The thread has added a global checkpoint. So, - // even if all the threads are waiting, the session does no end. However, - // if all the threads are waiting with an error, then we must end the - // session. because none of the waiting threads can send the global - // checkpoint back to the sender - finished &= (waitingThreadCount_ == 0); - } - return finished; -} - -void Receiver::endCurGlobalSession() { - WDT_CHECK(transferFinishedCount_ + 1 == transferStartedCount_); - LOG(INFO) << "Received done for all threads. Transfer session " - << transferStartedCount_ << " finished"; - if (throttler_) { - throttler_->deRegisterTransfer(); - } - transferFinishedCount_++; - waitingThreadCount_ = 0; - waitingWithErrorThreadCount_ = 0; - checkpoints_.clear(); - fileCreator_->clearAllocationMap(); - conditionAllFinished_.notify_all(); -} - -void Receiver::incrFailedThreadCountAndCheckForSessionEnd(ThreadData &data) { - std::unique_lock lock(mutex_); - failedThreadCount_++; - if (areAllThreadsFinished(false) && - transferStartedCount_ > transferFinishedCount_) { - endCurGlobalSession(); - } -} - -bool Receiver::hasNewSessionStarted(ThreadData &data) { - bool started = transferStartedCount_ > data.transferStartedCount_; - if (started) { - WDT_CHECK(transferStartedCount_ == data.transferStartedCount_ + 1); - } - return started; -} - -void Receiver::startNewGlobalSession(ThreadData &data) { - WDT_CHECK(transferStartedCount_ == transferFinishedCount_); - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - if (throttler_) { - // If throttler is configured/set then register this session - // in the throttler. This is guaranteed to work in either of the - // modes long running or not. We will de register from the throttler - // when the current session ends - throttler_->registerTransfer(); - } - transferStartedCount_++; - startTime_ = Clock::now(); - - if (options.enable_download_resumption) { - bool verifySuccessful = - transferLogManager_.verifySenderIp(socket.getPeerIp()); - if (!verifySuccessful) { - fileChunksInfo_.clear(); - } - } - - LOG(INFO) << "New transfer started " << transferStartedCount_; -} - -bool Receiver::hasCurSessionFinished(ThreadData &data) { - return transferFinishedCount_ > data.transferFinishedCount_; -} - -void Receiver::startNewThreadSession(ThreadData &data) { - WDT_CHECK(data.transferStartedCount_ == data.transferFinishedCount_); - data.transferStartedCount_++; -} - -void Receiver::endCurThreadSession(ThreadData &data) { - WDT_CHECK(data.transferStartedCount_ == data.transferFinishedCount_ + 1); - data.transferFinishedCount_++; -} - -/***LISTEN STATE***/ -Receiver::ReceiverState Receiver::listen(ThreadData &data) { - VLOG(1) << data << " entered LISTEN state "; - const auto &options = WdtOptions::get(); - const bool doActualWrites = !options.skip_writes; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - - int32_t port = socket.getPort(); - VLOG(1) << "Server Thread for port " << port << " with backlog " - << socket.getBackLog() << " on " << destDir_ - << " writes= " << doActualWrites; - for (int i = 1; i < options.max_retries; ++i) { - ErrorCode code = socket.listen(); - if (code == OK) { - break; - } else if (code == CONN_ERROR) { - threadStats.setErrorCode(code); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - LOG(INFO) << "Sleeping after failed attempt " << i; - /* sleep override */ - usleep(options.sleep_millis * 1000); - } - // one more/last try (stays true if it worked above) - if (socket.listen() != OK) { - LOG(ERROR) << "Unable to listen/bind despite retries"; - threadStats.setErrorCode(CONN_ERROR); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - return ACCEPT_FIRST_CONNECTION; -} - -/***ACCEPT_FIRST_CONNECTION***/ -Receiver::ReceiverState Receiver::acceptFirstConnection(ThreadData &data) { - VLOG(1) << data << " entered ACCEPT_FIRST_CONNECTION state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &curConnectionVerified = data.curConnectionVerified_; - - data.reset(); - socket.closeCurrentConnection(); - auto timeout = options.accept_timeout_millis; - int acceptAttempts = 0; - while (true) { - { - std::lock_guard lock(mutex_); - if (hasNewSessionStarted(data)) { - startNewThreadSession(data); - return ACCEPT_WITH_TIMEOUT; - } - } - if (isJoinable_ && acceptAttempts == options.max_accept_retries) { - LOG(ERROR) << "unable to accept after " << acceptAttempts << " attempts"; - threadStats.setErrorCode(CONN_ERROR); - incrFailedThreadCountAndCheckForSessionEnd(data); - return FAILED; - } - - if (getCurAbortCode() != OK) { - LOG(ERROR) << "Thread marked to abort while trying to accept first" - << " connection. Num attempts " << acceptAttempts; - // Even though there is a transition FAILED here - // getCurAbortCode() is going to be checked again in the receiveOne. - // So this is pretty much irrelevant - return FAILED; - } - - ErrorCode code = - socket.acceptNextConnection(timeout, curConnectionVerified); - if (code == OK) { - break; - } - acceptAttempts++; - } - - std::lock_guard lock(mutex_); - if (!hasNewSessionStarted(data)) { - // this thread has the first connection - startNewGlobalSession(data); - } - startNewThreadSession(data); - return READ_NEXT_CMD; -} - -/***ACCEPT_WITH_TIMEOUT STATE***/ -Receiver::ReceiverState Receiver::acceptWithTimeout(ThreadData &data) { - LOG(INFO) << data << " entered ACCEPT_WITH_TIMEOUT state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &senderWriteTimeout = data.senderWriteTimeout_; - auto &doneSendFailure = data.doneSendFailure_; - auto &curConnectionVerified = data.curConnectionVerified_; - socket.closeCurrentConnection(); - - auto timeout = options.accept_window_millis; - if (senderReadTimeout > 0) { - // transfer is in progress and we have already got sender settings - timeout = - std::max(senderReadTimeout, senderWriteTimeout) + kTimeoutBufferMillis; - } - - ErrorCode code = socket.acceptNextConnection(timeout, curConnectionVerified); - curConnectionVerified = false; - if (code != OK) { - LOG(ERROR) << "accept() failed with timeout " << timeout; - threadStats.setErrorCode(code); - if (doneSendFailure) { - // if SEND_DONE_CMD state had already been reached, we do not need to - // wait for other threads to end - return END; - } - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - if (doneSendFailure) { - // no need to reset any session variables in this case - return SEND_LOCAL_CHECKPOINT; - } - - data.numRead_ = data.off_ = 0; - data.pendingCheckpointIndex_ = data.checkpointIndex_; - ReceiverState nextState = READ_NEXT_CMD; - if (threadStats.getErrorCode() != OK) { - nextState = SEND_LOCAL_CHECKPOINT; - } - // reset thread status - threadStats.setErrorCode(OK); - return nextState; -} - -/***SEND_LOCAL_CHECKPOINT STATE***/ -Receiver::ReceiverState Receiver::sendLocalCheckpoint(ThreadData &data) { - LOG(INFO) << data << " entered SEND_LOCAL_CHECKPOINT state "; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &doneSendFailure = data.doneSendFailure_; - auto &checkpoint = data.checkpoint_; - int32_t protocolVersion = data.threadProtocolVersion_; - char *buf = data.getBuf(); - - std::vector checkpoints; - if (doneSendFailure) { - // in case SEND_DONE failed, a special checkpoint(-1) is sent to signal this - // condition - Checkpoint localCheckpoint(socket.getPort()); - localCheckpoint.numBlocks = -1; - checkpoints.emplace_back(localCheckpoint); - } else { - checkpoints.emplace_back(checkpoint); - } - - int64_t off = 0; - const int checkpointLen = - Protocol::getMaxLocalCheckpointLength(protocolVersion); - Protocol::encodeCheckpoints(protocolVersion, buf, off, checkpointLen, - checkpoints); - int written = socket.write(buf, checkpointLen); - if (written != checkpointLen) { - LOG(ERROR) << "unable to write local checkpoint. write mismatch " - << checkpointLen << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(checkpointLen); - if (doneSendFailure) { - return SEND_DONE_CMD; - } - return READ_NEXT_CMD; -} - -/***READ_NEXT_CMD***/ -Receiver::ReceiverState Receiver::readNextCmd(ThreadData &data) { - VLOG(1) << data << " entered READ_NEXT_CMD state "; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto bufferSize = data.bufferSize_; - - oldOffset = off; - numRead = readAtLeast(socket, buf + off, bufferSize - off, - Protocol::kMinBufLength, numRead); - if (numRead < Protocol::kMinBufLength) { - LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " - << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[off++]; - if (cmd == Protocol::DONE_CMD) { - return PROCESS_DONE_CMD; - } - if (cmd == Protocol::FILE_CMD) { - return PROCESS_FILE_CMD; - } - if (cmd == Protocol::SETTINGS_CMD) { - return PROCESS_SETTINGS_CMD; - } - if (cmd == Protocol::SIZE_CMD) { - return PROCESS_SIZE_CMD; - } - LOG(ERROR) << "received an unknown cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; -} - -/***PROCESS_SETTINGS_CMD***/ -Receiver::ReceiverState Receiver::processSettingsCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_SETTINGS_CMD state "; - char *buf = data.getBuf(); - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto &numRead = data.numRead_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &senderWriteTimeout = data.senderWriteTimeout_; - auto &threadStats = data.threadStats_; - auto &enableChecksum = data.enableChecksum_; - auto &threadProtocolVersion = data.threadProtocolVersion_; - auto &curConnectionVerified = data.curConnectionVerified_; - auto &isBlockMode = data.isBlockMode_; - Settings settings; - int senderProtocolVersion; - - bool success = Protocol::decodeVersion( - buf, off, oldOffset + Protocol::kMaxVersion, senderProtocolVersion); - if (!success) { - LOG(ERROR) << "Unable to decode version " << data.threadIndex_; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - if (senderProtocolVersion != threadProtocolVersion) { - LOG(ERROR) << "Receiver and sender protocol version mismatch " - << senderProtocolVersion << " " << threadProtocolVersion; - int negotiatedProtocol = Protocol::negotiateProtocol(senderProtocolVersion, - threadProtocolVersion); - if (negotiatedProtocol == 0) { - LOG(WARNING) << "Can not support sender with version " - << senderProtocolVersion << ", aborting!"; - threadStats.setErrorCode(VERSION_INCOMPATIBLE); - return SEND_ABORT_CMD; - } else { - LOG_IF(INFO, threadProtocolVersion != negotiatedProtocol) - << "Changing receiver protocol version to " << negotiatedProtocol; - threadProtocolVersion = negotiatedProtocol; - if (negotiatedProtocol != senderProtocolVersion) { - threadStats.setErrorCode(VERSION_MISMATCH); - return SEND_ABORT_CMD; - } - } - } - - success = Protocol::decodeSettings( - threadProtocolVersion, buf, off, - oldOffset + Protocol::kMaxVersion + Protocol::kMaxSettings, settings); - if (!success) { - LOG(ERROR) << "Unable to decode settings cmd " << data.threadIndex_; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - auto senderId = settings.transferId; - if (transferId_ != senderId) { - LOG(ERROR) << "Receiver and sender id mismatch " << senderId << " " - << transferId_; - threadStats.setErrorCode(ID_MISMATCH); - return SEND_ABORT_CMD; - } - senderReadTimeout = settings.readTimeoutMillis; - senderWriteTimeout = settings.writeTimeoutMillis; - enableChecksum = settings.enableChecksum; - isBlockMode = !settings.blockModeDisabled; - curConnectionVerified = true; - - if (settings.sendFileChunks) { - // We only move to SEND_FILE_CHUNKS state, if download resumption is enabled - // in the sender side - numRead = off = 0; - return SEND_FILE_CHUNKS; - } - auto msgLen = off - oldOffset; - numRead -= msgLen; - return READ_NEXT_CMD; -} - -/***PROCESS_FILE_CMD***/ -Receiver::ReceiverState Receiver::processFileCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_FILE_CMD state "; - const auto &options = WdtOptions::get(); - auto &socket = data.socket_; - auto &threadIndex = data.threadIndex_; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - auto bufferSize = data.bufferSize_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &enableChecksum = data.enableChecksum_; - auto &protocolVersion = data.threadProtocolVersion_; - auto &checkpoint = data.checkpoint_; - auto &isBlockMode = data.isBlockMode_; - - // following block needs to be executed for the first file cmd. There is no - // harm in executing it more than once. number of blocks equal to 0 is a good - // approximation for first file cmd. Did not want to introduce another boolean - if (options.enable_download_resumption && threadStats.getNumBlocks() == 0) { - std::lock_guard lock(mutex_); - if (sendChunksStatus_ != SENT) { - // sender is not in resumption mode - addTransferLogHeader(isBlockMode, /* sender not resuming */ false); - sendChunksStatus_ = SENT; - } - } - - checkpoint.resetLastBlockDetails(); - BlockDetails blockDetails; - - auto guard = folly::makeGuard([&socket, &threadStats] { - if (threadStats.getErrorCode() != OK) { - threadStats.incrFailedAttempts(); - } - }); - - ErrorCode transferStatus = (ErrorCode)buf[off++]; - if (transferStatus != OK) { - // TODO: use this status information to implement fail fast mode - VLOG(1) << "sender entered into error state " - << errorCodeToStr(transferStatus); - } - int16_t headerLen = folly::loadUnaligned(buf + off); - headerLen = folly::Endian::little(headerLen); - VLOG(2) << "Processing FILE_CMD, header len " << headerLen; - - if (headerLen > numRead) { - int64_t end = oldOffset + numRead; - numRead = - readAtLeast(socket, buf + end, bufferSize - end, headerLen, numRead); - } - if (numRead < headerLen) { - LOG(ERROR) << "Unable to read full header " << headerLen << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - off += sizeof(int16_t); - bool success = Protocol::decodeHeader(protocolVersion, buf, off, - numRead + oldOffset, blockDetails); - int64_t headerBytes = off - oldOffset; - // transferred header length must match decoded header length - WDT_CHECK_EQ(headerLen, headerBytes) << " " << blockDetails.fileName << " " - << blockDetails.seqId << " " - << protocolVersion; - threadStats.addHeaderBytes(headerBytes); - if (!success) { - LOG(ERROR) << "Error decoding at" - << " ooff:" << oldOffset << " off: " << off - << " numRead: " << numRead; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - // received a well formed file cmd, apply the pending checkpoint update - checkpointIndex = pendingCheckpointIndex; - VLOG(1) << "Read id:" << blockDetails.fileName - << " size:" << blockDetails.dataSize << " ooff:" << oldOffset - << " off: " << off << " numRead: " << numRead; - - FileWriter writer(threadIndex, &blockDetails, fileCreator_.get()); - auto writtenGuard = folly::makeGuard([&] { - if (protocolVersion >= Protocol::CHECKPOINT_OFFSET_VERSION) { - // considering partially written block contents as valid, this bypasses - // checksum verification - // TODO: Make sure checksum verification work with checkpoint offsets - checkpoint.setLastBlockDetails(blockDetails.seqId, blockDetails.offset, - writer.getTotalWritten()); - threadStats.addEffectiveBytes(headerBytes, writer.getTotalWritten()); - } - }); - - if (writer.open() != OK) { - threadStats.setErrorCode(FILE_WRITE_ERROR); - return SEND_ABORT_CMD; - } - int32_t checksum = 0; - int64_t remainingData = numRead + oldOffset - off; - int64_t toWrite = remainingData; - WDT_CHECK(remainingData >= 0); - if (remainingData >= blockDetails.dataSize) { - toWrite = blockDetails.dataSize; - } - threadStats.addDataBytes(toWrite); - if (enableChecksum) { - checksum = folly::crc32c((const uint8_t *)(buf + off), toWrite, checksum); - } - if (throttler_) { - // We might be reading more than we require for this file but - // throttling should make sense for any additional bytes received - // on the network - throttler_->limit(toWrite + headerBytes); - } - ErrorCode code = writer.write(buf + off, toWrite); - if (code != OK) { - threadStats.setErrorCode(code); - return SEND_ABORT_CMD; - } - off += toWrite; - remainingData -= toWrite; - // also means no leftOver so it's ok we use buf from start - while (writer.getTotalWritten() < blockDetails.dataSize) { - if (getCurAbortCode() != OK) { - LOG(ERROR) << "Thread marked for abort while processing a file." - << " port : " << socket.getPort(); - return FAILED; - } - int64_t nres = readAtMost(socket, buf, bufferSize, - blockDetails.dataSize - writer.getTotalWritten()); - if (nres <= 0) { - break; - } - if (throttler_) { - // We only know how much we have read after we are done calling - // readAtMost. Call throttler with the bytes read off the wire. - throttler_->limit(nres); - } - threadStats.addDataBytes(nres); - if (enableChecksum) { - checksum = folly::crc32c((const uint8_t *)buf, nres, checksum); - } - code = writer.write(buf, nres); - if (code != OK) { - threadStats.setErrorCode(code); - return SEND_ABORT_CMD; - } - } - if (writer.getTotalWritten() != blockDetails.dataSize) { - // This can only happen if there are transmission errors - // Write errors to disk are already taken care of above - LOG(ERROR) << "could not read entire content for " << blockDetails.fileName - << " port " << socket.getPort(); - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - writtenGuard.dismiss(); - VLOG(2) << "completed " << blockDetails.fileName << " off: " << off - << " numRead: " << numRead; - // Transfer of the file is complete here, mark the bytes effective - WDT_CHECK(remainingData >= 0) << "Negative remainingData " << remainingData; - if (remainingData > 0) { - // if we need to read more anyway, let's move the data - numRead = remainingData; - if ((remainingData < Protocol::kMaxHeader) && (off > (bufferSize / 2))) { - // rare so inefficient is ok - VLOG(3) << "copying extra " << remainingData << " leftover bytes @ " - << off; - memmove(/* dst */ buf, - /* from */ buf + off, - /* how much */ remainingData); - off = 0; - } else { - // otherwise just continue from the offset - VLOG(3) << "Using remaining extra " << remainingData - << " leftover bytes starting @ " << off; - } - } else { - numRead = off = 0; - } - if (enableChecksum) { - // have to read footer cmd - oldOffset = off; - numRead = readAtLeast(socket, buf + off, bufferSize - off, - Protocol::kMinBufLength, numRead); - if (numRead < Protocol::kMinBufLength) { - LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " - << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[off++]; - if (cmd != Protocol::FOOTER_CMD) { - LOG(ERROR) << "Expecting footer cmd, but received " << cmd; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - int32_t receivedChecksum; - bool success = Protocol::decodeFooter( - buf, off, oldOffset + Protocol::kMaxFooter, receivedChecksum); - if (!success) { - LOG(ERROR) << "Unable to decode footer cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - if (checksum != receivedChecksum) { - LOG(ERROR) << "Checksum mismatch " << checksum << " " << receivedChecksum - << " port " << socket.getPort() << " file " - << blockDetails.fileName; - threadStats.setErrorCode(CHECKSUM_MISMATCH); - return ACCEPT_WITH_TIMEOUT; - } - int64_t msgLen = off - oldOffset; - numRead -= msgLen; - } - if (options.isLogBasedResumption()) { - transferLogManager_.addBlockWriteEntry( - blockDetails.seqId, blockDetails.offset, blockDetails.dataSize); - } - threadStats.addEffectiveBytes(headerBytes, blockDetails.dataSize); - threadStats.incrNumBlocks(); - checkpoint.incrNumBlocks(); - return READ_NEXT_CMD; -} - -Receiver::ReceiverState Receiver::processDoneCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_DONE_CMD state "; - auto &numRead = data.numRead_; - auto &threadStats = data.threadStats_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - int protocolVersion = data.threadProtocolVersion_; - char *buf = data.getBuf(); - - if (numRead != Protocol::kMinBufLength) { - LOG(ERROR) << "Unexpected state for done command" - << " off: " << off << " numRead: " << numRead; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - - ErrorCode senderStatus = (ErrorCode)buf[off++]; - bool success; - { - std::lock_guard lock(mutex_); - success = Protocol::decodeDone(protocolVersion, buf, off, - oldOffset + Protocol::kMaxDone, - numBlocksSend_, totalSenderBytes_); - } - if (!success) { - LOG(ERROR) << "Unable to decode done cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - threadStats.setRemoteErrorCode(senderStatus); - - // received a valid command, applying pending checkpoint write update - checkpointIndex = pendingCheckpointIndex; - return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; -} - -Receiver::ReceiverState Receiver::processSizeCmd(ThreadData &data) { - VLOG(1) << data << " entered PROCESS_SIZE_CMD state "; - auto &threadStats = data.threadStats_; - auto &numRead = data.numRead_; - auto &off = data.off_; - auto &oldOffset = data.oldOffset_; - char *buf = data.getBuf(); - std::lock_guard lock(mutex_); - bool success = Protocol::decodeSize(buf, off, oldOffset + Protocol::kMaxSize, - totalSenderBytes_); - if (!success) { - LOG(ERROR) << "Unable to decode size cmd"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; - } - VLOG(1) << "Number of bytes to receive " << totalSenderBytes_; - auto msgLen = off - oldOffset; - numRead -= msgLen; - return READ_NEXT_CMD; -} - -Receiver::ReceiverState Receiver::sendFileChunks(ThreadData &data) { - LOG(INFO) << data << " entered SEND_FILE_CHUNKS state "; - char *buf = data.getBuf(); - auto bufferSize = data.bufferSize_; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &isBlockMode = data.isBlockMode_; - int64_t toWrite; - int64_t written; - std::unique_lock lock(mutex_); - while (true) { - switch (sendChunksStatus_) { - case SENT: { - lock.unlock(); - buf[0] = Protocol::ACK_CMD; - toWrite = 1; - written = socket.write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(toWrite); - return READ_NEXT_CMD; - } - case IN_PROGRESS: { - lock.unlock(); - buf[0] = Protocol::WAIT_CMD; - toWrite = 1; - written = socket.write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - threadStats.addHeaderBytes(toWrite); - WDT_CHECK(senderReadTimeout > 0); // must have received settings - int timeoutMillis = senderReadTimeout / kWaitTimeoutFactor; - auto waitingTime = std::chrono::milliseconds(timeoutMillis); - lock.lock(); - conditionFileChunksSent_.wait_for(lock, waitingTime); - continue; - } - case NOT_STARTED: { - // This thread has to send file chunks - sendChunksStatus_ = IN_PROGRESS; - lock.unlock(); - auto guard = folly::makeGuard([&] { - lock.lock(); - sendChunksStatus_ = NOT_STARTED; - conditionFileChunksSent_.notify_one(); - }); - int64_t off = 0; - buf[off++] = Protocol::CHUNKS_CMD; - const int64_t numParsedChunksInfo = fileChunksInfo_.size(); - Protocol::encodeChunksCmd(buf, off, bufferSize, numParsedChunksInfo); - written = socket.write(buf, off); - if (written > 0) { - threadStats.addHeaderBytes(written); - } - if (written != off) { - LOG(ERROR) << "Socket write error " << off << " " << written; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - int64_t numEntriesWritten = 0; - // we try to encode as many chunks as possible in the buffer. If a - // single - // chunk can not fit in the buffer, it is ignored. Format of encoding : - // ... - while (numEntriesWritten < numParsedChunksInfo) { - off = sizeof(int32_t); - int64_t numEntriesEncoded = Protocol::encodeFileChunksInfoList( - buf, off, bufferSize, numEntriesWritten, fileChunksInfo_); - int32_t dataSize = folly::Endian::little(off - sizeof(int32_t)); - folly::storeUnaligned(buf, dataSize); - written = socket.write(buf, off); - if (written > 0) { - threadStats.addHeaderBytes(written); - } - if (written != off) { - break; - } - numEntriesWritten += numEntriesEncoded; - } - if (numEntriesWritten != numParsedChunksInfo) { - LOG(ERROR) << "Could not write all the file chunks " - << numParsedChunksInfo << " " << numEntriesWritten; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - // try to read ack - int64_t toRead = 1; - int64_t numRead = socket.read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return ACCEPT_WITH_TIMEOUT; - } - guard.dismiss(); - lock.lock(); - sendChunksStatus_ = SENT; - addTransferLogHeader(isBlockMode, /* sender resuming */ true); - conditionFileChunksSent_.notify_all(); - return READ_NEXT_CMD; - } - } - } -} - void Receiver::addTransferLogHeader(bool isBlockMode, bool isSenderResuming) { const auto &options = WdtOptions::get(); if (!options.enable_download_resumption) { @@ -1309,244 +500,5 @@ void Receiver::fixAndCloseTransferLog(bool transferSuccess) { transferLogManager_.unlink(); } } - -Receiver::ReceiverState Receiver::sendGlobalCheckpoint(ThreadData &data) { - LOG(INFO) << data << " entered SEND_GLOBAL_CHECKPOINTS state "; - char *buf = data.getBuf(); - auto &off = data.off_; - auto &newCheckpoints = data.newCheckpoints_; - auto &socket = data.socket_; - auto &checkpointIndex = data.checkpointIndex_; - auto &pendingCheckpointIndex = data.pendingCheckpointIndex_; - auto &threadStats = data.threadStats_; - auto &numRead = data.numRead_; - auto bufferSize = data.bufferSize_; - int32_t protocolVersion = data.threadProtocolVersion_; - - buf[0] = Protocol::ERR_CMD; - off = 1; - // leave space for length - off += sizeof(int16_t); - auto oldOffset = off; - Protocol::encodeCheckpoints(protocolVersion, buf, off, bufferSize, - newCheckpoints); - int16_t length = off - oldOffset; - folly::storeUnaligned(buf + 1, folly::Endian::little(length)); - - auto written = socket.write(buf, off); - if (written != off) { - LOG(ERROR) << "unable to write error checkpoints"; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return ACCEPT_WITH_TIMEOUT; - } else { - threadStats.addHeaderBytes(off); - pendingCheckpointIndex = checkpointIndex + newCheckpoints.size(); - numRead = off = 0; - return READ_NEXT_CMD; - } -} - -Receiver::ReceiverState Receiver::sendAbortCmd(ThreadData &data) { - LOG(INFO) << data << " entered SEND_ABORT_CMD state "; - auto &threadStats = data.threadStats_; - char *buf = data.getBuf(); - auto &socket = data.socket_; - int32_t protocolVersion = data.threadProtocolVersion_; - int64_t offset = 0; - buf[offset++] = Protocol::ABORT_CMD; - Protocol::encodeAbort(buf, offset, protocolVersion, - threadStats.getErrorCode(), threadStats.getNumFiles()); - socket.write(buf, offset); - // No need to check if we were successful in sending ABORT - // This thread will simply disconnect and sender thread on the - // other side will timeout - socket.closeCurrentConnection(); - threadStats.addHeaderBytes(offset); - if (threadStats.getErrorCode() == VERSION_MISMATCH) { - // Receiver should try again expecting sender to have changed its version - return ACCEPT_WITH_TIMEOUT; - } - return WAIT_FOR_FINISH_WITH_THREAD_ERROR; -} - -Receiver::ReceiverState Receiver::sendDoneCmd(ThreadData &data) { - VLOG(1) << data << " entered SEND_DONE_CMD state "; - char *buf = data.getBuf(); - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - auto &doneSendFailure = data.doneSendFailure_; - - buf[0] = Protocol::DONE_CMD; - if (socket.write(buf, 1) != 1) { - PLOG(ERROR) << "unable to send DONE " << data.threadIndex_; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - - threadStats.addHeaderBytes(1); - - auto read = socket.read(buf, 1); - if (read != 1 || buf[0] != Protocol::DONE_CMD) { - LOG(ERROR) << data << " did not receive ack for DONE"; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - - read = socket.read(buf, Protocol::kMinBufLength); - if (read != 0) { - LOG(ERROR) << data << " EOF not found where expected"; - doneSendFailure = true; - return ACCEPT_WITH_TIMEOUT; - } - socket.closeCurrentConnection(); - LOG(INFO) << data << " got ack for DONE. Transfer finished"; - return END; -} - -Receiver::ReceiverState Receiver::waitForFinishWithThreadError( - ThreadData &data) { - LOG(INFO) << data << " entered WAIT_FOR_FINISH_WITH_THREAD_ERROR state "; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - auto &checkpoint = data.checkpoint_; - // should only be in this state if there is some error - WDT_CHECK(threadStats.getErrorCode() != OK); - - // close the socket, so that sender receives an error during connect - socket.closeAll(); - - std::unique_lock lock(mutex_); - addCheckpoint(checkpoint); - waitingWithErrorThreadCount_++; - - if (areAllThreadsFinished(true)) { - endCurGlobalSession(); - } else { - // wait for session end - while (!hasCurSessionFinished(data)) { - conditionAllFinished_.wait(lock); - } - } - endCurThreadSession(data); - return END; -} - -Receiver::ReceiverState Receiver::waitForFinishOrNewCheckpoint( - ThreadData &data) { - VLOG(1) << data << " entered WAIT_FOR_FINISH_OR_NEW_CHECKPOINT state "; - auto &threadStats = data.threadStats_; - auto &senderReadTimeout = data.senderReadTimeout_; - auto &checkpointIndex = data.checkpointIndex_; - auto &newCheckpoints = data.newCheckpoints_; - char *buf = data.getBuf(); - auto &socket = data.socket_; - // should only be called if there are no errors - WDT_CHECK(threadStats.getErrorCode() == OK); - - std::unique_lock lock(mutex_); - // we have to check for checkpoints before checking to see if session ended or - // not. because if some checkpoints have not been sent back to the sender, - // session should not end - newCheckpoints = getNewCheckpoints(checkpointIndex); - if (!newCheckpoints.empty()) { - return SEND_GLOBAL_CHECKPOINTS; - } - - waitingThreadCount_++; - if (areAllThreadsFinished(false)) { - endCurGlobalSession(); - endCurThreadSession(data); - return SEND_DONE_CMD; - } - - // we must send periodic wait cmd to keep the sender thread alive - while (true) { - WDT_CHECK(senderReadTimeout > 0); // must have received settings - int timeoutMillis = senderReadTimeout / kWaitTimeoutFactor; - auto waitingTime = std::chrono::milliseconds(timeoutMillis); - START_PERF_TIMER - conditionAllFinished_.wait_for(lock, waitingTime); - RECORD_PERF_RESULT(PerfStatReport::RECEIVER_WAIT_SLEEP) - - // check if transfer finished or not - if (hasCurSessionFinished(data)) { - endCurThreadSession(data); - return SEND_DONE_CMD; - } - - // check to see if any new checkpoints were added - newCheckpoints = getNewCheckpoints(checkpointIndex); - if (!newCheckpoints.empty()) { - waitingThreadCount_--; - return SEND_GLOBAL_CHECKPOINTS; - } - - // must unlock because socket write could block for long time, as long as - // the write timeout, which is 5sec by default - lock.unlock(); - - // send WAIT cmd to keep sender thread alive - buf[0] = Protocol::WAIT_CMD; - if (socket.write(buf, 1) != 1) { - PLOG(ERROR) << data << " unable to write WAIT "; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - lock.lock(); - // we again have to check if the session has finished or not. while - // writing WAIT cmd, some other thread could have ended the session, so - // going back to ACCEPT_WITH_TIMEOUT state would be wrong - if (!hasCurSessionFinished(data)) { - waitingThreadCount_--; - return ACCEPT_WITH_TIMEOUT; - } - endCurThreadSession(data); - return END; - } - threadStats.addHeaderBytes(1); - lock.lock(); - } -} - -void Receiver::receiveOne(int threadIndex, ServerSocket &socket, - int64_t bufferSize, TransferStats &threadStats) { - INIT_PERF_STAT_REPORT - auto guard = folly::makeGuard([&] { - perfReports_[threadIndex] = *perfStatReport; // copy when done - std::unique_lock lock(mutex_); - numActiveThreads_--; - if (numActiveThreads_ == 0) { - LOG(WARNING) << "Last thread finished. Duration of the transfer " - << durationSeconds(Clock::now() - startTime_); - transferFinished_ = true; - } - }); - ThreadData data(threadIndex, socket, threadStats, protocolVersion_, - bufferSize); - if (!data.getBuf()) { - LOG(ERROR) << "error allocating " << bufferSize; - threadStats.setErrorCode(MEMORY_ALLOCATION_ERROR); - return; - } - ReceiverState state = LISTEN; - while (true) { - ErrorCode abortCode = getCurAbortCode(); - if (abortCode != OK) { - LOG(ERROR) << "Transfer aborted " << socket.getPort() << " " - << errorCodeToStr(abortCode); - threadStats.setErrorCode(ABORT); - incrFailedThreadCountAndCheckForSessionEnd(data); - break; - } - if (state == FAILED) { - return; - } - if (state == END) { - if (isJoinable_) { - return; - } - state = ACCEPT_FIRST_CONNECTION; - } - state = (this->*stateMap_[state])(data); - } -} } } // namespace facebook::wdt diff --git a/Receiver.h b/Receiver.h index a24cc779..a9aa6f7b 100644 --- a/Receiver.h +++ b/Receiver.h @@ -17,7 +17,9 @@ #include "Protocol.h" #include "Writer.h" #include "Throttler.h" +#include "ReceiverThread.h" #include "TransferLogManager.h" +#include "ThreadsController.h" #include #include #include @@ -98,7 +100,9 @@ class Receiver : public WdtBase { */ std::vector getPorts() const; - private: + protected: + friend class ReceiverThread; + /** * @param isFinished Mark transfer active/inactive */ @@ -111,321 +115,18 @@ class Receiver : public WdtBase { */ void traverseDestinationDir(std::vector &fileChunksInfo); - /** - * Wdt receiver has logic to maintain the consistency of the - * transfers through connection errors. All threads are run by the logic - * defined as a state machine. These are the all the states in that - * state machine - */ - enum ReceiverState { - LISTEN, - ACCEPT_FIRST_CONNECTION, - ACCEPT_WITH_TIMEOUT, - SEND_LOCAL_CHECKPOINT, - READ_NEXT_CMD, - PROCESS_FILE_CMD, - PROCESS_SETTINGS_CMD, - PROCESS_DONE_CMD, - PROCESS_SIZE_CMD, - SEND_FILE_CHUNKS, - SEND_GLOBAL_CHECKPOINTS, - SEND_DONE_CMD, - SEND_ABORT_CMD, - WAIT_FOR_FINISH_OR_NEW_CHECKPOINT, - WAIT_FOR_FINISH_WITH_THREAD_ERROR, - FAILED, - END - }; - - /** - * Structure to pass data to the state machine and also to share data between - * different state - */ - struct ThreadData { - /// Index of the thread that this data belongs to - const int threadIndex_; - - /** - * Server socket object that provides functionality such as listen() - * accept, read, write on the socket - */ - ServerSocket &socket_; - - /// Statistics of the transfer for this thread - TransferStats &threadStats_; - - /// protocol version for this thread. This per thread protocol version is - /// kept separately from the global one to avoid locking - int threadProtocolVersion_; - - /// Buffer that receivers reads data into from the network - std::unique_ptr buf_; - /// Maximum size of the buffer - const int64_t bufferSize_; - - /// Marks the number of bytes already read in the buffer - int64_t numRead_{0}; - - /// Following two are markers to mark how much data has been read/parsed - int64_t off_{0}; - int64_t oldOffset_{0}; - - /// number of checkpoints already transferred - int checkpointIndex_{0}; - - /** - * Pending value of checkpoint count. Since write call success does not - * guarantee actual transfer, we do not apply checkpoint count update after - * the write. Only after receiving next cmd from sender, we apply the - * update - */ - int pendingCheckpointIndex_{0}; - - /// a counter incremented each time a new session starts for this thread - int64_t transferStartedCount_{0}; - - /// a counter incremented each time a new session ends for this thread - int64_t transferFinishedCount_{0}; - - /// read timeout for sender - int64_t senderReadTimeout_{-1}; - - /// write timeout for sender - int64_t senderWriteTimeout_{-1}; - - /// whether checksum verification is enabled or not - bool enableChecksum_{false}; - - /** - * Whether SEND_DONE_CMD state has already failed for this session or not. - * This has to be separately handled, because session barrier is - * implemented before sending done cmd - */ - bool doneSendFailure_{false}; - - /// Checkpoints that have not been sent back to the sender - std::vector newCheckpoints_; - - /// whether settings has been received and verified for the current - /// connection. This is used to determine round robin order for polling in - /// the server socket - bool curConnectionVerified_{false}; - - /// whether the transfer is in block mode or not - bool isBlockMode_{true}; - - Checkpoint checkpoint_; - - /// Constructor for thread data - ThreadData(int threadIndex, ServerSocket &socket, - TransferStats &threadStats, int protocolVersion, - int64_t bufferSize) - : threadIndex_(threadIndex), - socket_(socket), - threadStats_(threadStats), - threadProtocolVersion_(protocolVersion), - bufferSize_(bufferSize), - checkpoint_(socket.getPort()) { - buf_.reset(new char[bufferSize_]); - } - - /** - * In long running mode, we need to reset thread variables after each - * session. Before starting each session, reset() has to called to do that. - */ - void reset() { - numRead_ = off_ = 0; - checkpointIndex_ = pendingCheckpointIndex_ = 0; - doneSendFailure_ = false; - senderReadTimeout_ = senderWriteTimeout_ = -1; - curConnectionVerified_ = false; - threadStats_.reset(); - } - - /// Get the raw pointer to the buffer - char *getBuf() { - return buf_.get(); - } - }; - - /// Overloaded operator for printing thread info - friend std::ostream &operator<<(std::ostream &os, const ThreadData &data); - - typedef ReceiverState (Receiver::*StateFunction)(ThreadData &data); + /// Get the transferred file chunks info + const std::vector &getFileChunksInfo() const; - /** - * Tries to listen/bind to port. If this fails, thread is considered failed. - * Previous states : n/a (start state) - * Next states : ACCEPT_FIRST_CONNECTION(success), - * FAILED(failure) - */ - ReceiverState listen(ThreadData &data); - /** - * Tries to accept first connection of a new session. Periodically checks - * whether a new session has started or not. If a new session has started then - * goes to ACCEPT_WITH_TIMEOUT state. Also does session initialization. In - * joinable mode, tries to accept for a limited number of user specified - * retries. - * Previous states : LISTEN, - * END(if in long running mode) - * Next states : ACCEPT_WITH_TIMEOUT(if a new transfer has started and this - * thread has not received a connection), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(if did not receive a - * connection in specified number of retries), - * READ_NEXT_CMD(if a connection was received) - */ - ReceiverState acceptFirstConnection(ThreadData &data); - /** - * Tries to accept a connection with timeout. There are 2 kinds of timeout. At - * the beginning of the session, it uses accept window as the timeout. Later - * when sender settings are known it uses max(readTimeOut, writeTimeout)) + - * buffer(500) as the timeout. - * Previous states : Almost all states(for any network errors during transfer, - * we transition to this state), - * Next states : READ_NEXT_CMD(if there are no previous errors and accept - * was successful), - * SEND_LOCAL_CHECKPOINT(if there were previous errors and - * accept was successful), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(if accept failed and - * transfer previously failed during SEND_DONE_CMD state. Thus - * case needs special handling to ensure that we do not mess up - * session variables), - * END(if accept fails otherwise) - */ - ReceiverState acceptWithTimeout(ThreadData &data); - /** - * Sends local checkpoint to the sender. In case of previous error during - * SEND_LOCAL_CHECKPOINT state, we send -1 as the checkpoint. - * Previous states : ACCEPT_WITH_TIMEOUT - * Next states : ACCEPT_WITH_TIMEOUT(if sending fails), - * SEND_DONE_CMD(if send is successful and we have previous - * SEND_DONE_CMD error), - * READ_NEXT_CMD(if send is successful otherwise) - */ - ReceiverState sendLocalCheckpoint(ThreadData &data); - /** - * Reads next cmd and transitions to the state accordingly. - * Previous states : SEND_LOCAL_CHECKPOINT, - * ACCEPT_FIRST_CONNECTION, - * ACCEPT_WITH_TIMEOUT, - * PROCESS_SETTINGS_CMD, - * PROCESS_FILE_CMD, - * SEND_GLOBAL_CHECKPOINTS, - * Next states : PROCESS_FILE_CMD, - * PROCESS_DONE_CMD, - * PROCESS_SETTINGS_CMD, - * PROCESS_SIZE_CMD, - * ACCEPT_WITH_TIMEOUT(in case of read failure), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(in case of protocol errors) - */ - ReceiverState readNextCmd(ThreadData &data); - /** - * Processes file cmd. Logic of how we write the file to the destination - * directory is defined here. - * Previous states : READ_NEXT_CMD - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error), - * ACCEPT_WITH_TIMEOUT(socket read failure) - */ - ReceiverState processFileCmd(ThreadData &data); - /** - * Processes settings cmd. Settings has a connection settings, - * protocol version, transfer id, etc. For more info check Protocol.h - * Previous states : READ_NEXT_CMD, - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error), - * ACCEPT_WITH_TIMEOUT(socket read failure), - * SEND_FILE_CHUNKS(If the sender wants to resume transfer) - */ - ReceiverState processSettingsCmd(ThreadData &data); - /** - * Processes done cmd. Also checks to see if there are any new global - * checkpoints or not - * Previous states : READ_NEXT_CMD, - * Next states : WAIT_FOR_FINISH_OR_ERROR(protocol error), - * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT(success), - * SEND_GLOBAL_CHECKPOINTS(if there are global errors) - */ - ReceiverState processDoneCmd(ThreadData &data); - /** - * Processes size cmd. Sets the value of totalSenderBytes_ - * Previous states : READ_NEXT_CMD, - * Next states : READ_NEXT_CMD(success), - * WAIT_FOR_FINISH_WITH_THREAD_ERROR(protocol error) - */ - ReceiverState processSizeCmd(ThreadData &data); - /** - * Sends file chunks that were received successfully in any previous transfer, - * this is the first step in download resumption. - * Checks to see if they have already been transferred or not. - * If yes, send ACK. If some other thread is sending it, sends wait cmd - * and checks again later. Otherwise, breaks the entire data into bufferSIze_ - * chunks and sends it. - * Previous states: PROCESS_SETTINGS_CMD, - * Next states : ACCEPT_WITH_TIMEOUT(network error), - * READ_NEXT_CMD(success) - */ - ReceiverState sendFileChunks(ThreadData &data); - /** - * Sends global checkpoints to sender - * Previous states : PROCESS_DONE_CMD, - * WAIT_FOR_FINISH_OR_ERROR - * Next states : READ_NEXT_CMD(success), - * ACCEPT_WITH_TIMEOUT(socket write failure) - */ - ReceiverState sendGlobalCheckpoint(ThreadData &data); - /** - * Sends DONE to sender, also tries to read back ack. If anything fails during - * this state, doneSendFailure_ thread variable is set. This flag makes the - * state machine behave differently, effectively bypassing all session related - * things. - * Previous states : SEND_LOCAL_CHECKPOINT, - * WAIT_FOR_FINISH_OR_ERROR - * Next states : END(success), - * ACCEPT_WITH_TIMEOUT(failure) - */ - ReceiverState sendDoneCmd(ThreadData &data); - - /** - * Sends ABORT cmd back to the sender - * Previous states : PROCESS_FILE_CMD - * Next states : WAIT_FOR_FINISH_WITH_THREAD_ERROR - */ - ReceiverState sendAbortCmd(ThreadData &data); - - /** - * Waits for transfer to finish or new checkpoints. This state first - * increments waitingThreadCount_. Then, it - * waits till all the threads have finished. It sends periodic WAIT signal to - * prevent sender from timing out. If a new checkpoint is found, we move to - * SEND_GLOBAL_CHECKPOINTS state. - * Previous states : PROCESS_DONE_CMD - * Next states : SEND_DONE_CMD(all threads finished), - * SEND_GLOBAL_CHECKPOINTS(if new checkpoints are found), - * ACCEPT_WITH_TIMEOUT(if socket write fails) - */ - ReceiverState waitForFinishOrNewCheckpoint(ThreadData &data); - - /** - * Waits for transfer to finish. Only called when there is an error for the - * thread. It adds a checkpoint to the global list of checkpoints if a - * connection was received. It increments waitingWithErrorThreadCount_ and - * waits till the session ends. - * Previous states : Almost all states - * Next states : END - */ - ReceiverState waitForFinishWithThreadError(ThreadData &data); + /// Get file creator, used by receiver threads + std::unique_ptr &getFileCreator(); - /// Mapping from receiver states to state functions - static const StateFunction stateMap_[]; + /// Get the ref to transfer log manager + TransferLogManager &getTransferLogManager(); /// Responsible for basic setup and starting threads void start(); - /// This method is the entry point for each thread. - void receiveOne(int threadIndex, ServerSocket &s, int64_t bufferSize, - TransferStats &threadStats); - /** * Periodically calculates current transfer report and send it to progress * reporter. This only works in the single transfer mode. @@ -445,45 +146,20 @@ class Receiver : public WdtBase { */ std::vector getNewCheckpoints(int startIndex); - /// Returns true if all threads finished for this session - bool areAllThreadsFinished(bool checkpointAdded); - - /// Ends current global session - void endCurGlobalSession(); - - /** - * Returns if a new session has started and the thread is not aware of it - * A thread must hold lock on mutex_ before calling this - */ - bool hasNewSessionStarted(ThreadData &data); + /// Does the steps needed before a new transfer is started + void startNewGlobalSession(const std::string &peerIp); - /** - * Start new transfer by incrementing transferStartedCount_ - * A thread must hold lock on mutex_ before calling this - */ - void startNewGlobalSession(ThreadData &data); + /// Returns true if at least one thread has accepted connection + bool hasNewTransferStarted() const; - /** - * Returns whether the current session has finished or not. - * A thread must hold lock on mutex_ before calling this - */ - bool hasCurSessionFinished(ThreadData &data); - - /// Starts a new session for the thread - void startNewThreadSession(ThreadData &data); - - /// Ends current thread session - void endCurThreadSession(ThreadData &data); - - /// Increments failed thread count, does not wait for transfer to finish - void incrFailedThreadCountAndCheckForSessionEnd(ThreadData &data); + /// Has steps to do when the current transfer is ended + void endCurGlobalSession(); /// adds log header and also a directory invalidation entry if needed void addTransferLogHeader(bool isBlockMode, bool isSenderResuming); /// fix and close transfer log void fixAndCloseTransferLog(bool transferSuccess); - /** * Get transfer report, meant to be called after threads have been finished * This method is not thread safe @@ -495,6 +171,7 @@ class Receiver : public WdtBase { /// The thread that is responsible for calling running the progress tracker std::thread progressTrackerThread_; + /** * Flags that represents if a transfer has finished. Threads on completion * set this flag. This is always accurate even if you don't call finish() @@ -509,7 +186,7 @@ class Receiver : public WdtBase { std::string destDir_; /// Responsible for writing files on the disk - std::unique_ptr fileCreator_; + std::unique_ptr fileCreator_{nullptr}; /** * Unique-id used to verify transfer log. This value must be same for @@ -531,64 +208,14 @@ class Receiver : public WdtBase { * it has to be made sure that these threads are joined at least before * the destruction of this object. */ - std::vector receiverThreads_; - - /** - * start() gives each thread the instance of the serversocket, these - * sockets can be closed and changed completely by the progress tracker - * thread thus ending any hope of receiver threads doing any further - * successful transfer - */ - std::vector threadServerSockets_; - - /** - * Bunch of stats objects given to each thread by the root thread - * so that finish() can summarize the result at the end of joining. - */ - std::vector threadStats_; - - /// Per thread perf report - std::vector perfReports_; + std::vector> receiverThreads_; /// Transfer log manager TransferLogManager transferLogManager_; - /// Enum representing status of file chunks transfer - enum SendChunkStatus { NOT_STARTED, IN_PROGRESS, SENT }; - - /// State of the receiver when sending file chunks in sendFileChunksCmd - SendChunkStatus sendChunksStatus_{NOT_STARTED}; - - /** - * All threads coordinate with each other to send previously received file - * chunks using this condition variable. - */ - mutable std::condition_variable conditionFileChunksSent_; - - /// Number of blocks sent by the sender - int64_t numBlocksSend_{-1}; - /// Global list of checkpoints std::vector checkpoints_; - /// Number of threads which failed in the transfer - int failedThreadCount_{0}; - - /// Number of threads which are waiting for finish or new checkpoint - int waitingThreadCount_{0}; - - /// Number of threads which are waiting with an error - int waitingWithErrorThreadCount_{0}; - - /// Counter that is incremented each time a new session starts - int64_t transferStartedCount_{0}; - - /// Counter that is incremented each time a new session ends - int64_t transferFinishedCount_{0}; - - /// Total number of data bytes sender wants to transfer - int64_t totalSenderBytes_{-1}; - /// Start time of the session std::chrono::time_point startTime_; @@ -598,8 +225,8 @@ class Receiver : public WdtBase { /// Mutex to guard all the shared variables mutable std::mutex mutex_; - /// Condition variable to coordinate transfer finish - mutable std::condition_variable conditionAllFinished_; + /// Marks when a new transfer has started + std::atomic hasNewTransferStarted_{false}; /** * Returns true if threads have been joined (done in finish()) @@ -607,14 +234,17 @@ class Receiver : public WdtBase { */ bool areThreadsJoined_{false}; - /// Number of active threads, decremented every time a thread is finished - int32_t numActiveThreads_{0}; - /** * Mutex for the management of this instance, specifically to keep the * instance sane for multi threaded public API calls */ std::mutex instanceManagementMutex_; + + /// Buffer size used by this receiver + int64_t bufferSize_; + + /// Backlog used by the sockets + int backlog_; }; } } // namespace facebook::wdt diff --git a/ReceiverThread.cpp b/ReceiverThread.cpp new file mode 100644 index 00000000..902b72c6 --- /dev/null +++ b/ReceiverThread.cpp @@ -0,0 +1,905 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ReceiverThread.h" +#include "FileWriter.h" +#include +#include +#include +#include +#include +#include + +namespace facebook { +namespace wdt { + +const static int kTimeoutBufferMillis = 1000; +const static int kWaitTimeoutFactor = 5; +std::ostream &operator<<(std::ostream &os, + const ReceiverThread &receiverThread) { + os << "Thread[" << receiverThread.threadIndex_ + << ", port: " << receiverThread.socket_.getPort() << "] "; + return os; +} + +int64_t readAtLeast(ServerSocket &s, char *buf, int64_t max, int64_t atLeast, + int64_t len) { + VLOG(4) << "readAtLeast len " << len << " max " << max << " atLeast " + << atLeast << " from " << s.getFd(); + CHECK_GE(len, 0); + CHECK_GT(atLeast, 0); + CHECK_LE(atLeast, max); + int count = 0; + while (len < atLeast) { + // because we want to process data as soon as it arrives, tryFull option for + // read is false + int64_t n = s.read(buf + len, max - len, false); + if (n < 0) { + PLOG(ERROR) << "Read error on " << s.getPort() << " after " << count; + if (len) { + return len; + } else { + return n; + } + } + if (n == 0) { + VLOG(2) << "Eof on " << s.getPort() << " after " << count << " reads " + << "got " << len; + return len; + } + len += n; + count++; + } + VLOG(3) << "Took " << count << " reads to get " << len + << " from fd : " << s.getFd(); + return len; +} + +int64_t readAtMost(ServerSocket &s, char *buf, int64_t max, int64_t atMost) { + const int64_t target = atMost < max ? atMost : max; + VLOG(3) << "readAtMost target " << target; + // because we want to process data as soon as it arrives, tryFull option for + // read is false + int64_t n = s.read(buf, target, false); + if (n < 0) { + PLOG(ERROR) << "Read error on " << s.getPort() << " with target " << target; + return n; + } + if (n == 0) { + LOG(WARNING) << "Eof on " << s.getFd(); + return n; + } + VLOG(3) << "readAtMost " << n << " / " << atMost << " from " << s.getFd(); + return n; +} + +const ReceiverThread::StateFunction ReceiverThread::stateMap_[] = { + &ReceiverThread::listen, &ReceiverThread::acceptFirstConnection, + &ReceiverThread::acceptWithTimeout, &ReceiverThread::sendLocalCheckpoint, + &ReceiverThread::readNextCmd, &ReceiverThread::processFileCmd, + &ReceiverThread::processSettingsCmd, &ReceiverThread::processDoneCmd, + &ReceiverThread::processSizeCmd, &ReceiverThread::sendFileChunks, + &ReceiverThread::sendGlobalCheckpoint, &ReceiverThread::sendDoneCmd, + &ReceiverThread::sendAbortCmd, + &ReceiverThread::waitForFinishOrNewCheckpoint, + &ReceiverThread::finishWithError}; + +ReceiverThread::ReceiverThread(Receiver *wdtParent, int threadIndex, + int32_t port, ThreadsController *controller) + : WdtThread(threadIndex, wdtParent->getProtocolVersion(), controller), + wdtParent_(wdtParent), + socket_(port, wdtParent->backlog_, &(wdtParent->abortCheckerCallback_)), + bufferSize_(wdtParent->bufferSize_) { + controller_->registerThread(threadIndex_); + buf_ = new char[bufferSize_]; +} + +/**LISTEN STATE***/ +ReceiverState ReceiverThread::listen() { + VLOG(1) << *this << " entered LISTEN state "; + const auto &options = WdtOptions::get(); + const bool doActualWrites = !options.skip_writes; + int32_t port = socket_.getPort(); + VLOG(1) << "Server Thread for port " << port << " with backlog " + << socket_.getBackLog() << " on " << wdtParent_->getDir() + << " writes = " << doActualWrites; + + for (int retry = 1; retry < options.max_retries; ++retry) { + ErrorCode code = socket_.listen(); + if (code == OK) { + break; + } else if (code == CONN_ERROR) { + threadStats_.setErrorCode(code); + return FAILED; + } + LOG(INFO) << "Sleeping after failed attempt " << retry; + /* sleep override */ + usleep(options.sleep_millis * 1000); + } + // one more/last try (stays true if it worked above) + if (socket_.listen() != OK) { + LOG(ERROR) << "Unable to listen/bind despite retries"; + threadStats_.setErrorCode(CONN_ERROR); + return FAILED; + } + return ACCEPT_FIRST_CONNECTION; +} + +/***ACCEPT_FIRST_CONNECTION***/ +ReceiverState ReceiverThread::acceptFirstConnection() { + VLOG(1) << *this << " entered ACCEPT_FIRST_CONNECTION state "; + const auto &options = WdtOptions::get(); + reset(); + socket_.closeCurrentConnection(); + auto timeout = options.accept_timeout_millis; + int acceptAttempts = 0; + while (true) { + // Move to timeout state if some other thread was successful + // in getting a connection + if (wdtParent_->hasNewTransferStarted()) { + return ACCEPT_WITH_TIMEOUT; + } + if (acceptAttempts == options.max_accept_retries) { + LOG(ERROR) << "unable to accept after " << acceptAttempts << " attempts"; + threadStats_.setErrorCode(CONN_ERROR); + return FAILED; + } + if (wdtParent_->getCurAbortCode() != OK) { + LOG(ERROR) << "Thread marked to abort while trying to accept first" + << " connection. Num attempts " << acceptAttempts; + // Even though there is a transition FAILED here + // getCurAbortCode() is going to be checked again in the receiveOne. + // So this is pretty much irrelevant + return FAILED; + } + ErrorCode code = + socket_.acceptNextConnection(timeout, curConnectionVerified_); + if (code == OK) { + break; + } + ++acceptAttempts; + } + // Make the parent start new global session. This is executed + // only by the first thread that calls this function + controller_->executeAtStart( + [&]() { wdtParent_->startNewGlobalSession(socket_.getPeerIp()); }); + return READ_NEXT_CMD; +} + +/***ACCEPT_WITH_TIMEOUT STATE***/ +ReceiverState ReceiverThread::acceptWithTimeout() { + LOG(INFO) << *this << " entered ACCEPT_WITH_TIMEOUT state "; + const auto &options = WdtOptions::get(); + socket_.closeCurrentConnection(); + auto timeout = options.accept_window_millis; + if (senderReadTimeout_ > 0) { + // transfer is in progress and we have already got sender settings + timeout = std::max(senderReadTimeout_, senderWriteTimeout_) + + kTimeoutBufferMillis; + } + ErrorCode code = + socket_.acceptNextConnection(timeout, curConnectionVerified_); + curConnectionVerified_ = false; + if (code != OK) { + LOG(ERROR) << "accept() failed with timeout " << timeout; + threadStats_.setErrorCode(code); + if (doneSendFailure_) { + // if SEND_DONE_CMD state had already been reached, we do not need to + // wait for other threads to end + return END; + } + return FINISH_WITH_ERROR; + } + + if (doneSendFailure_) { + // no need to reset any session variables in this case + return SEND_LOCAL_CHECKPOINT; + } + + numRead_ = off_ = 0; + pendingCheckpointIndex_ = checkpointIndex_; + ReceiverState nextState = READ_NEXT_CMD; + if (threadStats_.getErrorCode() != OK) { + nextState = SEND_LOCAL_CHECKPOINT; + } + // reset thread status + threadStats_.setErrorCode(OK); + return nextState; +} + +/***SEND_LOCAL_CHECKPOINT STATE***/ +ReceiverState ReceiverThread::sendLocalCheckpoint() { + LOG(INFO) << *this << " entered SEND_LOCAL_CHECKPOINT state "; + std::vector checkpoints; + if (doneSendFailure_) { + // in case SEND_DONE failed, a special checkpoint(-1) is sent to signal this + // condition + Checkpoint localCheckpoint(socket_.getPort()); + localCheckpoint.numBlocks = -1; + checkpoints.emplace_back(localCheckpoint); + } else { + checkpoints.emplace_back(checkpoint_); + } + + int64_t off = 0; + const int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + Protocol::encodeCheckpoints(threadProtocolVersion_, buf_, off, checkpointLen, + checkpoints); + int written = socket_.write(buf_, checkpointLen); + if (written != checkpointLen) { + LOG(ERROR) << "unable to write local checkpoint. write mismatch " + << checkpointLen << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(checkpointLen); + if (doneSendFailure_) { + return SEND_DONE_CMD; + } + return READ_NEXT_CMD; +} + +/***READ_NEXT_CMD***/ +ReceiverState ReceiverThread::readNextCmd() { + VLOG(1) << *this << " entered READ_NEXT_CMD state "; + oldOffset_ = off_; + numRead_ = readAtLeast(socket_, buf_ + off_, bufferSize_ - off_, + Protocol::kMinBufLength, numRead_); + if (numRead_ < Protocol::kMinBufLength) { + LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " + << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[off_++]; + if (cmd == Protocol::DONE_CMD) { + return PROCESS_DONE_CMD; + } + if (cmd == Protocol::FILE_CMD) { + return PROCESS_FILE_CMD; + } + if (cmd == Protocol::SETTINGS_CMD) { + return PROCESS_SETTINGS_CMD; + } + if (cmd == Protocol::SIZE_CMD) { + return PROCESS_SIZE_CMD; + } + LOG(ERROR) << "received an unknown cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; +} + +/***PROCESS_SETTINGS_CMD***/ +ReceiverState ReceiverThread::processSettingsCmd() { + VLOG(1) << *this << " entered PROCESS_SETTINGS_CMD state "; + Settings settings; + int senderProtocolVersion; + + bool success = Protocol::decodeVersion( + buf_, off_, oldOffset_ + Protocol::kMaxVersion, senderProtocolVersion); + if (!success) { + LOG(ERROR) << "Unable to decode version " << threadIndex_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + if (senderProtocolVersion != threadProtocolVersion_) { + LOG(ERROR) << "Receiver and sender protocol version mismatch " + << senderProtocolVersion << " " << threadProtocolVersion_; + int negotiatedProtocol = Protocol::negotiateProtocol( + senderProtocolVersion, threadProtocolVersion_); + if (negotiatedProtocol == 0) { + LOG(WARNING) << "Can not support sender with version " + << senderProtocolVersion << ", aborting!"; + threadStats_.setErrorCode(VERSION_INCOMPATIBLE); + return SEND_ABORT_CMD; + } else { + LOG_IF(INFO, threadProtocolVersion_ != negotiatedProtocol) + << "Changing receiver protocol version to " << negotiatedProtocol; + threadProtocolVersion_ = negotiatedProtocol; + if (negotiatedProtocol != senderProtocolVersion) { + threadStats_.setErrorCode(VERSION_MISMATCH); + return SEND_ABORT_CMD; + } + } + } + + success = Protocol::decodeSettings( + threadProtocolVersion_, buf_, off_, + oldOffset_ + Protocol::kMaxVersion + Protocol::kMaxSettings, settings); + if (!success) { + LOG(ERROR) << *this << "Unable to decode settings cmd "; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + auto senderId = settings.transferId; + auto transferId = wdtParent_->getTransferId(); + if (transferId != senderId) { + LOG(ERROR) << "Receiver and sender id mismatch " << senderId << " " + << transferId; + threadStats_.setErrorCode(ID_MISMATCH); + return SEND_ABORT_CMD; + } + senderReadTimeout_ = settings.readTimeoutMillis; + senderWriteTimeout_ = settings.writeTimeoutMillis; + enableChecksum_ = settings.enableChecksum; + isBlockMode_ = !settings.blockModeDisabled; + curConnectionVerified_ = true; + if (settings.sendFileChunks) { + // We only move to SEND_FILE_CHUNKS state, if download resumption is enabled + // in the sender side + numRead_ = off_ = 0; + return SEND_FILE_CHUNKS; + } + auto msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + return READ_NEXT_CMD; +} + +/***PROCESS_FILE_CMD***/ +ReceiverState ReceiverThread::processFileCmd() { + VLOG(1) << *this << " entered PROCESS_FILE_CMD state "; + const auto &options = WdtOptions::get(); + // following block needs to be executed for the first file cmd. There is no + // harm in executing it more than once. number of blocks equal to 0 is a good + // approximation for first file cmd. Did not want to introduce another boolean + if (options.enable_download_resumption && threadStats_.getNumBlocks() == 0) { + auto sendChunksFunnel = controller_->getFunnel(SEND_FILE_CHUNKS_FUNNEL); + auto state = sendChunksFunnel->getStatus(); + if (state == FUNNEL_START) { + // sender is not in resumption mode + wdtParent_->addTransferLogHeader(isBlockMode_, + /* sender not resuming */ false); + sendChunksFunnel->notifySuccess(); + } + } + checkpoint_.resetLastBlockDetails(); + BlockDetails blockDetails; + auto guard = folly::makeGuard([&] { + if (threadStats_.getErrorCode() != OK) { + threadStats_.incrFailedAttempts(); + } + }); + + ErrorCode transferStatus = (ErrorCode)buf_[off_++]; + if (transferStatus != OK) { + // TODO: use this status information to implement fail fast mode + VLOG(1) << "sender entered into error state " + << errorCodeToStr(transferStatus); + } + int16_t headerLen = folly::loadUnaligned(buf_ + off_); + headerLen = folly::Endian::little(headerLen); + VLOG(2) << "Processing FILE_CMD, header len " << headerLen; + + if (headerLen > numRead_) { + int64_t end = oldOffset_ + numRead_; + numRead_ = readAtLeast(socket_, buf_ + end, bufferSize_ - end, headerLen, + numRead_); + } + if (numRead_ < headerLen) { + LOG(ERROR) << "Unable to read full header " << headerLen << " " << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + off_ += sizeof(int16_t); + bool success = Protocol::decodeHeader(threadProtocolVersion_, buf_, off_, + numRead_ + oldOffset_, blockDetails); + int64_t headerBytes = off_ - oldOffset_; + // transferred header length must match decoded header length + WDT_CHECK_EQ(headerLen, headerBytes) << " " << blockDetails.fileName << " " + << blockDetails.seqId << " " + << threadProtocolVersion_; + threadStats_.addHeaderBytes(headerBytes); + if (!success) { + LOG(ERROR) << "Error decoding at" + << " ooff:" << oldOffset_ << " off_: " << off_ + << " numRead_: " << numRead_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + + // received a well formed file cmd, apply the pending checkpoint update + checkpointIndex_ = pendingCheckpointIndex_; + VLOG(1) << "Read id:" << blockDetails.fileName + << " size:" << blockDetails.dataSize << " ooff:" << oldOffset_ + << " off_: " << off_ << " numRead_: " << numRead_; + auto &fileCreator = wdtParent_->getFileCreator(); + FileWriter writer(threadIndex_, &blockDetails, fileCreator.get()); + auto writtenGuard = folly::makeGuard([&] { + if (threadProtocolVersion_ >= Protocol::CHECKPOINT_OFFSET_VERSION) { + // considering partially written block contents as valid, this bypasses + // checksum verification + // TODO: Make sure checksum verification work with checkpoint offsets + checkpoint_.setLastBlockDetails(blockDetails.seqId, blockDetails.offset, + writer.getTotalWritten()); + threadStats_.addEffectiveBytes(headerBytes, writer.getTotalWritten()); + } + }); + if (writer.open() != OK) { + threadStats_.setErrorCode(FILE_WRITE_ERROR); + return SEND_ABORT_CMD; + } + int32_t checksum = 0; + int64_t remainingData = numRead_ + oldOffset_ - off_; + int64_t toWrite = remainingData; + WDT_CHECK(remainingData >= 0); + if (remainingData >= blockDetails.dataSize) { + toWrite = blockDetails.dataSize; + } + threadStats_.addDataBytes(toWrite); + if (enableChecksum_) { + checksum = folly::crc32c((const uint8_t *)(buf_ + off_), toWrite, checksum); + } + auto throttler = wdtParent_->getThrottler(); + if (throttler) { + // We might be reading more than we require for this file but + // throttling should make sense for any additional bytes received + // on the network + throttler->limit(toWrite + headerBytes); + } + ErrorCode code = writer.write(buf_ + off_, toWrite); + if (code != OK) { + threadStats_.setErrorCode(code); + return SEND_ABORT_CMD; + } + off_ += toWrite; + remainingData -= toWrite; + // also means no leftOver so it's ok we use buf_ from start + while (writer.getTotalWritten() < blockDetails.dataSize) { + if (wdtParent_->getCurAbortCode() != OK) { + LOG(ERROR) << "Thread marked for abort while processing a file." + << " port : " << socket_.getPort(); + return FAILED; + } + int64_t nres = readAtMost(socket_, buf_, bufferSize_, + blockDetails.dataSize - writer.getTotalWritten()); + if (nres <= 0) { + break; + } + if (throttler) { + // We only know how much we have read after we are done calling + // readAtMost. Call throttler with the bytes read off_ the wire. + throttler->limit(nres); + } + threadStats_.addDataBytes(nres); + if (enableChecksum_) { + checksum = folly::crc32c((const uint8_t *)buf_, nres, checksum); + } + code = writer.write(buf_, nres); + if (code != OK) { + threadStats_.setErrorCode(code); + return SEND_ABORT_CMD; + } + } + if (writer.getTotalWritten() != blockDetails.dataSize) { + // This can only happen if there are transmission errors + // Write errors to disk are already taken care of above + LOG(ERROR) << "could not read entire content for " << blockDetails.fileName + << " port " << socket_.getPort(); + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + writtenGuard.dismiss(); + VLOG(2) << "completed " << blockDetails.fileName << " off: " << off_ + << " numRead: " << numRead_; + // Transfer of the file is complete here, mark the bytes effective + WDT_CHECK(remainingData >= 0) << "Negative remainingData " << remainingData; + if (remainingData > 0) { + // if we need to read more anyway, let's move the data + numRead_ = remainingData; + if ((remainingData < Protocol::kMaxHeader) && (off_ > (bufferSize_ / 2))) { + // rare so inefficient is ok + VLOG(3) << "copying extra " << remainingData << " leftover bytes @ " + << off_; + memmove(/* dst */ buf_, + /* from */ buf_ + off_, + /* how much */ remainingData); + off_ = 0; + } else { + // otherwise just continue from the offset + VLOG(3) << "Using remaining extra " << remainingData + << " leftover bytes starting @ " << off_; + } + } else { + numRead_ = off_ = 0; + } + if (enableChecksum_) { + // have to read footer cmd + oldOffset_ = off_; + numRead_ = readAtLeast(socket_, buf_ + off_, bufferSize_ - off_, + Protocol::kMinBufLength, numRead_); + if (numRead_ < Protocol::kMinBufLength) { + LOG(ERROR) << "socket read failure " << Protocol::kMinBufLength << " " + << numRead_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[off_++]; + if (cmd != Protocol::FOOTER_CMD) { + LOG(ERROR) << "Expecting footer cmd, but received " << cmd; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + int32_t receivedChecksum; + bool success = Protocol::decodeFooter( + buf_, off_, oldOffset_ + Protocol::kMaxFooter, receivedChecksum); + if (!success) { + LOG(ERROR) << "Unable to decode footer cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + if (checksum != receivedChecksum) { + LOG(ERROR) << "Checksum mismatch " << checksum << " " << receivedChecksum + << " port " << socket_.getPort() << " file " + << blockDetails.fileName; + threadStats_.setErrorCode(CHECKSUM_MISMATCH); + return ACCEPT_WITH_TIMEOUT; + } + int64_t msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + } + auto &transferLogManager = wdtParent_->getTransferLogManager(); + if (options.isLogBasedResumption()) { + transferLogManager.addBlockWriteEntry( + blockDetails.seqId, blockDetails.offset, blockDetails.dataSize); + } + threadStats_.addEffectiveBytes(headerBytes, blockDetails.dataSize); + threadStats_.incrNumBlocks(); + checkpoint_.incrNumBlocks(); + return READ_NEXT_CMD; +} + +ReceiverState ReceiverThread::processDoneCmd() { + VLOG(1) << *this << " entered PROCESS_DONE_CMD state "; + if (numRead_ != Protocol::kMinBufLength) { + LOG(ERROR) << "Unexpected state for done command" + << " off_: " << off_ << " numRead_: " << numRead_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + + ErrorCode senderStatus = (ErrorCode)buf_[off_++]; + int64_t numBlocksSend = -1; + int64_t totalSenderBytes = -1; + bool success = Protocol::decodeDone(threadProtocolVersion_, buf_, off_, + oldOffset_ + Protocol::kMaxDone, + numBlocksSend, totalSenderBytes); + if (!success) { + LOG(ERROR) << "Unable to decode done cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + threadStats_.setNumBlocksSend(numBlocksSend); + threadStats_.setTotalSenderBytes(totalSenderBytes); + threadStats_.setRemoteErrorCode(senderStatus); + + // received a valid command, applying pending checkpoint write update + checkpointIndex_ = pendingCheckpointIndex_; + return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; +} + +ReceiverState ReceiverThread::processSizeCmd() { + VLOG(1) << *this << " entered PROCESS_SIZE_CMD state "; + int64_t totalSenderBytes; + bool success = Protocol::decodeSize( + buf_, off_, oldOffset_ + Protocol::kMaxSize, totalSenderBytes); + if (!success) { + LOG(ERROR) << "Unable to decode size cmd"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return FINISH_WITH_ERROR; + } + VLOG(1) << "Number of bytes to receive " << totalSenderBytes; + threadStats_.setTotalSenderBytes(totalSenderBytes); + auto msgLen = off_ - oldOffset_; + numRead_ -= msgLen; + return READ_NEXT_CMD; +} + +ReceiverState ReceiverThread::sendFileChunks() { + LOG(INFO) << *this << " entered SEND_FILE_CHUNKS state "; + WDT_CHECK(senderReadTimeout_ > 0); // must have received settings + int waitingTimeMillis = senderReadTimeout_ / kWaitTimeoutFactor; + auto execFunnel = controller_->getFunnel(SEND_FILE_CHUNKS_FUNNEL); + while (true) { + auto status = execFunnel->getStatus(); + switch (status) { + case FUNNEL_END: { + buf_[0] = Protocol::ACK_CMD; + int toWrite = 1; + int written = socket_.write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << *this << " socket write error " << toWrite << " " + << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(toWrite); + return READ_NEXT_CMD; + } + case FUNNEL_PROGRESS: { + buf_[0] = Protocol::WAIT_CMD; + int toWrite = 1; + int written = socket_.write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << *this << " socket write error " << toWrite << " " + << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(toWrite); + execFunnel->wait(waitingTimeMillis); + break; + } + case FUNNEL_START: { + int64_t off = 0; + buf_[off++] = Protocol::CHUNKS_CMD; + const auto &fileChunksInfo = wdtParent_->getFileChunksInfo(); + const int64_t numParsedChunksInfo = fileChunksInfo.size(); + Protocol::encodeChunksCmd(buf_, off, bufferSize_, numParsedChunksInfo); + int written = socket_.write(buf_, off); + if (written > 0) { + threadStats_.addHeaderBytes(written); + } + if (written != off) { + LOG(ERROR) << "Socket write error " << off << " " << written; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + int64_t numEntriesWritten = 0; + // we try to encode as many chunks as possible in the buffer. If a + // single + // chunk can not fit in the buffer, it is ignored. Format of encoding : + // ... + while (numEntriesWritten < numParsedChunksInfo) { + off = sizeof(int32_t); + int64_t numEntriesEncoded = Protocol::encodeFileChunksInfoList( + buf_, off, bufferSize_, numEntriesWritten, fileChunksInfo); + int32_t dataSize = folly::Endian::little(off - sizeof(int32_t)); + folly::storeUnaligned(buf_, dataSize); + written = socket_.write(buf_, off); + if (written > 0) { + threadStats_.addHeaderBytes(written); + } + if (written != off) { + break; + } + numEntriesWritten += numEntriesEncoded; + } + if (numEntriesWritten != numParsedChunksInfo) { + LOG(ERROR) << "Could not write all the file chunks " + << numParsedChunksInfo << " " << numEntriesWritten; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + // try to read ack + int64_t toRead = 1; + int64_t numRead = socket_.read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + execFunnel->notifyFail(); + return ACCEPT_WITH_TIMEOUT; + } + wdtParent_->addTransferLogHeader(isBlockMode_, + /* sender resuming */ true); + execFunnel->notifySuccess(); + return READ_NEXT_CMD; + } + } + } +} + +ReceiverState ReceiverThread::sendGlobalCheckpoint() { + LOG(INFO) << *this << " entered SEND_GLOBAL_CHECKPOINTS state"; + buf_[0] = Protocol::ERR_CMD; + off_ = 1; + // leave space for length + off_ += sizeof(int16_t); + auto oldOffset = off_; + Protocol::encodeCheckpoints(threadProtocolVersion_, buf_, off_, bufferSize_, + newCheckpoints_); + int16_t length = off_ - oldOffset; + folly::storeUnaligned(buf_ + 1, folly::Endian::little(length)); + + auto written = socket_.write(buf_, off_); + if (written != off_) { + LOG(ERROR) << "unable to write error checkpoints"; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return ACCEPT_WITH_TIMEOUT; + } else { + threadStats_.addHeaderBytes(off_); + pendingCheckpointIndex_ = checkpointIndex_ + newCheckpoints_.size(); + numRead_ = off_ = 0; + return READ_NEXT_CMD; + } +} + +ReceiverState ReceiverThread::sendAbortCmd() { + LOG(INFO) << *this << " entered SEND_ABORT_CMD state "; + int64_t offset = 0; + buf_[offset++] = Protocol::ABORT_CMD; + Protocol::encodeAbort(buf_, offset, threadProtocolVersion_, + threadStats_.getErrorCode(), + threadStats_.getNumFiles()); + socket_.write(buf_, offset); + // No need to check if we were successful in sending ABORT + // This thread will simply disconnect and sender thread on the + // other side will timeout + socket_.closeCurrentConnection(); + threadStats_.addHeaderBytes(offset); + if (threadStats_.getErrorCode() == VERSION_MISMATCH) { + // Receiver should try again expecting sender to have changed its version + return ACCEPT_WITH_TIMEOUT; + } + return FINISH_WITH_ERROR; +} + +ReceiverState ReceiverThread::sendDoneCmd() { + VLOG(1) << *this << " entered SEND_DONE_CMD state "; + buf_[0] = Protocol::DONE_CMD; + if (socket_.write(buf_, 1) != 1) { + PLOG(ERROR) << "unable to send DONE " << threadIndex_; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + + threadStats_.addHeaderBytes(1); + + auto read = socket_.read(buf_, 1); + if (read != 1 || buf_[0] != Protocol::DONE_CMD) { + LOG(ERROR) << *this << " did not receive ack for DONE"; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + + read = socket_.read(buf_, Protocol::kMinBufLength); + if (read != 0) { + LOG(ERROR) << *this << " EOF not found where expected"; + doneSendFailure_ = true; + return ACCEPT_WITH_TIMEOUT; + } + socket_.closeCurrentConnection(); + LOG(INFO) << *this << " got ack for DONE. Transfer finished"; + return END; +} + +ReceiverState ReceiverThread::finishWithError() { + LOG(INFO) << *this << " entered FINISH_WITH_ERROR state "; + // should only be in this state if there is some error + WDT_CHECK(threadStats_.getErrorCode() != OK); + + // close the socket, so that sender receives an error during connect + socket_.closeAll(); + auto cv = controller_->getCondition(WAIT_FOR_FINISH_OR_CHECKPOINT_CV); + auto guard = cv->acquire(); + wdtParent_->addCheckpoint(checkpoint_); + controller_->markState(threadIndex_, FINISHED); + // guard deletion notifies one thread + return END; +} + +ReceiverState ReceiverThread::checkForFinishOrNewCheckpoints() { + auto checkpoints = wdtParent_->getNewCheckpoints(checkpointIndex_); + if (!checkpoints.empty()) { + newCheckpoints_ = std::move(checkpoints); + controller_->markState(threadIndex_, RUNNING); + return SEND_GLOBAL_CHECKPOINTS; + } + bool existActiveThreads = controller_->hasThreads(threadIndex_, RUNNING); + if (!existActiveThreads) { + controller_->markState(threadIndex_, FINISHED); + return SEND_DONE_CMD; + } + return WAIT_FOR_FINISH_OR_NEW_CHECKPOINT; +} + +ReceiverState ReceiverThread::waitForFinishOrNewCheckpoint() { + LOG(INFO) << *this << " entered WAIT_FOR_FINISH_OR_NEW_CHECKPOINT state "; + // should only be called if the are no errors + WDT_CHECK(threadStats_.getErrorCode() == OK); + auto cv = controller_->getCondition(WAIT_FOR_FINISH_OR_CHECKPOINT_CV); + int timeoutMillis = senderReadTimeout_ / kWaitTimeoutFactor; + controller_->markState(threadIndex_, WAITING); + while (true) { + WDT_CHECK(senderReadTimeout_ > 0); // must have received settings + { + auto guard = cv->acquire(); + auto state = checkForFinishOrNewCheckpoints(); + if (state != WAIT_FOR_FINISH_OR_NEW_CHECKPOINT) { + // guard automatically notfies one + return state; + } + START_PERF_TIMER + guard.wait(timeoutMillis); + RECORD_PERF_RESULT(PerfStatReport::RECEIVER_WAIT_SLEEP) + state = checkForFinishOrNewCheckpoints(); + if (state != WAIT_FOR_FINISH_OR_NEW_CHECKPOINT) { + guard.notifyOne(); + return state; + } + } + // send WAIT cmd to keep sender thread alive + buf_[0] = Protocol::WAIT_CMD; + if (socket_.write(buf_, 1) != 1) { + PLOG(ERROR) << *this << " unable to write WAIT "; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + controller_->markState(threadIndex_, RUNNING); + return ACCEPT_WITH_TIMEOUT; + } + threadStats_.addHeaderBytes(1); + } +} + +void ReceiverThread::start() { + INIT_PERF_STAT_REPORT + auto guard = folly::makeGuard([&] { + perfReport_ = *perfStatReport; + LOG(INFO) << *this << threadStats_; + controller_->deRegisterThread(threadIndex_); + controller_->executeAtEnd([&]() { wdtParent_->endCurGlobalSession(); }); + }); + if (!buf_) { + LOG(ERROR) << "error allocating " << bufferSize_; + threadStats_.setErrorCode(MEMORY_ALLOCATION_ERROR); + return; + } + ReceiverState state = LISTEN; + while (true) { + ErrorCode abortCode = wdtParent_->getCurAbortCode(); + if (abortCode != OK) { + LOG(ERROR) << "Transfer aborted " << socket_.getPort() << " " + << errorCodeToStr(abortCode); + threadStats_.setErrorCode(ABORT); + break; + } + if (state == FAILED) { + return; + } + if (state == END) { + return; + } + state = (this->*stateMap_[state])(); + } +} + +int32_t ReceiverThread::getPort() const { + return socket_.getPort(); +} + +ErrorCode ReceiverThread::init() { + int max_retries = WdtOptions::get().max_retries; + for (int retries = 0; retries < max_retries; retries++) { + if (socket_.listen() == OK) { + break; + } + } + if (socket_.listen() != OK) { + LOG(ERROR) << *this << "Couldn't listen on port " << socket_.getPort(); + return ERROR; + } + checkpoint_.port = socket_.getPort(); + LOG(INFO) << "Listening on port " << socket_.getPort(); + return OK; +} + +void ReceiverThread::reset() { + numRead_ = off_ = 0; + checkpointIndex_ = pendingCheckpointIndex_ = 0; + doneSendFailure_ = false; + senderReadTimeout_ = senderWriteTimeout_ = -1; + curConnectionVerified_ = false; + threadStats_.reset(); +} + +ReceiverThread::~ReceiverThread() { + delete[] buf_; +} +} +} diff --git a/ReceiverThread.h b/ReceiverThread.h new file mode 100644 index 00000000..9fbc8cc0 --- /dev/null +++ b/ReceiverThread.h @@ -0,0 +1,341 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "WdtBase.h" +#include "WdtThread.h" +#include "Receiver.h" +#include "ServerSocket.h" +namespace facebook { +namespace wdt { +class Receiver; +/** + * Wdt receiver has logic to maintain the consistency of the + * transfers through connection errors. All threads are run by the logic + * defined as a state machine. These are the all the states in that + * state machine + */ +enum ReceiverState { + LISTEN, + ACCEPT_FIRST_CONNECTION, + ACCEPT_WITH_TIMEOUT, + SEND_LOCAL_CHECKPOINT, + READ_NEXT_CMD, + PROCESS_FILE_CMD, + PROCESS_SETTINGS_CMD, + PROCESS_DONE_CMD, + PROCESS_SIZE_CMD, + SEND_FILE_CHUNKS, + SEND_GLOBAL_CHECKPOINTS, + SEND_DONE_CMD, + SEND_ABORT_CMD, + WAIT_FOR_FINISH_OR_NEW_CHECKPOINT, + FINISH_WITH_ERROR, + FAILED, + END +}; + +/** + * This class represents a receiver thread. It contains + * all the logic for a thread to bind on a port and + * receive data from the wdt sender. All the receiver threads + * share modules like threads controller, throttler etc + */ +class ReceiverThread : public WdtThread { + public: + /// Identifiers for the funnels that this thread will use + enum RECEIVER_FUNNELS { SEND_FILE_CHUNKS_FUNNEL, NUM_FUNNELS }; + + /// Identifiers for the condition variable wrappers used in the thread + enum RECEIVER_CONDITIONS { WAIT_FOR_FINISH_OR_CHECKPOINT_CV, NUM_CONDITIONS }; + + /// Identifiers for the barriers used in the thread + enum RECEIVER_BARRIERS { NUM_BARRIERS }; + /** + * Constructor for receiver thread. + * @param wdtParent Pointer back to the parent receiver for meta + * information + * @param threadIndex Every thread is identified by unique index + * @param port Port this thread will listen on + * @param controller Thread controller for all the instances of the + * receiver threads. All the receiver thread objects + * need to share the same instance of the controller + */ + ReceiverThread(Receiver *wdtParent, int threadIndex, int port, + ThreadsController *controller); + + /// Initializes the receiver thread before starting + ErrorCode init() override; + + /** + * In long running mode, we need to reset thread variables after each + * session. Before starting each session, reset() has to called to do that. + */ + void reset() override; + + /// Destructor of Receiver thread + ~ReceiverThread(); + + /// Get the port this receiver thread is listening on + int32_t getPort() const override; + + private: + /// Overloaded operator for printing thread info + friend std::ostream &operator<<(std::ostream &os, + const ReceiverThread &receiverThread); + typedef ReceiverState (ReceiverThread::*StateFunction)(); + + /// Parent shared among all the threads for meta information + Receiver *wdtParent_; + + /** + * Tries to listen/bind to port. If this fails, thread is considered failed. + * Previous states : n/a (start state) + * Next states : ACCEPT_FIRST_CONNECTION(success), + * FAILED(failure) + */ + ReceiverState listen(); + + /** + * Tries to accept first connection of a new session. Periodically checks + * whether a new session has started or not. If a new session has started then + * goes to ACCEPT_WITH_TIMEOUT state. Also does session initialization. In + * joinable mode, tries to accept for a limited number of user specified + * retries. + * Previous states : LISTEN, + * END(if in long running mode) + * Next states : ACCEPT_WITH_TIMEOUT(if a new transfer has started and this + * thread has not received a connection), + * FINISH_WITH_ERROR(if did not receive a + * connection in specified number of retries), + * READ_NEXT_CMD(if a connection was received) + */ + ReceiverState acceptFirstConnection(); + + /** + * Tries to accept a connection with timeout. There are 2 kinds of timeout. At + * the beginning of the session, it uses accept window as the timeout. Later + * when sender settings are known it uses max(readTimeOut, writeTimeout)) + + * buffer(500) as the timeout. + * Previous states : Almost all states(for any network errors during transfer, + * we transition to this state), + * Next states : READ_NEXT_CMD(if there are no previous errors and accept + * was successful), + * SEND_LOCAL_CHECKPOINT(if there were previous errors and + * accept was successful), + * FINISH_WITH_ERROR(if accept failed and + * transfer previously failed during SEND_DONE_CMD state. Thus + * case needs special handling to ensure that we do not mess up + * session variables), + * END(if accept fails otherwise) + */ + ReceiverState acceptWithTimeout(); + /** + * Sends local checkpoint to the sender. In case of previous error during + * SEND_LOCAL_CHECKPOINT state, we send -1 as the checkpoint. + * Previous states : ACCEPT_WITH_TIMEOUT + * Next states : ACCEPT_WITH_TIMEOUT(if sending fails), + * SEND_DONE_CMD(if send is successful and we have previous + * SEND_DONE_CMD error), + * READ_NEXT_CMD(if send is successful otherwise) + */ + ReceiverState sendLocalCheckpoint(); + /** + * Reads next cmd and transitions to the state accordingly. + * Previous states : SEND_LOCAL_CHECKPOINT, + * ACCEPT_FIRST_CONNECTION, + * ACCEPT_WITH_TIMEOUT, + * PROCESS_SETTINGS_CMD, + * PROCESS_FILE_CMD, + * SEND_GLOBAL_CHECKPOINTS, + * Next states : PROCESS_FILE_CMD, + * PROCESS_DONE_CMD, + * PROCESS_SETTINGS_CMD, + * PROCESS_SIZE_CMD, + * ACCEPT_WITH_TIMEOUT(in case of read failure), + * FINISH_WITH_ERROR(in case of protocol errors) + */ + ReceiverState readNextCmd(); + /** + * Processes file cmd. Logic of how we write the file to the destination + * directory is defined here. + * Previous states : READ_NEXT_CMD + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error), + * ACCEPT_WITH_TIMEOUT(socket read failure) + */ + ReceiverState processFileCmd(); + /** + * Processes settings cmd. Settings has a connection settings, + * protocol version, transfer id, etc. For more info check Protocol.h + * Previous states : READ_NEXT_CMD, + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error), + * ACCEPT_WITH_TIMEOUT(socket read failure), + * SEND_FILE_CHUNKS(If the sender wants to resume transfer) + */ + ReceiverState processSettingsCmd(); + /** + * Processes done cmd. Also checks to see if there are any new global + * checkpoints or not + * Previous states : READ_NEXT_CMD, + * Next states : FINISH_WITH_ERROR(protocol error), + * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT(success), + * SEND_GLOBAL_CHECKPOINTS(if there are global errors) + */ + ReceiverState processDoneCmd(); + /** + * Processes size cmd. Sets the value of totalSenderBytes_ + * Previous states : READ_NEXT_CMD, + * Next states : READ_NEXT_CMD(success), + * FINISH_WITH_ERROR(protocol error) + */ + ReceiverState processSizeCmd(); + /** + * Sends file chunks that were received successfully in any previous transfer, + * this is the first step in download resumption. + * Checks to see if they have already been transferred or not. + * If yes, send ACK. If some other thread is sending it, sends wait cmd + * and checks again later. Otherwise, breaks the entire data into bufferSIze_ + * chunks and sends it. + * Previous states: PROCESS_SETTINGS_CMD, + * Next states : ACCEPT_WITH_TIMEOUT(network error), + * READ_NEXT_CMD(success) + */ + ReceiverState sendFileChunks(); + /** + * Sends global checkpoints to sender + * Previous states : PROCESS_DONE_CMD, + * FINISH_WITH_ERROR + * Next states : READ_NEXT_CMD(success), + * ACCEPT_WITH_TIMEOUT(socket write failure) + */ + ReceiverState sendGlobalCheckpoint(); + /** + * Sends DONE to sender, also tries to read back ack. If anything fails during + * this state, doneSendFailure_ thread variable is set. This flag makes the + * state machine behave differently, effectively bypassing all session related + * things. + * Previous states : SEND_LOCAL_CHECKPOINT, + * FINISH_WITH_ERROR + * Next states : END(success), + * ACCEPT_WITH_TIMEOUT(failure) + */ + ReceiverState sendDoneCmd(); + + /** + * Sends ABORT cmd back to the sender + * Previous states : PROCESS_FILE_CMD + * Next states : FINISH_WITH_ERROR + */ + ReceiverState sendAbortCmd(); + + /** + * Internal implementation of waitForFinishOrNewCheckpoint + * Returns : + * SEND_GLOBAL_CHECKPOINTS if there are checkpoints + * SEND_DONE_CMD if there are no checkpoints and + * there are no active threads + * WAIT_FOR_FINISH_OR_NEW_CHECKPOINT in all other cases + */ + ReceiverState checkForFinishOrNewCheckpoints(); + + /** + * Waits for transfer to finish or new checkpoints. This state first + * increments waitingThreadCount_. Then, it + * waits till all the threads have finished. It sends periodic WAIT signal to + * prevent sender from timing out. If a new checkpoint is found, we move to + * SEND_GLOBAL_CHECKPOINTS state. + * Previous states : PROCESS_DONE_CMD + * Next states : SEND_DONE_CMD(all threads finished), + * SEND_GLOBAL_CHECKPOINTS(if new checkpoints are found), + * ACCEPT_WITH_TIMEOUT(if socket write fails) + */ + ReceiverState waitForFinishOrNewCheckpoint(); + + /** + * Waits for transfer to finish. Only called when there is an error for the + * thread. It adds a checkpoint to the global list of checkpoints if a + * connection was received. It increments waitingWithErrorThreadCount_ and + * waits till the session ends. + * Previous states : Almost all states + * Next states : END + */ + ReceiverState finishWithError(); + + /// Mapping from receiver states to state functions + static const StateFunction stateMap_[]; + + /// Main entry point for the thread, starts the state machine + void start() override; + + /** + * Server socket object that provides functionality such as listen() + * accept, read, write on the socket + */ + ServerSocket socket_; + + /// Buffer that receives reads data into from the network + char *buf_{nullptr}; + + /// Size of the buffer + const int64_t bufferSize_; + + /// Marks the number of bytes already read in the buffer + int64_t numRead_{0}; + + /// Following two are markers to mark how much data has been read/parsed + int64_t off_{0}; + int64_t oldOffset_{0}; + + /// Number of checkpoints already transferred + int checkpointIndex_{0}; + + /// Checkpoints saved for this thread + std::vector checkpoints_; + + /** + * Pending value of checkpoint count. since write call success does not + * gurantee actual transfer, we do not apply checkpoint count update after + * the write. Only after receiving next cmd from sender, we apply the + * update + */ + int pendingCheckpointIndex_{0}; + + /// read timeout for sender + int64_t senderReadTimeout_{-1}; + + /// write timeout for sender + int64_t senderWriteTimeout_{-1}; + + /// whether the transfer is in block mode or not + bool isBlockMode_{true}; + + /// Checkpoint local to the thread, updated regularly + Checkpoint checkpoint_; + + /// whether checksum verification is enabled or not + bool enableChecksum_{false}; + + /// whether settings have been received and verified for the current + /// connection. This is used to determine round robin order for polling in + /// the server socket + bool curConnectionVerified_{false}; + + /** + * Whether SEND_DONE_CMD state has already failed for this session or not. + * This has to be separately handled, because session barrier is + * implemented before sending done cmd + */ + bool doneSendFailure_{false}; + + /// Checkpoints that have not been sent back to the sender + std::vector newCheckpoints_; +}; +} +} diff --git a/Reporting.cpp b/Reporting.cpp index f39bc5a0..585bdb1a 100644 --- a/Reporting.cpp +++ b/Reporting.cpp @@ -30,6 +30,24 @@ TransferStats& TransferStats::operator+=(const TransferStats& stats) { numFiles_ += stats.numFiles_; numBlocks_ += stats.numBlocks_; failedAttempts_ += stats.failedAttempts_; + if (numBlocksSend_ == -1) { + numBlocksSend_ = stats.numBlocksSend_; + } else if (stats.numBlocksSend_ != -1 && + numBlocksSend_ != stats.numBlocksSend_) { + LOG_IF(ERROR, errCode_ == OK) << "Mismatch in the numBlocksSend " + << numBlocksSend_ << " " + << stats.numBlocksSend_; + errCode_ = ERROR; + } + if (totalSenderBytes_ == -1) { + totalSenderBytes_ = stats.totalSenderBytes_; + } else if (stats.totalSenderBytes_ != -1 && + totalSenderBytes_ != stats.totalSenderBytes_) { + LOG_IF(ERROR, errCode_ == OK) << "Mismatch in the total sender bytes " + << totalSenderBytes_ << " " + << stats.totalSenderBytes_; + errCode_ = ERROR; + } if (stats.errCode_ != OK) { if (errCode_ == OK) { // First error. Setting this as the error code @@ -123,6 +141,17 @@ TransferReport::TransferReport( summary_.setNumFiles(numTransferredFiles); } +TransferReport::TransferReport(TransferStats&& globalStats, double totalTime, + int64_t totalFileSize) + : TransferReport(std::move(globalStats)) { + totalTime_ = totalTime; + totalFileSize_ = totalFileSize; +} + +TransferReport::TransferReport(TransferStats&& globalStats) { + summary_ = std::move(globalStats); +} + TransferReport::TransferReport(const std::vector& threadStats, double totalTime, int64_t totalFileSize) : totalTime_(totalTime), totalFileSize_(totalFileSize) { diff --git a/Reporting.h b/Reporting.h index dfc2abf5..39da302e 100644 --- a/Reporting.h +++ b/Reporting.h @@ -52,7 +52,7 @@ std::ostream &operator<<(std::ostream &os, const std::vector &v) { std::copy(v.begin(), v.end(), std::ostream_iterator(os, " ")); return os; } - +// TODO rename to ThreadResult /// class representing statistics related to file transfer class TransferStats { private: @@ -75,6 +75,12 @@ class TransferStats { /// number of failed transfers int64_t failedAttempts_ = 0; + /// Total number of blocks sent by sender + int64_t numBlocksSend_{-1}; + + /// Total number of bytes sent by sender + int64_t totalSenderBytes_{-1}; + /// status of the transfer ErrorCode errCode_ = OK; @@ -114,6 +120,36 @@ class TransferStats { errCode_ = remoteErrCode_ = OK; } + /// Validates the global transfer stats. Only call this method + /// on the accumulated stats. Can only be called for receiver + void validate() { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + if (numBlocksSend_ == -1) { + LOG(ERROR) << "Negative number of blocks sent by the sender"; + errCode_ = ERROR; + } else if (totalSenderBytes_ != -1 && + totalSenderBytes_ != effectiveDataBytes_) { + // did not receive all the bytes + LOG(ERROR) << "Number of bytes sent and received do not match " + << totalSenderBytes_ << " " << effectiveDataBytes_; + errCode_ = ERROR; + } else { + errCode_ = OK; + } + } + + /// @return the number of blocks sent by sender + int64_t getNumBlocksSend() const { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + return numBlocksSend_; + } + + /// @return the total sender bytes + int64_t getTotalSenderBytes() const { + folly::RWSpinLock::ReadHolder lock(mutex_.get()); + return totalSenderBytes_; + } + /// @return number of header bytes transferred int64_t getHeaderBytes() const { folly::RWSpinLock::ReadHolder lock(mutex_.get()); @@ -225,6 +261,18 @@ class TransferStats { headerBytes_ += count; } + /// @param set num blocks send + void setNumBlocksSend(int64_t numBlocksSend) { + folly::RWSpinLock::WriteHolder lock(mutex_.get()); + numBlocksSend_ = numBlocksSend; + } + + /// @param set total sender bytes + void setTotalSenderBytes(int64_t totalSenderBytes) { + folly::RWSpinLock::WriteHolder lock(mutex_.get()); + totalSenderBytes_ = totalSenderBytes; + } + /// one more file transfer failed void incrFailedAttempts() { folly::RWSpinLock::WriteHolder lock(mutex_.get()); @@ -311,6 +359,9 @@ class TransferReport { TransferReport(const std::vector &threadStats, double totalTime, int64_t totalFileSize); + TransferReport(TransferStats &&stats, double totalTime, + int64_t totalFileSize); + explicit TransferReport(TransferStats &&stats); /// constructor used by receiver, does move the thread stats explicit TransferReport(std::vector &threadStats); /// @return summary of the report diff --git a/Sender.cpp b/Sender.cpp index 8eb2e538..ffb356ae 100644 --- a/Sender.cpp +++ b/Sender.cpp @@ -9,6 +9,7 @@ #include "Sender.h" #include "ClientSocket.h" +#include "SenderThread.h" #include "Throttler.h" #include "SocketUtils.h" @@ -24,209 +25,26 @@ namespace facebook { namespace wdt { - -ThreadTransferHistory::ThreadTransferHistory(DirectorySourceQueue &queue, - TransferStats &threadStats) - : queue_(queue), threadStats_(threadStats) { -} - -std::string ThreadTransferHistory::getSourceId(int64_t index) { - folly::SpinLockGuard guard(lock_); - std::string sourceId; - const int64_t historySize = history_.size(); - if (index >= 0 && index < historySize) { - sourceId = history_[index]->getIdentifier(); - } else { - LOG(WARNING) << "Trying to read out of bounds data " << index << " " - << history_.size(); - } - return sourceId; -} - -bool ThreadTransferHistory::addSource(std::unique_ptr &source) { - folly::SpinLockGuard guard(lock_); - if (globalCheckpoint_) { - // already received an error for this thread - VLOG(1) << "adding source after global checkpoint is received. Returning " - "the source to the queue"; - markSourceAsFailed(source, lastCheckpoint_.get()); - lastCheckpoint_.reset(); - queue_.returnToQueue(source); - return false; - } - history_.emplace_back(std::move(source)); - return true; -} - -ErrorCode ThreadTransferHistory::setCheckpointAndReturnToQueue( - const Checkpoint &checkpoint, bool globalCheckpoint) { - folly::SpinLockGuard guard(lock_); - const int64_t historySize = history_.size(); - int64_t numReceivedSources = checkpoint.numBlocks; - int64_t lastBlockReceivedBytes = checkpoint.lastBlockReceivedBytes; - if (numReceivedSources > historySize) { - LOG(ERROR) - << "checkpoint is greater than total number of sources transfered " - << history_.size() << " " << numReceivedSources; - return INVALID_CHECKPOINT; - } - ErrorCode errCode = validateCheckpoint(checkpoint, globalCheckpoint); - if (errCode == INVALID_CHECKPOINT) { - return INVALID_CHECKPOINT; - } - globalCheckpoint_ |= globalCheckpoint; - lastCheckpoint_ = folly::make_unique(checkpoint); - int64_t numFailedSources = historySize - numReceivedSources; - if (numFailedSources == 0 && lastBlockReceivedBytes > 0) { - if (!globalCheckpoint) { - // no block to apply checkpoint offset. This can happen if we receive same - // local checkpoint without adding anything to the history - LOG(WARNING) << "Local checkpoint has received bytes for last block, but " - "there are no unacked blocks in the history. Ignoring."; - } - } - numAcknowledged_ = numReceivedSources; - std::vector> sourcesToReturn; - for (int64_t i = 0; i < numFailedSources; i++) { - std::unique_ptr source = std::move(history_.back()); - history_.pop_back(); - const Checkpoint *checkpointPtr = - (i == numFailedSources - 1 ? &checkpoint : nullptr); - markSourceAsFailed(source, checkpointPtr); - sourcesToReturn.emplace_back(std::move(source)); - } - queue_.returnToQueue(sourcesToReturn); - LOG(INFO) << numFailedSources - << " number of sources returned to queue, checkpoint: " - << checkpoint; - return errCode; -} - -std::vector ThreadTransferHistory::popAckedSourceStats() { - const int64_t historySize = history_.size(); - WDT_CHECK(numAcknowledged_ == historySize); - // no locking needed, as this should be called after transfer has finished - std::vector sourceStats; - while (!history_.empty()) { - sourceStats.emplace_back(std::move(history_.back()->getTransferStats())); - history_.pop_back(); - } - return sourceStats; -} - -void ThreadTransferHistory::markAllAcknowledged() { - folly::SpinLockGuard guard(lock_); - numAcknowledged_ = history_.size(); -} - -void ThreadTransferHistory::returnUnackedSourcesToQueue() { - Checkpoint checkpoint; - checkpoint.numBlocks = numAcknowledged_; - setCheckpointAndReturnToQueue(checkpoint, false); -} - -ErrorCode ThreadTransferHistory::validateCheckpoint( - const Checkpoint &checkpoint, bool globalCheckpoint) { - if (lastCheckpoint_ == nullptr) { - return OK; - } - if (checkpoint.numBlocks < lastCheckpoint_->numBlocks) { - LOG(ERROR) << "Current checkpoint must be higher than previous checkpoint, " - "Last checkpoint: " << *lastCheckpoint_ - << ", Current checkpoint: " << checkpoint; - return INVALID_CHECKPOINT; - } - if (checkpoint.numBlocks > lastCheckpoint_->numBlocks) { - return OK; - } - bool noProgress = false; - // numBlocks same - if (checkpoint.lastBlockSeqId == lastCheckpoint_->lastBlockSeqId && - checkpoint.lastBlockOffset == lastCheckpoint_->lastBlockOffset) { - // same block - if (checkpoint.lastBlockReceivedBytes != - lastCheckpoint_->lastBlockReceivedBytes) { - LOG(ERROR) << "Current checkpoint has different received bytes, but all " - "other fields are same, Last checkpoint " - << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; - return INVALID_CHECKPOINT; - } - noProgress = true; - } else { - // different block - WDT_CHECK(checkpoint.lastBlockReceivedBytes >= 0); - if (checkpoint.lastBlockReceivedBytes == 0) { - noProgress = true; - } - } - if (noProgress && !globalCheckpoint) { - // we can get same global checkpoint multiple times, so no need to check for - // progress - LOG(WARNING) << "No progress since last checkpoint, Last checkpoint: " - << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; - return NO_PROGRESS; +void Sender::endCurTransfer() { + endTime_ = Clock::now(); + LOG(INFO) << "Last thread finished " << durationSeconds(endTime_ - startTime_) + << " for transfer id " << transferId_; + transferFinished_ = true; + if (throttler_) { + throttler_->deRegisterTransfer(); } - return OK; } -void ThreadTransferHistory::markSourceAsFailed( - std::unique_ptr &source, const Checkpoint *checkpoint) { - auto metadata = source->getMetaData(); - bool validCheckpoint = false; - if (checkpoint != nullptr) { - if (checkpoint->hasSeqId) { - if ((checkpoint->lastBlockSeqId == metadata.seqId) && - (checkpoint->lastBlockOffset == source->getOffset())) { - validCheckpoint = true; - } else { - LOG(WARNING) - << "Checkpoint block does not match history block. Checkpoint: " - << checkpoint->lastBlockSeqId << ", " << checkpoint->lastBlockOffset - << " History: " << metadata.seqId << ", " << source->getOffset(); - } - } else { - // Receiver at lower version! - // checkpoint does not have seq-id. We have to blindly trust - // lastBlockReceivedBytes. If we do not, transfer will fail because of - // number of bytes mismatch. Even if an error happens because of this, - // Receiver will fail. - validCheckpoint = true; - } - } - int64_t receivedBytes = - (validCheckpoint ? checkpoint->lastBlockReceivedBytes : 0); - TransferStats &sourceStats = source->getTransferStats(); - if (sourceStats.getErrorCode() != OK) { - // already marked as failed - sourceStats.addEffectiveBytes(0, receivedBytes); - threadStats_.addEffectiveBytes(0, receivedBytes); - } else { - auto dataBytes = source->getSize(); - auto headerBytes = sourceStats.getEffectiveHeaderBytes(); - int64_t wastedBytes = dataBytes - receivedBytes; - sourceStats.subtractEffectiveBytes(headerBytes, wastedBytes); - sourceStats.decrNumBlocks(); - sourceStats.setErrorCode(SOCKET_WRITE_ERROR); - sourceStats.incrFailedAttempts(); - - threadStats_.subtractEffectiveBytes(headerBytes, wastedBytes); - threadStats_.decrNumBlocks(); - threadStats_.incrFailedAttempts(); +void Sender::startNewTransfer() { + if (throttler_) { + throttler_->registerTransfer(); } - source->advanceOffset(receivedBytes); + LOG(INFO) << "Starting a new transfer " << transferId_ << " to " << destHost_; } -const Sender::StateFunction Sender::stateMap_[] = { - &Sender::connect, &Sender::readLocalCheckPoint, &Sender::sendSettings, - &Sender::sendBlocks, &Sender::sendDoneCmd, &Sender::sendSizeCmd, - &Sender::checkForAbort, &Sender::readFileChunks, &Sender::readReceiverCmd, - &Sender::processDoneCmd, &Sender::processWaitCmd, &Sender::processErrCmd, - &Sender::processAbortCmd, &Sender::processVersionMismatch}; - Sender::Sender(const std::string &destHost, const std::string &srcDir) - : queueAbortChecker_(this) { + : queueAbortChecker_(this), destHost_(destHost) { LOG(INFO) << "WDT Sender " << Protocol::getFullVersion(); - destHost_ = destHost; srcDir_ = srcDir; transferFinished_ = true; const auto &options = WdtOptions::get(); @@ -262,6 +80,8 @@ Sender::Sender(const std::string &destHost, const std::string &srcDir, : Sender(destHost, srcDir) { ports_ = ports; dirQueue_->setFileInfo(srcFileInfo); + transferHistoryController_ = + folly::make_unique(*dirQueue_); } WdtTransferRequest Sender::init() { @@ -317,6 +137,43 @@ const std::string &Sender::getSrcDir() const { return srcDir_; } +ProtoNegotiationStatus Sender::getNegotiationStatus() { + return protoNegotiationStatus_; +} + +std::vector Sender::getNegotiatedProtocols() const { + std::vector ret; + for (const auto &senderThread : senderThreads_) { + ret.push_back(senderThread->getNegotiatedProtocol()); + } + return ret; +} + +void Sender::setProtoNegotiationStatus(ProtoNegotiationStatus status) { + protoNegotiationStatus_ = status; +} + +bool Sender::isSendFileChunks() const { + return (downloadResumptionEnabled_ && + protocolVersion_ >= Protocol::DOWNLOAD_RESUMPTION_VERSION); +} + +bool Sender::isFileChunksReceived() const { + std::lock_guard lock(mutex_); + return fileChunksReceived_; +} + +void Sender::setFileChunksInfo( + std::vector &fileChunksInfoList) { + std::lock_guard lock(mutex_); + if (fileChunksReceived_) { + LOG(WARNING) << "File chunks list received multiple times"; + return; + } + dirQueue_->setPreviouslyReceivedChunks(fileChunksInfoList); + fileChunksReceived_ = true; +} + const std::string &Sender::getDestination() const { return destHost_; } @@ -324,8 +181,9 @@ const std::string &Sender::getDestination() const { std::unique_ptr Sender::getTransferReport() { int64_t totalFileSize = dirQueue_->getTotalSize(); double totalTime = durationSeconds(Clock::now() - startTime_); + auto globalStats = getGlobalTransferStats(); std::unique_ptr transferReport = - folly::make_unique(globalThreadStats_, totalTime, + folly::make_unique(std::move(globalStats), totalTime, totalFileSize); return transferReport; } @@ -339,6 +197,26 @@ Clock::time_point Sender::getEndTime() { return endTime_; } +TransferStats Sender::getGlobalTransferStats() const { + TransferStats globalStats; + for (const auto &thread : senderThreads_) { + globalStats += thread->getTransferStats(); + } + return globalStats; +} + +ErrorCode Sender::verifyVersionMismatchStats() const { + for (const auto &senderThread : senderThreads_) { + const auto &threadStats = senderThread->getTransferStats(); + if (threadStats.getErrorCode() == OK) { + LOG(ERROR) << "Found a thread that completed transfer successfully " + << "despite version mismatch. " << senderThread->getPort(); + return ERROR; + } + } + return OK; +} + std::unique_ptr Sender::finish() { std::unique_lock instanceLock(instanceManagementMutex_); VLOG(1) << "Sender::finish()"; @@ -351,9 +229,8 @@ std::unique_ptr Sender::finish() { const bool twoPhases = options.two_phases; bool progressReportEnabled = progressReporter_ && progressReportIntervalMillis_ > 0; - const int64_t numPorts = ports_.size(); - for (int64_t i = 0; i < numPorts; i++) { - senderThreads_[i].join(); + for (auto &senderThread : senderThreads_) { + senderThread->finish(); } if (!twoPhases) { dirThread_.join(); @@ -366,9 +243,13 @@ std::unique_ptr Sender::finish() { } progressReporterThread_.join(); } - + std::vector threadStats; + for (auto &senderThread : senderThreads_) { + threadStats.push_back(std::move(senderThread->moveStats())); + } bool allSourcesAcked = false; - for (auto &stats : globalThreadStats_) { + for (auto &senderThread : senderThreads_) { + auto &stats = senderThread->getTransferStats(); if (stats.getErrorCode() == OK) { // at least one thread finished correctly // that means all transferred sources are acked @@ -378,7 +259,9 @@ std::unique_ptr Sender::finish() { } std::vector transferredSourceStats; - for (auto &transferHistory : transferHistories_) { + for (auto port : ports_) { + auto &transferHistory = + transferHistoryController_->getTransferHistory(port); if (allSourcesAcked) { transferHistory.markAllAcknowledged(); } else { @@ -391,18 +274,16 @@ std::unique_ptr Sender::finish() { std::make_move_iterator(stats.end())); } } - if (WdtOptions::get().full_reporting) { validateTransferStats(transferredSourceStats, - dirQueue_->getFailedSourceStats(), - globalThreadStats_); + dirQueue_->getFailedSourceStats()); } int64_t totalFileSize = dirQueue_->getTotalSize(); double totalTime = durationSeconds(endTime_ - startTime_); std::unique_ptr transferReport = folly::make_unique( transferredSourceStats, dirQueue_->getFailedSourceStats(), - globalThreadStats_, dirQueue_->getFailedDirectories(), totalTime, + threadStats, dirQueue_->getFailedDirectories(), totalTime, totalFileSize, dirQueue_->getCount()); if (progressReportEnabled) { @@ -410,8 +291,8 @@ std::unique_ptr Sender::finish() { } if (options.enable_perf_stat_collection) { PerfStatReport report; - for (auto &perfReport : perfReports_) { - report += perfReport; + for (auto &senderThread : senderThreads_) { + report += senderThread->getPerfReport(); } LOG(INFO) << report; } @@ -461,23 +342,14 @@ ErrorCode Sender::start() { } else { configureThrottler(); } - - // WARNING: Do not MERGE the following two loops. ThreadTransferHistory keeps - // a reference of TransferStats. And, any emplace operation on a vector - // invalidates all its references - const int64_t numPorts = ports_.size(); - for (int64_t i = 0; i < numPorts; i++) { - globalThreadStats_.emplace_back(true); - } - for (int64_t i = 0; i < numPorts; i++) { - transferHistories_.emplace_back(*dirQueue_, globalThreadStats_[i]); - } - perfReports_.resize(numPorts); - negotiatedProtocolVersions_.resize(numPorts, 0); - numActiveThreads_ = numPorts; - for (int64_t i = 0; i < numPorts; i++) { - globalThreadStats_[i].setId(folly::to(i)); - senderThreads_.emplace_back(&Sender::sendOne, this, i); + threadsController_ = new ThreadsController(ports_.size()); + threadsController_->setNumBarriers(SenderThread::NUM_BARRIERS); + threadsController_->setNumFunnels(SenderThread::NUM_FUNNELS); + threadsController_->setNumConditions(SenderThread::NUM_CONDITIONS); + senderThreads_ = threadsController_->makeThreads( + this, ports_.size(), ports_); + for (auto &senderThread : senderThreads_) { + senderThread->startThread(); } if (progressReportEnabled) { progressReporter_->start(); @@ -489,8 +361,7 @@ ErrorCode Sender::start() { void Sender::validateTransferStats( const std::vector &transferredSourceStats, - const std::vector &failedSourceStats, - const std::vector &threadStats) { + const std::vector &failedSourceStats) { int64_t sourceFailedAttempts = 0; int64_t sourceDataBytes = 0; int64_t sourceEffectiveDataBytes = 0; @@ -513,7 +384,8 @@ void Sender::validateTransferStats( sourceEffectiveDataBytes += stat.getEffectiveDataBytes(); sourceNumBlocks += stat.getNumBlocks(); } - for (const auto &stat : threadStats) { + for (const auto &senderThread : senderThreads_) { + const auto &stat = senderThread->getTransferStats(); threadFailedAttempts += stat.getFailedAttempts(); threadDataBytes += stat.getDataBytes(); threadEffectiveDataBytes += stat.getEffectiveDataBytes(); @@ -530,19 +402,18 @@ void Sender::setSocketCreator(const SocketCreator socketCreator) { socketCreator_ = socketCreator; } -std::unique_ptr Sender::connectToReceiver(const int port, - ErrorCode &errCode) { +std::unique_ptr Sender::connectToReceiver( + const int port, IAbortChecker const *abortChecker, ErrorCode &errCode) { auto startTime = Clock::now(); const auto &options = WdtOptions::get(); int connectAttempts = 0; std::unique_ptr socket; + std::string portStr = folly::to(port); if (!socketCreator_) { // socket creator not set, creating ClientSocket - socket = folly::make_unique( - destHost_, folly::to(port), &abortCheckerCallback_); + socket = folly::make_unique(destHost_, portStr, abortChecker); } else { - socket = socketCreator_(destHost_, folly::to(port), - &abortCheckerCallback_); + socket = socketCreator_(destHost_, portStr, abortChecker); } double retryInterval = options.sleep_millis; int maxRetries = options.max_retries; @@ -582,678 +453,6 @@ std::unique_ptr Sender::connectToReceiver(const int port, return socket; } -Sender::SenderState Sender::connect(ThreadData &data) { - VLOG(1) << "entered CONNECT state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - auto &socket = data.socket_; - auto &numReconnectWithoutProgress = data.numReconnectWithoutProgress_; - auto &options = WdtOptions::get(); - - if (socket) { - socket->close(); - } - if (numReconnectWithoutProgress >= options.max_transfer_retries) { - LOG(ERROR) << "Sender thread reconnected " << numReconnectWithoutProgress - << " times without making any progress, giving up. port: " - << socket->getPort(); - threadStats.setErrorCode(NO_PROGRESS); - return END; - } - - ErrorCode code; - socket = connectToReceiver(port, code); - if (code == ABORT) { - threadStats.setErrorCode(ABORT); - if (getCurAbortCode() == VERSION_MISMATCH) { - return PROCESS_VERSION_MISMATCH; - } - return END; - } - if (code != OK) { - threadStats.setErrorCode(code); - return END; - } - // clearing the totalSizeSent_ flag. This way if anything breaks, we resend - // the total size. - data.totalSizeSent_ = false; - auto nextState = - threadStats.getErrorCode() == OK ? SEND_SETTINGS : READ_LOCAL_CHECKPOINT; - // clear the error code, as this is a new transfer - threadStats.setErrorCode(OK); - return nextState; -} - -Sender::SenderState Sender::readLocalCheckPoint(ThreadData &data) { - LOG(INFO) << "entered READ_LOCAL_CHECKPOINT state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - auto &numReconnectWithoutProgress = data.numReconnectWithoutProgress_; - - std::vector checkpoints; - int64_t decodeOffset = 0; - char *buf = data.buf_; - int checkpointLen = Protocol::getMaxLocalCheckpointLength(protocolVersion_); - int64_t numRead = data.socket_->read(buf, checkpointLen); - if (numRead != checkpointLen) { - LOG(ERROR) << "read mismatch during reading local checkpoint " - << checkpointLen << " " << numRead << " port " << port; - threadStats.setErrorCode(SOCKET_READ_ERROR); - numReconnectWithoutProgress++; - return CONNECT; - } - if (!Protocol::decodeCheckpoints(protocolVersion_, buf, decodeOffset, - checkpointLen, checkpoints)) { - LOG(ERROR) << "checkpoint decode failure " - << folly::humanify(std::string(buf, checkpointLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (checkpoints.size() != 1 || checkpoints[0].port != port) { - LOG(ERROR) << "illegal local checkpoint " - << folly::humanify(std::string(buf, checkpointLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - const Checkpoint &checkpoint = checkpoints[0]; - auto numBlocks = checkpoint.numBlocks; - VLOG(1) << "received local checkpoint, port " << port << " num-blocks " - << numBlocks << " seq-id " << checkpoint.lastBlockSeqId << " offset " - << checkpoint.lastBlockOffset << " received-bytes " - << checkpoint.lastBlockReceivedBytes; - - if (numBlocks == -1) { - // Receiver failed while sending DONE cmd - return READ_RECEIVER_CMD; - } - - ErrorCode errCode = - transferHistory.setCheckpointAndReturnToQueue(checkpoint, false); - if (errCode == INVALID_CHECKPOINT) { - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (errCode == NO_PROGRESS) { - numReconnectWithoutProgress++; - } else { - numReconnectWithoutProgress = 0; - } - return SEND_SETTINGS; -} - -Sender::SenderState Sender::sendSettings(ThreadData &data) { - VLOG(1) << "entered SEND_SETTINGS state " << data.threadIndex_; - - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &options = WdtOptions::get(); - int64_t readTimeoutMillis = options.read_timeout_millis; - int64_t writeTimeoutMillis = options.write_timeout_millis; - int64_t off = 0; - buf[off++] = Protocol::SETTINGS_CMD; - bool sendFileChunks; - { - std::lock_guard lock(mutex_); - sendFileChunks = - (downloadResumptionEnabled_ && - protocolVersion_ >= Protocol::DOWNLOAD_RESUMPTION_VERSION); - } - Settings settings; - settings.readTimeoutMillis = readTimeoutMillis; - settings.writeTimeoutMillis = writeTimeoutMillis; - settings.transferId = transferId_; - settings.enableChecksum = options.enable_checksum; - settings.sendFileChunks = sendFileChunks; - settings.blockModeDisabled = (options.block_size_mbytes <= 0); - Protocol::encodeSettings(protocolVersion_, buf, off, Protocol::kMaxSettings, - settings); - int64_t toWrite = sendFileChunks ? Protocol::kMinBufLength : off; - int64_t written = socket->write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write failure " << written << " " << toWrite; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CONNECT; - } - threadStats.addHeaderBytes(toWrite); - return sendFileChunks ? READ_FILE_CHUNKS : SEND_BLOCKS; -} - -Sender::SenderState Sender::sendBlocks(ThreadData &data) { - VLOG(1) << "entered SEND_BLOCKS state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - auto &totalSizeSent = data.totalSizeSent_; - - if (protocolVersion_ >= Protocol::RECEIVER_PROGRESS_REPORT_VERSION && - !totalSizeSent && dirQueue_->fileDiscoveryFinished()) { - return SEND_SIZE_CMD; - } - - ErrorCode transferStatus; - std::unique_ptr source = dirQueue_->getNextSource(transferStatus); - if (!source) { - return SEND_DONE_CMD; - } - WDT_CHECK(!source->hasError()); - TransferStats transferStats = - sendOneByteSource(data.socket_, source, transferStatus); - threadStats += transferStats; - source->addTransferStats(transferStats); - source->close(); - if (!transferHistory.addSource(source)) { - // global checkpoint received for this thread. no point in - // continuing - LOG(ERROR) << "global checkpoint received, no point in continuing"; - threadStats.setErrorCode(CONN_ERROR); - return END; - } - - if (transferStats.getErrorCode() != OK) { - return CHECK_FOR_ABORT; - } - return SEND_BLOCKS; -} - -Sender::SenderState Sender::sendSizeCmd(ThreadData &data) { - VLOG(1) << "entered SEND_SIZE_CMD state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &totalSizeSent = data.totalSizeSent_; - int64_t off = 0; - buf[off++] = Protocol::SIZE_CMD; - - Protocol::encodeSize(buf, off, Protocol::kMaxSize, dirQueue_->getTotalSize()); - int64_t written = socket->write(buf, off); - if (written != off) { - LOG(ERROR) << "Socket write error " << off << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(off); - totalSizeSent = true; - return SEND_BLOCKS; -} - -Sender::SenderState Sender::sendDoneCmd(ThreadData &data) { - VLOG(1) << "entered SEND_DONE_CMD state " << data.threadIndex_; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - auto &socket = data.socket_; - int64_t off = 0; - buf[off++] = Protocol::DONE_CMD; - - auto pair = dirQueue_->getNumBlocksAndStatus(); - int64_t numBlocksDiscovered = pair.first; - ErrorCode transferStatus = pair.second; - buf[off++] = transferStatus; - - Protocol::encodeDone(protocolVersion_, buf, off, Protocol::kMaxDone, - numBlocksDiscovered, dirQueue_->getTotalSize()); - - int toWrite = Protocol::kMinBufLength; - int64_t written = socket->write(buf, toWrite); - if (written != toWrite) { - LOG(ERROR) << "Socket write failure " << written << " " << toWrite; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(toWrite); - VLOG(1) << "Wrote done cmd on " << socket->getFd() << " waiting for reply..."; - return READ_RECEIVER_CMD; -} - -Sender::SenderState Sender::checkForAbort(ThreadData &data) { - LOG(INFO) << "entered CHECK_FOR_ABORT state " << data.threadIndex_; - char *buf = data.buf_; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - - auto numRead = socket->read(buf, 1); - if (numRead != 1) { - VLOG(1) << "No abort cmd found"; - return CONNECT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd != Protocol::ABORT_CMD) { - VLOG(1) << "Unexpected result found while reading for abort " << buf[0]; - return CONNECT; - } - threadStats.addHeaderBytes(1); - return PROCESS_ABORT_CMD; -} - -Sender::SenderState Sender::readFileChunks(ThreadData &data) { - LOG(INFO) << "entered READ_FILE_CHUNKS state " << data.threadIndex_; - char *buf = data.buf_; - auto &socket = data.socket_; - auto &threadStats = data.threadStats_; - int64_t numRead = socket->read(buf, 1); - if (numRead != 1) { - LOG(ERROR) << "Socket read error 1 " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd == Protocol::ABORT_CMD) { - return PROCESS_ABORT_CMD; - } - if (cmd == Protocol::WAIT_CMD) { - return READ_FILE_CHUNKS; - } - if (cmd == Protocol::ACK_CMD) { - { - std::lock_guard lock(mutex_); - if (!fileChunksReceived_) { - LOG(ERROR) << "Sender has not yet received file chunks, but receiver " - "thinks it has already sent it"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - } - return SEND_BLOCKS; - } - if (cmd != Protocol::CHUNKS_CMD) { - LOG(ERROR) << "Unexpected cmd " << cmd; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - int64_t toRead = Protocol::kChunksCmdLen; - numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - int64_t off = 0; - int64_t bufSize, numFiles; - Protocol::decodeChunksCmd(buf, off, bufSize, numFiles); - LOG(INFO) << "File chunk list has " << numFiles - << " entries and is broken in buffers of length " << bufSize; - std::unique_ptr chunkBuffer(new char[bufSize]); - std::vector fileChunksInfoList; - while (true) { - int64_t numFileChunks = fileChunksInfoList.size(); - if (numFileChunks > numFiles) { - // We should never be able to read more file chunks than mentioned in the - // chunks cmd. Chunks cmd has buffer size used to transfer chunks and also - // number of chunks. This chunks are read and parsed and added to - // fileChunksInfoList. Number of chunks we decode should match with the - // number mentioned in the Chunks cmd. - LOG(ERROR) << "Number of file chunks received is more than the number " - "mentioned in CHUNKS_CMD " << numFileChunks << " " - << numFiles; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - if (numFileChunks == numFiles) { - break; - } - toRead = sizeof(int32_t); - numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - toRead = folly::loadUnaligned(buf); - toRead = folly::Endian::little(toRead); - numRead = socket->read(chunkBuffer.get(), toRead); - if (numRead != toRead) { - LOG(ERROR) << "Socket read error " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(numRead); - off = 0; - // decode function below adds decoded file chunks to fileChunksInfoList - bool success = Protocol::decodeFileChunksInfoList( - chunkBuffer.get(), off, toRead, fileChunksInfoList); - if (!success) { - LOG(ERROR) << "Unable to decode file chunks list"; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - } - { - std::lock_guard lock(mutex_); - if (fileChunksReceived_) { - LOG(WARNING) << "File chunks list received multiple times"; - } else { - dirQueue_->setPreviouslyReceivedChunks(fileChunksInfoList); - fileChunksReceived_ = true; - } - } - // send ack for file chunks list - buf[0] = Protocol::ACK_CMD; - int64_t toWrite = 1; - int64_t written = socket->write(buf, toWrite); - if (toWrite != written) { - LOG(ERROR) << "Socket write error " << toWrite << " " << written; - threadStats.setErrorCode(SOCKET_WRITE_ERROR); - return CHECK_FOR_ABORT; - } - threadStats.addHeaderBytes(written); - return SEND_BLOCKS; -} - -Sender::SenderState Sender::readReceiverCmd(ThreadData &data) { - VLOG(1) << "entered READ_RECEIVER_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - TransferStats &threadStats = data.threadStats_; - char *buf = data.buf_; - int64_t numRead = data.socket_->read(buf, 1); - if (numRead != 1) { - LOG(ERROR) << "READ unexpected " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf[0]; - if (cmd == Protocol::ERR_CMD) { - return PROCESS_ERR_CMD; - } - if (cmd == Protocol::WAIT_CMD) { - return PROCESS_WAIT_CMD; - } - if (cmd == Protocol::DONE_CMD) { - return PROCESS_DONE_CMD; - } - if (cmd == Protocol::ABORT_CMD) { - return PROCESS_ABORT_CMD; - } - if (cmd == Protocol::LOCAL_CHECKPOINT_CMD) { - int checkpointLen = Protocol::getMaxLocalCheckpointLength(protocolVersion_); - int64_t toRead = checkpointLen - 1; - numRead = data.socket_->read(buf + 1, toRead); - if (numRead != toRead) { - LOG(ERROR) << "Could not read possible local checkpoint " << toRead << " " - << numRead << " " << port; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - int64_t offset = 0; - std::vector checkpoints; - if (Protocol::decodeCheckpoints(protocolVersion_, buf, offset, - checkpointLen, checkpoints)) { - if (checkpoints.size() == 1 && checkpoints[0].port == port && - checkpoints[0].numBlocks == 0 && - checkpoints[0].lastBlockReceivedBytes == 0) { - // In a spurious local checkpoint, number of blocks and offset must both - // be zero - // Ignore the checkpoint - LOG(WARNING) - << "Received valid but unexpected local checkpoint, ignoring " - << port; - return READ_RECEIVER_CMD; - } - } - LOG(ERROR) << "Failed to verify spurious local checkpoint, port " << port; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - LOG(ERROR) << "Read unexpected receiver cmd " << cmd << " port " << port; - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; -} - -Sender::SenderState Sender::processDoneCmd(ThreadData &data) { - VLOG(1) << "entered PROCESS_DONE_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - char *buf = data.buf_; - auto &socket = data.socket_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - transferHistory.markAllAcknowledged(); - - // send ack for DONE - buf[0] = Protocol::DONE_CMD; - socket->write(buf, 1); - - socket->shutdown(); - auto numRead = socket->read(buf, Protocol::kMinBufLength); - if (numRead != 0) { - LOG(WARNING) << "EOF not found when expected"; - return END; - } - VLOG(1) << "done with transfer, port " << port; - return END; -} - -Sender::SenderState Sender::processWaitCmd(ThreadData &data) { - LOG(INFO) << "entered PROCESS_WAIT_CMD state " << data.threadIndex_; - int port = ports_[data.threadIndex_]; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - VLOG(1) << "received WAIT_CMD, port " << port; - transferHistory.markAllAcknowledged(); - return READ_RECEIVER_CMD; -} - -Sender::SenderState Sender::processErrCmd(ThreadData &data) { - int port = ports_[data.threadIndex_]; - LOG(INFO) << "entered PROCESS_ERR_CMD state " << data.threadIndex_ << " port " - << port; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - TransferStats &threadStats = data.threadStats_; - auto &transferHistories = data.transferHistories_; - auto &socket = data.socket_; - char *buf = data.buf_; - - int64_t toRead = sizeof(int16_t); - int64_t numRead = socket->read(buf, toRead); - if (numRead != toRead) { - LOG(ERROR) << "read unexpected " << toRead << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - - int16_t checkpointsLen = folly::loadUnaligned(buf); - checkpointsLen = folly::Endian::little(checkpointsLen); - char checkpointBuf[checkpointsLen]; - numRead = socket->read(checkpointBuf, checkpointsLen); - if (numRead != checkpointsLen) { - LOG(ERROR) << "read unexpected " << checkpointsLen << " " << numRead; - threadStats.setErrorCode(SOCKET_READ_ERROR); - return CONNECT; - } - - std::vector checkpoints; - int64_t decodeOffset = 0; - if (!Protocol::decodeCheckpoints(protocolVersion_, checkpointBuf, - decodeOffset, checkpointsLen, checkpoints)) { - LOG(ERROR) << "checkpoint decode failure " - << folly::humanify(std::string(checkpointBuf, checkpointsLen)); - threadStats.setErrorCode(PROTOCOL_ERROR); - return END; - } - transferHistory.markAllAcknowledged(); - for (auto &checkpoint : checkpoints) { - auto errPort = checkpoint.port; - auto errPoint = checkpoint.numBlocks; - auto it = std::find(ports_.begin(), ports_.end(), errPort); - if (it == ports_.end()) { - LOG(ERROR) << "Invalid checkpoint " << errPoint - << ". No sender thread running on port " << errPort; - continue; - } - auto errThread = it - ports_.begin(); - VLOG(1) << "received global checkpoint " << errThread << " -> " << errPoint - << ", " << checkpoint.lastBlockSeqId << ", " - << checkpoint.lastBlockOffset << ", " - << checkpoint.lastBlockReceivedBytes; - transferHistories[errThread].setCheckpointAndReturnToQueue(checkpoint, - true); - } - return SEND_BLOCKS; -} - -Sender::SenderState Sender::processAbortCmd(ThreadData &data) { - LOG(INFO) << "entered PROCESS_ABORT_CMD state " << data.threadIndex_; - char *buf = data.buf_; - auto &threadStats = data.threadStats_; - auto &socket = data.socket_; - ThreadTransferHistory &transferHistory = data.getTransferHistory(); - - threadStats.setErrorCode(ABORT); - int toRead = Protocol::kAbortLength; - auto numRead = socket->read(buf, toRead); - if (numRead != toRead) { - // can not read checkpoint, but still must exit because of ABORT - LOG(ERROR) << "Error while trying to read ABORT cmd " << numRead << " " - << toRead; - return END; - } - int64_t offset = 0; - int32_t negotiatedProtocol; - ErrorCode remoteError; - int64_t checkpoint; - Protocol::decodeAbort(buf, offset, negotiatedProtocol, remoteError, - checkpoint); - threadStats.setRemoteErrorCode(remoteError); - std::string failedFileName = transferHistory.getSourceId(checkpoint); - LOG(WARNING) << "Received abort on " << data.threadIndex_ - << " remote protocol version " << negotiatedProtocol - << " remote error code " << errorCodeToStr(remoteError) - << " file " << failedFileName << " checkpoint " << checkpoint; - abort(remoteError); - if (remoteError == VERSION_MISMATCH) { - if (Protocol::negotiateProtocol(negotiatedProtocol, protocolVersion_) == - negotiatedProtocol) { - // sender can support this negotiated version - negotiatedProtocolVersions_[data.threadIndex_] = negotiatedProtocol; - return PROCESS_VERSION_MISMATCH; - } else { - LOG(ERROR) << "Sender can not support receiver version " - << negotiatedProtocol; - threadStats.setRemoteErrorCode(VERSION_INCOMPATIBLE); - } - } - return END; -} - -Sender::SenderState Sender::processVersionMismatch(ThreadData &data) { - LOG(INFO) << "entered PROCESS_VERSION_MISMATCH state " << data.threadIndex_; - auto &threadStats = data.threadStats_; - - WDT_CHECK(threadStats.getErrorCode() == ABORT); - std::unique_lock lock(mutex_); - if (protoNegotiationStatus_ != V_MISMATCH_WAIT) { - LOG(WARNING) << "Protocol version already negotiated, but transfer still " - "aborted due to version mismatch, port " - << ports_[data.threadIndex_]; - return END; - } - numWaitingWithAbort_++; - while (protoNegotiationStatus_ == V_MISMATCH_WAIT && - numWaitingWithAbort_ != numActiveThreads_) { - WDT_CHECK(numWaitingWithAbort_ < numActiveThreads_); - conditionAllAborted_.wait(lock); - } - numWaitingWithAbort_--; - if (protoNegotiationStatus_ == V_MISMATCH_FAILED) { - return END; - } - if (protoNegotiationStatus_ == V_MISMATCH_RESOLVED) { - threadStats.setRemoteErrorCode(OK); - return CONNECT; - } - protoNegotiationStatus_ = V_MISMATCH_FAILED; - - for (const auto &stat : globalThreadStats_) { - if (stat.getErrorCode() == OK) { - // Some other thread finished successfully, should never happen in case of - // version mismatch - return END; - } - } - - const int numHistories = transferHistories_.size(); - for (int i = 0; i < numHistories; i++) { - auto &history = transferHistories_[i]; - if (history.getNumAcked() > 0) { - LOG(ERROR) << "Even though the transfer aborted due to VERSION_MISMATCH, " - "some blocks got acked by the receiver, port " << ports_[i] - << " numAcked " << history.getNumAcked(); - return END; - } - history.returnUnackedSourcesToQueue(); - } - - int negotiatedProtocol = 0; - for (int protocolVersion : negotiatedProtocolVersions_) { - if (protocolVersion > 0) { - if (negotiatedProtocol > 0 && negotiatedProtocol != protocolVersion) { - LOG(ERROR) << "Different threads negotiated different protocols " - << negotiatedProtocol << " " << protocolVersion; - return END; - } - negotiatedProtocol = protocolVersion; - } - } - WDT_CHECK_GT(negotiatedProtocol, 0); - - LOG_IF(INFO, negotiatedProtocol != protocolVersion_) - << "Changing protocol version to " << negotiatedProtocol - << ", previous version " << protocolVersion_; - protocolVersion_ = negotiatedProtocol; - threadStats.setRemoteErrorCode(OK); - protoNegotiationStatus_ = V_MISMATCH_RESOLVED; - clearAbort(); - conditionAllAborted_.notify_all(); - return CONNECT; -} - -void Sender::sendOne(int threadIndex) { - INIT_PERF_STAT_REPORT - std::vector &transferHistories = transferHistories_; - TransferStats &threadStats = globalThreadStats_[threadIndex]; - Clock::time_point startTime = Clock::now(); - int port = ports_[threadIndex]; - auto completionGuard = folly::makeGuard([&] { - std::unique_lock lock(mutex_); - numActiveThreads_--; - if (numActiveThreads_ == 0) { - LOG(INFO) << "Last thread finished " - << durationSeconds(Clock::now() - startTime_); - endTime_ = Clock::now(); - transferFinished_ = true; - } - if (throttler_) { - throttler_->deRegisterTransfer(); - } - conditionAllAborted_.notify_one(); - }); - if (throttler_) { - throttler_->registerTransfer(); - } - ThreadData threadData(threadIndex, threadStats, transferHistories); - SenderState state = CONNECT; - while (state != END) { - ErrorCode abortCode = getCurAbortCode(); - if (abortCode != OK) { - LOG(ERROR) << "Transfer aborted " << port << " " - << errorCodeToStr(abortCode); - threadStats.setErrorCode(ABORT); - if (abortCode == VERSION_MISMATCH) { - state = PROCESS_VERSION_MISMATCH; - } else { - break; - } - } - state = (this->*stateMap_[state])(threadData); - } - - double totalTime = durationSeconds(Clock::now() - startTime); - LOG(INFO) << "Port " << port << " done. " << threadStats - << " Total throughput = " - << threadStats.getEffectiveTotalBytes() / totalTime / kMbToB - << " Mbytes/sec"; - perfReports_[threadIndex] = *perfStatReport; - return; -} - TransferStats Sender::sendOneByteSource( const std::unique_ptr &socket, const std::unique_ptr &source, ErrorCode transferStatus) { @@ -1267,7 +466,6 @@ TransferStats Sender::sendOneByteSource( off += sizeof(int16_t); const int64_t expectedSize = source->getSize(); int64_t actualSize = 0; - const SourceMetaData &metadata = source->getMetaData(); BlockDetails blockDetails; blockDetails.fileName = metadata.relPath; @@ -1277,7 +475,6 @@ TransferStats Sender::sendOneByteSource( blockDetails.dataSize = expectedSize; blockDetails.allocationStatus = metadata.allocationStatus; blockDetails.prevSeqId = metadata.prevSeqId; - Protocol::encodeHeader(protocolVersion_, headerBuf, off, Protocol::kMaxHeader, blockDetails); int16_t littleEndianOff = folly::Endian::little((int16_t)off); diff --git a/Sender.h b/Sender.h index 73cb9c57..7e4536c7 100644 --- a/Sender.h +++ b/Sender.h @@ -11,9 +11,7 @@ #include "WdtBase.h" #include "ClientSocket.h" #include "WdtOptions.h" - -#include - +#include "ThreadsController.h" #include #include #include @@ -22,87 +20,13 @@ namespace facebook { namespace wdt { - -class DirectorySourceQueue; - -/// transfer history of a sender thread -class ThreadTransferHistory { - public: - /** - * @param queue directory queue - * @param threadStats stat object of the thread - */ - ThreadTransferHistory(DirectorySourceQueue &queue, - TransferStats &threadStats); - - /** - * @param index of the source - * @return if index is in bounds, returns the identifier for the - * source, else returns empty string - */ - std::string getSourceId(int64_t index); - - /** - * Adds the source to the history. If global checkpoint has already been - * received, then the source is returned to the queue. - * - * @param source source to be added to history - * @return true if added to history, false if not added due to a - * global checkpoint - */ - bool addSource(std::unique_ptr &source); - - /** - * Sets checkpoint. Also, returns unacked sources to queue - * - * @param checkpoint checkpoint received - * @param globalCheckpoint global or local checkpoint - * @return number of sources returned to queue - */ - ErrorCode setCheckpointAndReturnToQueue(const Checkpoint &checkpoint, - bool globalCheckpoint); - - /** - * @return stats for acked sources, must be called after all the - * unacked sources are returned to the queue - */ - std::vector popAckedSourceStats(); - - /// marks all the sources as acked - void markAllAcknowledged(); - - /// returns all unacked sources to the queue - void returnUnackedSourcesToQueue(); - - /** - * @return number of sources acked by the receiver - */ - int64_t getNumAcked() const { - return numAcknowledged_; - } - - private: - ErrorCode validateCheckpoint(const Checkpoint &checkpoint, - bool globalCheckpoint); - - void markSourceAsFailed(std::unique_ptr &source, - const Checkpoint *checkpoint); - - /// reference to global queue - DirectorySourceQueue &queue_; - /// reference to thread stats - TransferStats &threadStats_; - /// history of the thread - std::vector> history_; - /// whether a global error checkpoint has been received or not - bool globalCheckpoint_{false}; - /// number of sources acked by the receiver thread - int64_t numAcknowledged_{0}; - /// last received checkpoint - std::unique_ptr lastCheckpoint_{nullptr}; - folly::SpinLock lock_; +class SenderThread; +class TransferHistoryController; +enum ProtoNegotiationStatus { + V_MISMATCH_WAIT, // waiting for version mismatch to be processed + V_MISMATCH_RESOLVED, // version mismatch processed and was successful + V_MISMATCH_FAILED, // version mismatch processed and it failed }; - /** * The sender for the transfer. One instance of sender should only be * responsible for one transfer. For a second transfer you should make @@ -226,6 +150,34 @@ class Sender : public WdtBase { void setSocketCreator(const SocketCreator socketCreator); private: + friend class SenderThread; + + /// Get the sum of all the thread transfer stats + TransferStats getGlobalTransferStats() const; + + /// Verifies that if there is version mismatch then no thread stats + /// should be OK + ErrorCode verifyVersionMismatchStats() const; + + /// General utility used by sender threads to connect to receiver + std::unique_ptr connectToReceiver( + const int port, IAbortChecker const *abortChecker, ErrorCode &errCode); + + /// Method responsible for sending one source to the destination + virtual TransferStats sendOneByteSource( + const std::unique_ptr &socket, + const std::unique_ptr &source, ErrorCode transferStatus); + + /// Returns true if file chunks need to be read + bool isSendFileChunks() const; + + /// Retruns true if file chunks been received by a thread + bool isFileChunksReceived() const; + + /// Sender thread calls this method to set the file chunks info received + /// from the receiver + void setFileChunksInfo(std::vector &fileChunksInfoList); + /// Abort checker passed to DirectoryQueue. If all the network threads finish, /// directory discovery thread is also aborted class QueueAbortChecker : public IAbortChecker { @@ -241,191 +193,9 @@ class Sender : public WdtBase { Sender *sender_; }; + /// Abort checker shared with the directory queue QueueAbortChecker queueAbortChecker_; - /// state machine states - enum SenderState { - CONNECT, - READ_LOCAL_CHECKPOINT, - SEND_SETTINGS, - SEND_BLOCKS, - SEND_DONE_CMD, - SEND_SIZE_CMD, - CHECK_FOR_ABORT, - READ_FILE_CHUNKS, - READ_RECEIVER_CMD, - PROCESS_DONE_CMD, - PROCESS_WAIT_CMD, - PROCESS_ERR_CMD, - PROCESS_ABORT_CMD, - PROCESS_VERSION_MISMATCH, - END - }; - - /// structure to share data among different states - struct ThreadData { - const int threadIndex_; - TransferStats &threadStats_; - std::vector &transferHistories_; - std::unique_ptr socket_; - char buf_[Protocol::kMinBufLength]; - /// whether total file size has been sent to the receiver - bool totalSizeSent_{false}; - /// number of consecutive reconnects without any progress - int numReconnectWithoutProgress_{0}; - ThreadData(int threadIndex, TransferStats &threadStats, - std::vector &transferHistories) - : threadIndex_(threadIndex), - threadStats_(threadStats), - transferHistories_(transferHistories) { - } - - ThreadTransferHistory &getTransferHistory() { - return transferHistories_[threadIndex_]; - } - }; - - typedef SenderState (Sender::*StateFunction)(ThreadData &data); - - /** - * tries to connect to the receiver - * Previous states : Almost all states(in case of network errors, all states - * move to this state) - * Next states : SEND_SETTINGS(if there is no previous error) - * READ_LOCAL_CHECKPOINT(if there is previous error) - * END(failed) - */ - SenderState connect(ThreadData &data); - /** - * tries to read local checkpoint and return unacked sources to queue. If the - * checkpoint value is -1, then we know previous attempt to send DONE had - * failed. So, we move to READ_RECEIVER_CMD state. - * Previous states : CONNECT - * Next states : CONNECT(read failure), - * END(protocol error or global checkpoint found), - * READ_RECEIVER_CMD(if checkpoint is -1), - * SEND_SETTINGS(success) - */ - SenderState readLocalCheckPoint(ThreadData &data); - /** - * sends sender settings to the receiver - * Previous states : READ_LOCAL_CHECKPOINT, - * CONNECT - * Next states : SEND_BLOCKS(success), - * CONNECT(failure) - */ - SenderState sendSettings(ThreadData &data); - /** - * sends blocks to receiver till the queue is not empty. After transferring a - * block, we add it to the history. While adding to history, if it is found - * that global checkpoint has been received for this thread, we move to END - * state. - * Previous states : SEND_SETTINGS, - * PROCESS_ERR_CMD - * Next states : SEND_BLOCKS(success), - * END(global checkpoint received), - * CHECK_FOR_ABORT(socket write failure), - * SEND_DONE_CMD(no more blocks left to transfer) - */ - SenderState sendBlocks(ThreadData &data); - /** - * sends DONE cmd to the receiver - * Previous states : SEND_BLOCKS - * Next states : CONNECT(failure), - * READ_RECEIVER_CMD(success) - */ - SenderState sendDoneCmd(ThreadData &data); - /** - * sends size cmd to the receiver - * Previous states : SEND_BLOCKS - * Next states : CHECK_FOR_ABORT(failure), - * SEND_BLOCKS(success) - */ - SenderState sendSizeCmd(ThreadData &data); - /** - * checks to see if the receiver has sent ABORT or not - * Previous states : SEND_BLOCKS, - * SEND_DONE_CMD - * Next states : CONNECT(no ABORT cmd), - * END(protocol error), - * PROCESS_ABORT_CMD(read ABORT cmd) - */ - SenderState checkForAbort(ThreadData &data); - /** - * reads previously transferred file chunks list. If it receives an ACK cmd, - * then it moves on. If wait cmd is received, it waits. Otherwise reads the - * file chunks and when done starts directory queue thread. - * Previous states : SEND_SETTINGS, - * Next states: READ_FILE_CHUNKS(if wait cmd is received), - * CHECK_FOR_ABORT(network error), - * END(protocol error), - * SEND_BLOCKS(success) - * - */ - SenderState readFileChunks(ThreadData &data); - /** - * reads receiver cmd - * Previous states : SEND_DONE_CMD - * Next states : PROCESS_DONE_CMD, - * PROCESS_WAIT_CMD, - * PROCESS_ERR_CMD, - * END(protocol error), - * CONNECT(failure) - */ - SenderState readReceiverCmd(ThreadData &data); - /** - * handles DONE cmd - * Previous states : READ_RECEIVER_CMD - * Next states : END - */ - SenderState processDoneCmd(ThreadData &data); - /** - * handles WAIT cmd - * Previous states : READ_RECEIVER_CMD - * Next states : READ_RECEIVER_CMD - */ - SenderState processWaitCmd(ThreadData &data); - /** - * reads list of global checkpoints and returns unacked sources to queue. - * Previous states : READ_RECEIVER_CMD - * Next states : CONNECT(socket read failure) - * END(checkpoint list decode failure), - * SEND_BLOCKS(success) - */ - SenderState processErrCmd(ThreadData &data); - /** - * processes ABORT cmd - * Previous states : CHECK_FOR_ABORT, - * READ_RECEIVER_CMD - * Next states : END - */ - SenderState processAbortCmd(ThreadData &data); - - /** - * waits for all active threads to be aborted, checks to see if the abort was - * due to version mismatch. Also performs various sanity checks. - * Previous states : Almost all threads, abort flags is checked between every - * state transition - * Next states : CONNECT(Abort was due to version mismatch), - * END(if abort was not due to version mismatch or some sanity - * check failed) - */ - SenderState processVersionMismatch(ThreadData &data); - - /// mapping from sender states to state functions - static const StateFunction stateMap_[]; - - /// Method responsible for sending one source to the destination - virtual TransferStats sendOneByteSource( - const std::unique_ptr &socket, - const std::unique_ptr &source, ErrorCode transferStatus); - - /// Every sender thread executes this method to send the data - void sendOne(int threadIndex); - - std::unique_ptr connectToReceiver(const int port, - ErrorCode &errCode); - /** * Internal API that triggers the directory thread, sets up the sender * threads and starts the transfer. Returns after the sender threads @@ -437,12 +207,10 @@ class Sender : public WdtBase { * @param transferredSourceStats Stats for the successfully transmitted * sources * @param failedSourceStats Stats for the failed sources - * @param threadStats Stats calculated by each sender thread */ void validateTransferStats( const std::vector &transferredSourceStats, - const std::vector &failedSourceStats, - const std::vector &threadStats); + const std::vector &failedSourceStats); /** * Responsible for doing a periodic check. @@ -452,16 +220,14 @@ class Sender : public WdtBase { */ void reportProgress(); + /// Address of the destination host where the files are sent + const std::string destHost_; /// Pointer to DirectorySourceQueue which reads the srcDir and the files std::unique_ptr dirQueue_; - /// List of ports where the receiver threads are running on the destination - std::vector ports_; /// Number of active threads, decremented every time a thread is finished int32_t numActiveThreads_{0}; /// The directory from where the files are read std::string srcDir_; - /// Address of the destination host where the files are sent - std::string destHost_; /// The interval at which the progress reporter should check for progress int progressReportIntervalMillis_; /// Socket creator used to optionally create different kinds of client socket @@ -473,26 +239,25 @@ class Sender : public WdtBase { /// Thread that is running the discovery of files using the dirQueue_ std::thread dirThread_; /// Threads which are responsible for transfer of the sources - std::vector senderThreads_; + std::vector> senderThreads_; /// Thread responsible for doing the progress checks. Uses reportProgress() std::thread progressReporterThread_; - /// Vector of per thread stats, this same instance is used in reporting - std::vector globalThreadStats_; - /// per thread perf report - std::vector perfReports_; - /// per thread negotiated protocol versions - std::vector negotiatedProtocolVersions_; - /// number of threads waiting in PROCESS_VERSION_MISMATCH state - int numWaitingWithAbort_{0}; - /// Condition variable used to co-ordinate threads waiting in - /// PROCESS_VERSION_MISMATCH state - std::condition_variable conditionAllAborted_; - - enum ProtoNegotiationStatus { - V_MISMATCH_WAIT, // waiting for version mismatch to be processed - V_MISMATCH_RESOLVED, // version mismatch processed and was successful - V_MISMATCH_FAILED, // version mismatch processed and it failed - }; + + /// Returns the protocol negotiation status of the parent sender + ProtoNegotiationStatus getNegotiationStatus(); + + /// Set the protocol negotiation status, called by sender thread + void setProtoNegotiationStatus(ProtoNegotiationStatus status); + + /// Things to do before ending the current transfer + void endCurTransfer(); + + /// Initializing the new transfer + void startNewTransfer(); + + /// Returns vector of negotiated protocols set by sender threads + std::vector getNegotiatedProtocols() const; + /// Protocol negotiation status, used to co-ordinate processing of version /// mismatch. Threads aborted due to version mismatch waits for all threads to /// abort and reach PROCESS_VERSION_MISMATCH state. Last thread processes @@ -503,20 +268,21 @@ class Sender : public WdtBase { std::condition_variable conditionFinished_; /// Mutex which is shared between the parent thread, sender thread and /// progress reporter thread - std::mutex mutex_; + mutable std::mutex mutex_; /// Set to false when the transfer begins and then turned on when it ends bool transferFinished_; /// Time at which the transfer was started std::chrono::time_point startTime_; /// Time at which the transfer finished std::chrono::time_point endTime_; - /// Per thread transfer history - std::vector transferHistories_; /// Has finished been called and threads joined bool areThreadsJoined_{true}; /// Mutex for the management of this instance, specifically to keep the /// instance sane for multi threaded public API calls std::mutex instanceManagementMutex_; + + /// Transfer history controller for the sender threads + std::unique_ptr transferHistoryController_; }; } } // namespace facebook::wdt diff --git a/SenderThread.cpp b/SenderThread.cpp new file mode 100644 index 00000000..0fe2d3b4 --- /dev/null +++ b/SenderThread.cpp @@ -0,0 +1,633 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "SenderThread.h" +#include "Sender.h" +#include "ClientSocket.h" +#include +#include +#include +#include +#include +#include +#include + +namespace facebook { +namespace wdt { +std::ostream &operator<<(std::ostream &os, const SenderThread &senderThread) { + os << "Thread[" << senderThread.threadIndex_ + << ", port: " << senderThread.port_ << "] "; + return os; +} + +const SenderThread::StateFunction SenderThread::stateMap_[] = { + &SenderThread::connect, &SenderThread::readLocalCheckPoint, + &SenderThread::sendSettings, &SenderThread::sendBlocks, + &SenderThread::sendDoneCmd, &SenderThread::sendSizeCmd, + &SenderThread::checkForAbort, &SenderThread::readFileChunks, + &SenderThread::readReceiverCmd, &SenderThread::processDoneCmd, + &SenderThread::processWaitCmd, &SenderThread::processErrCmd, + &SenderThread::processAbortCmd, &SenderThread::processVersionMismatch}; + +SenderState SenderThread::connect() { + VLOG(1) << *this << " entered CONNECT state"; + if (socket_) { + socket_->close(); + } + const auto &options = WdtOptions::get(); + if (numReconnectWithoutProgress_ >= options.max_transfer_retries) { + LOG(ERROR) << "Sender thread reconnected " << numReconnectWithoutProgress_ + << " times without making any progress, giving up. port: " + << socket_->getPort(); + threadStats_.setErrorCode(NO_PROGRESS); + return END; + } + ErrorCode code; + socket_ = + wdtParent_->connectToReceiver(port_, socketAbortChecker_.get(), code); + if (code == ABORT) { + threadStats_.setErrorCode(ABORT); + if (wdtParent_->getCurAbortCode() == VERSION_MISMATCH) { + return PROCESS_VERSION_MISMATCH; + } + return END; + } + if (code != OK) { + threadStats_.setErrorCode(code); + return END; + } + auto nextState = SEND_SETTINGS; + if (threadStats_.getErrorCode() != OK) { + nextState = READ_LOCAL_CHECKPOINT; + } + // resetting the status of thread + reset(); + return nextState; +} + +SenderState SenderThread::readLocalCheckPoint() { + LOG(INFO) << *this << " entered READ_LOCAL_CHECKPOINT state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + std::vector checkpoints; + int64_t decodeOffset = 0; + int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + int64_t numRead = socket_->read(buf_, checkpointLen); + if (numRead != checkpointLen) { + LOG(ERROR) << "read mismatch during reading local checkpoint " + << checkpointLen << " " << numRead << " port " << port_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + numReconnectWithoutProgress_++; + return CONNECT; + } + if (!Protocol::decodeCheckpoints(threadProtocolVersion_, buf_, decodeOffset, + checkpointLen, checkpoints)) { + LOG(ERROR) << "checkpoint decode failure " + << folly::humanify(std::string(buf_, checkpointLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (checkpoints.size() != 1 || checkpoints[0].port != port_) { + LOG(ERROR) << "illegal local checkpoint " + << folly::humanify(std::string(buf_, checkpointLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + const Checkpoint &checkpoint = checkpoints[0]; + auto numBlocks = checkpoint.numBlocks; + VLOG(1) << "received local checkpoint " << checkpoint; + + if (numBlocks == -1) { + // Receiver failed while sending DONE cmd + return READ_RECEIVER_CMD; + } + + ErrorCode errCode = transferHistory.setLocalCheckpoint(checkpoint); + if (errCode == INVALID_CHECKPOINT) { + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (errCode == NO_PROGRESS) { + ++numReconnectWithoutProgress_; + } else { + numReconnectWithoutProgress_ = 0; + } + return SEND_SETTINGS; +} + +SenderState SenderThread::sendSettings() { + VLOG(1) << *this << " entered SEND_SETTINGS state"; + auto &options = WdtOptions::get(); + int64_t readTimeoutMillis = options.read_timeout_millis; + int64_t writeTimeoutMillis = options.write_timeout_millis; + int64_t off = 0; + buf_[off++] = Protocol::SETTINGS_CMD; + bool sendFileChunks = wdtParent_->isSendFileChunks(); + Settings settings; + settings.readTimeoutMillis = readTimeoutMillis; + settings.writeTimeoutMillis = writeTimeoutMillis; + settings.transferId = wdtParent_->getTransferId(); + settings.enableChecksum = options.enable_checksum; + settings.sendFileChunks = sendFileChunks; + settings.blockModeDisabled = (options.block_size_mbytes <= 0); + Protocol::encodeSettings(threadProtocolVersion_, buf_, off, + Protocol::kMaxSettings, settings); + int64_t toWrite = sendFileChunks ? Protocol::kMinBufLength : off; + int64_t written = socket_->write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << "Socket write failure " << written << " " << toWrite; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CONNECT; + } + threadStats_.addHeaderBytes(toWrite); + return sendFileChunks ? READ_FILE_CHUNKS : SEND_BLOCKS; +} + +SenderState SenderThread::sendBlocks() { + VLOG(1) << *this << " entered SEND_BLOCKS state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + if (threadProtocolVersion_ >= Protocol::RECEIVER_PROGRESS_REPORT_VERSION && + !totalSizeSent_ && dirQueue_->fileDiscoveryFinished()) { + return SEND_SIZE_CMD; + } + ErrorCode transferStatus; + std::unique_ptr source = dirQueue_->getNextSource(transferStatus); + if (!source) { + return SEND_DONE_CMD; + } + WDT_CHECK(!source->hasError()); + TransferStats transferStats = + wdtParent_->sendOneByteSource(socket_, source, transferStatus); + threadStats_ += transferStats; + source->addTransferStats(transferStats); + source->close(); + if (!transferHistory.addSource(source)) { + // global checkpoint received for this thread. no point in + // continuing + LOG(ERROR) << *this << " global checkpoint received. Stopping"; + threadStats_.setErrorCode(CONN_ERROR); + return END; + } + if (transferStats.getErrorCode() != OK) { + return CHECK_FOR_ABORT; + } + return SEND_BLOCKS; +} + +SenderState SenderThread::sendSizeCmd() { + VLOG(1) << *this << " entered SEND_SIZE_CMD state"; + int64_t off = 0; + buf_[off++] = Protocol::SIZE_CMD; + + Protocol::encodeSize(buf_, off, Protocol::kMaxSize, + dirQueue_->getTotalSize()); + int64_t written = socket_->write(buf_, off); + if (written != off) { + LOG(ERROR) << "Socket write error " << off << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(off); + totalSizeSent_ = true; + return SEND_BLOCKS; +} + +SenderState SenderThread::sendDoneCmd() { + VLOG(1) << *this << " entered SEND_DONE_CMD state"; + int64_t off = 0; + buf_[off++] = Protocol::DONE_CMD; + auto pair = dirQueue_->getNumBlocksAndStatus(); + int64_t numBlocksDiscovered = pair.first; + ErrorCode transferStatus = pair.second; + buf_[off++] = transferStatus; + Protocol::encodeDone(threadProtocolVersion_, buf_, off, Protocol::kMaxDone, + numBlocksDiscovered, dirQueue_->getTotalSize()); + int toWrite = Protocol::kMinBufLength; + int64_t written = socket_->write(buf_, toWrite); + if (written != toWrite) { + LOG(ERROR) << "Socket write failure " << written << " " << toWrite; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(toWrite); + VLOG(1) << "Wrote done cmd on " << socket_->getFd() + << " waiting for reply..."; + return READ_RECEIVER_CMD; +} + +SenderState SenderThread::checkForAbort() { + LOG(INFO) << *this << " entered CHECK_FOR_ABORT state"; + auto numRead = socket_->read(buf_, 1); + if (numRead != 1) { + VLOG(1) << "No abort cmd found"; + return CONNECT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd != Protocol::ABORT_CMD) { + VLOG(1) << "Unexpected result found while reading for abort " << buf_[0]; + return CONNECT; + } + threadStats_.addHeaderBytes(1); + return PROCESS_ABORT_CMD; +} + +SenderState SenderThread::readFileChunks() { + LOG(INFO) << *this << " entered READ_FILE_CHUNKS state "; + int64_t numRead = socket_->read(buf_, 1); + if (numRead != 1) { + LOG(ERROR) << "Socket read error 1 " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd == Protocol::ABORT_CMD) { + return PROCESS_ABORT_CMD; + } + if (cmd == Protocol::WAIT_CMD) { + return READ_FILE_CHUNKS; + } + if (cmd == Protocol::ACK_CMD) { + if (!wdtParent_->isFileChunksReceived()) { + LOG(ERROR) << "Sender has not yet received file chunks, but receiver " + << "thinks it has already sent it"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + return SEND_BLOCKS; + } + if (cmd != Protocol::CHUNKS_CMD) { + LOG(ERROR) << "Unexpected cmd " << cmd; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + int64_t toRead = Protocol::kChunksCmdLen; + numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + int64_t off = 0; + int64_t bufSize, numFiles; + Protocol::decodeChunksCmd(buf_, off, bufSize, numFiles); + LOG(INFO) << "File chunk list has " << numFiles + << " entries and is broken in buffers of length " << bufSize; + std::unique_ptr chunkBuffer(new char[bufSize]); + std::vector fileChunksInfoList; + while (true) { + int64_t numFileChunks = fileChunksInfoList.size(); + if (numFileChunks > numFiles) { + // We should never be able to read more file chunks than mentioned in the + // chunks cmd. Chunks cmd has buffer size used to transfer chunks and also + // number of chunks. This chunks are read and parsed and added to + // fileChunksInfoList. Number of chunks we decode should match with the + // number mentioned in the Chunks cmd. + LOG(ERROR) << "Number of file chunks received is more than the number " + "mentioned in CHUNKS_CMD " << numFileChunks << " " + << numFiles; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + if (numFileChunks == numFiles) { + break; + } + toRead = sizeof(int32_t); + numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + toRead = folly::loadUnaligned(buf_); + toRead = folly::Endian::little(toRead); + numRead = socket_->read(chunkBuffer.get(), toRead); + if (numRead != toRead) { + LOG(ERROR) << "Socket read error " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(numRead); + off = 0; + // decode function below adds decoded file chunks to fileChunksInfoList + bool success = Protocol::decodeFileChunksInfoList( + chunkBuffer.get(), off, toRead, fileChunksInfoList); + if (!success) { + LOG(ERROR) << "Unable to decode file chunks list"; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + } + wdtParent_->setFileChunksInfo(fileChunksInfoList); + // send ack for file chunks list + buf_[0] = Protocol::ACK_CMD; + int64_t toWrite = 1; + int64_t written = socket_->write(buf_, toWrite); + if (toWrite != written) { + LOG(ERROR) << "Socket write error " << toWrite << " " << written; + threadStats_.setErrorCode(SOCKET_WRITE_ERROR); + return CHECK_FOR_ABORT; + } + threadStats_.addHeaderBytes(written); + return SEND_BLOCKS; +} + +SenderState SenderThread::readReceiverCmd() { + VLOG(1) << *this << " entered READ_RECEIVER_CMD state"; + int64_t numRead = socket_->read(buf_, 1); + if (numRead != 1) { + LOG(ERROR) << "READ unexpected " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + Protocol::CMD_MAGIC cmd = (Protocol::CMD_MAGIC)buf_[0]; + if (cmd == Protocol::ERR_CMD) { + return PROCESS_ERR_CMD; + } + if (cmd == Protocol::WAIT_CMD) { + return PROCESS_WAIT_CMD; + } + if (cmd == Protocol::DONE_CMD) { + return PROCESS_DONE_CMD; + } + if (cmd == Protocol::ABORT_CMD) { + return PROCESS_ABORT_CMD; + } + if (cmd == Protocol::LOCAL_CHECKPOINT_CMD) { + int checkpointLen = + Protocol::getMaxLocalCheckpointLength(threadProtocolVersion_); + int64_t toRead = checkpointLen - 1; + numRead = socket_->read(buf_ + 1, toRead); + if (numRead != toRead) { + LOG(ERROR) << "Could not read possible local checkpoint " << toRead << " " + << numRead << " " << port_; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + int64_t offset = 0; + std::vector checkpoints; + if (Protocol::decodeCheckpoints(threadProtocolVersion_, buf_, offset, + checkpointLen, checkpoints)) { + if (checkpoints.size() == 1 && checkpoints[0].port == port_ && + checkpoints[0].numBlocks == 0 && + checkpoints[0].lastBlockReceivedBytes == 0) { + // In a spurious local checkpoint, number of blocks and offset must both + // be zero + // Ignore the checkpoint + LOG(WARNING) + << "Received valid but unexpected local checkpoint, ignoring " + << port_ << " checkpoint " << checkpoints[0]; + return READ_RECEIVER_CMD; + } + } + LOG(ERROR) << "Failed to verify spurious local checkpoint, port " << port_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + LOG(ERROR) << "Read unexpected receiver cmd " << cmd << " port " << port_; + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; +} + +SenderState SenderThread::processDoneCmd() { + VLOG(1) << *this << " entered PROCESS_DONE_CMD state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + transferHistory.markAllAcknowledged(); + + // send ack for DONE + buf_[0] = Protocol::DONE_CMD; + socket_->write(buf_, 1); + + socket_->shutdown(); + auto numRead = socket_->read(buf_, Protocol::kMinBufLength); + if (numRead != 0) { + LOG(WARNING) << "EOF not found when expected"; + return END; + } + VLOG(1) << "done with transfer, port " << port_; + return END; +} + +SenderState SenderThread::processWaitCmd() { + LOG(INFO) << *this << " entered PROCESS_WAIT_CMD state "; + ; + ThreadTransferHistory &transferHistory = getTransferHistory(); + VLOG(1) << "received WAIT_CMD, port " << port_; + transferHistory.markAllAcknowledged(); + return READ_RECEIVER_CMD; +} + +SenderState SenderThread::processErrCmd() { + LOG(INFO) << *this << " entered PROCESS_ERR_CMD state"; + ThreadTransferHistory &transferHistory = getTransferHistory(); + int64_t toRead = sizeof(int16_t); + int64_t numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + LOG(ERROR) << "read unexpected " << toRead << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + + int16_t checkpointsLen = folly::loadUnaligned(buf_); + checkpointsLen = folly::Endian::little(checkpointsLen); + char checkpointBuf[checkpointsLen]; + numRead = socket_->read(checkpointBuf, checkpointsLen); + if (numRead != checkpointsLen) { + LOG(ERROR) << "read unexpected " << checkpointsLen << " " << numRead; + threadStats_.setErrorCode(SOCKET_READ_ERROR); + return CONNECT; + } + + std::vector checkpoints; + int64_t decodeOffset = 0; + if (!Protocol::decodeCheckpoints(threadProtocolVersion_, checkpointBuf, + decodeOffset, checkpointsLen, checkpoints)) { + LOG(ERROR) << "checkpoint decode failure " + << folly::humanify(std::string(checkpointBuf, checkpointsLen)); + threadStats_.setErrorCode(PROTOCOL_ERROR); + return END; + } + transferHistory.markAllAcknowledged(); + for (auto &checkpoint : checkpoints) { + LOG(INFO) << *this << " Received global checkpoint " << checkpoint; + transferHistoryController_->handleGlobalCheckpoint(checkpoint); + } + return SEND_BLOCKS; +} + +SenderState SenderThread::processAbortCmd() { + LOG(INFO) << *this << " entered PROCESS_ABORT_CMD state "; + ThreadTransferHistory &transferHistory = getTransferHistory(); + threadStats_.setErrorCode(ABORT); + int toRead = Protocol::kAbortLength; + auto numRead = socket_->read(buf_, toRead); + if (numRead != toRead) { + // can not read checkpoint, but still must exit because of ABORT + LOG(ERROR) << "Error while trying to read ABORT cmd " << numRead << " " + << toRead; + return END; + } + int64_t offset = 0; + int32_t negotiatedProtocol; + ErrorCode remoteError; + int64_t checkpoint; + Protocol::decodeAbort(buf_, offset, negotiatedProtocol, remoteError, + checkpoint); + threadStats_.setRemoteErrorCode(remoteError); + std::string failedFileName = transferHistory.getSourceId(checkpoint); + LOG(WARNING) << *this << "Received abort on " + << " remote protocol version " << negotiatedProtocol + << " remote error code " << errorCodeToStr(remoteError) + << " file " << failedFileName << " checkpoint " << checkpoint; + wdtParent_->abort(remoteError); + if (remoteError == VERSION_MISMATCH) { + if (Protocol::negotiateProtocol( + negotiatedProtocol, threadProtocolVersion_) == negotiatedProtocol) { + // sender can support this negotiated version + negotiatedProtocol_ = negotiatedProtocol; + return PROCESS_VERSION_MISMATCH; + } else { + LOG(ERROR) << "Sender can not support receiver version " + << negotiatedProtocol; + threadStats_.setRemoteErrorCode(VERSION_INCOMPATIBLE); + } + } + return END; +} + +SenderState SenderThread::processVersionMismatch() { + LOG(INFO) << *this << " entered PROCESS_VERSION_MISMATCH state "; + WDT_CHECK(threadStats_.getErrorCode() == ABORT); + auto negotiationStatus = wdtParent_->getNegotiationStatus(); + WDT_CHECK_NE(negotiationStatus, V_MISMATCH_FAILED) + << "Thread should have ended in case of version mismatch"; + if (negotiationStatus == V_MISMATCH_RESOLVED) { + LOG(WARNING) << *this << " Protocol version already negotiated, but " + "transfer still aborted due to version mismatch"; + return END; + } + WDT_CHECK_EQ(negotiationStatus, V_MISMATCH_WAIT); + // Need a barrier here to make sure all the negotiated protocol versions + // have been collected + auto barrier = controller_->getBarrier(VERSION_MISMATCH_BARRIER); + barrier->execute(); + VLOG(1) << *this << " cleared the protocol version barrier"; + auto execFunnel = controller_->getFunnel(VERSION_MISMATCH_FUNNEL); + while (true) { + auto status = execFunnel->getStatus(); + switch (status) { + case FUNNEL_START: { + LOG(INFO) << *this << " started the funnel for version mismatch"; + wdtParent_->setProtoNegotiationStatus(V_MISMATCH_FAILED); + if (wdtParent_->verifyVersionMismatchStats() != OK) { + execFunnel->notifySuccess(); + return END; + } + if (transferHistoryController_->handleVersionMismatch() != OK) { + execFunnel->notifySuccess(); + return END; + } + int negotiatedProtocol = 0; + for (int threadProtocolVersion_ : + wdtParent_->getNegotiatedProtocols()) { + if (threadProtocolVersion_ > 0) { + if (negotiatedProtocol > 0 && + negotiatedProtocol != threadProtocolVersion_) { + LOG(ERROR) << "Different threads negotiated different protocols " + << negotiatedProtocol << " " << threadProtocolVersion_; + execFunnel->notifySuccess(); + return END; + } + negotiatedProtocol = threadProtocolVersion_; + } + } + WDT_CHECK_GT(negotiatedProtocol, 0); + LOG_IF(INFO, negotiatedProtocol != threadProtocolVersion_) + << *this << "Changing protocol version to " << negotiatedProtocol + << ", previous version " << threadProtocolVersion_; + wdtParent_->setProtocolVersion(negotiatedProtocol); + threadProtocolVersion_ = wdtParent_->getProtocolVersion(); + threadStats_.setRemoteErrorCode(OK); + wdtParent_->setProtoNegotiationStatus(V_MISMATCH_RESOLVED); + wdtParent_->clearAbort(); + execFunnel->notifySuccess(); + return CONNECT; + } + case FUNNEL_PROGRESS: { + execFunnel->wait(); + break; + } + case FUNNEL_END: { + negotiationStatus = wdtParent_->getNegotiationStatus(); + WDT_CHECK_NE(negotiationStatus, V_MISMATCH_WAIT); + if (negotiationStatus == V_MISMATCH_FAILED) { + return END; + } + if (negotiationStatus == V_MISMATCH_RESOLVED) { + threadProtocolVersion_ = wdtParent_->getProtocolVersion(); + threadStats_.setRemoteErrorCode(OK); + return CONNECT; + } + } + } + } +} + +void SenderThread::start() { + INIT_PERF_STAT_REPORT + Clock::time_point startTime = Clock::now(); + auto completionGuard = folly::makeGuard([&] { + ThreadTransferHistory &transferHistory = getTransferHistory(); + transferHistory.markNotInUse(); + controller_->deRegisterThread(threadIndex_); + controller_->executeAtEnd([&]() { wdtParent_->endCurTransfer(); }); + }); + controller_->executeAtStart([&]() { wdtParent_->startNewTransfer(); }); + SenderState state = CONNECT; + while (state != END) { + ErrorCode abortCode = wdtParent_->getCurAbortCode(); + if (abortCode != OK) { + LOG(ERROR) << *this << "Transfer aborted " << errorCodeToStr(abortCode); + threadStats_.setErrorCode(ABORT); + if (abortCode == VERSION_MISMATCH) { + state = PROCESS_VERSION_MISMATCH; + } else { + break; + } + } + state = (this->*stateMap_[state])(); + } + + double totalTime = durationSeconds(Clock::now() - startTime); + LOG(INFO) << "Port " << port_ << " done. " << threadStats_ + << " Total throughput = " + << threadStats_.getEffectiveTotalBytes() / totalTime / kMbToB + << " Mbytes/sec"; + perfReport_ = *perfStatReport; + return; +} + +int SenderThread::getPort() const { + return port_; +} + +int SenderThread::getNegotiatedProtocol() const { + return negotiatedProtocol_; +} + +ErrorCode SenderThread::init() { + return OK; +} + +void SenderThread::reset() { + totalSizeSent_ = false; + threadStats_.setErrorCode(OK); +} +} +} diff --git a/SenderThread.h b/SenderThread.h new file mode 100644 index 00000000..f75684f5 --- /dev/null +++ b/SenderThread.h @@ -0,0 +1,278 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include +#include "WdtThread.h" +#include "Sender.h" +#include "ClientSocket.h" +#include "ThreadTransferHistory.h" +namespace facebook { +namespace wdt { +class DirectorySourceQueue; +/// state machine states +enum SenderState { + CONNECT, + READ_LOCAL_CHECKPOINT, + SEND_SETTINGS, + SEND_BLOCKS, + SEND_DONE_CMD, + SEND_SIZE_CMD, + CHECK_FOR_ABORT, + READ_FILE_CHUNKS, + READ_RECEIVER_CMD, + PROCESS_DONE_CMD, + PROCESS_WAIT_CMD, + PROCESS_ERR_CMD, + PROCESS_ABORT_CMD, + PROCESS_VERSION_MISMATCH, + END +}; + +/** + * This class represents one sender thread. It contains all the + * functionalities that a thread would need to send data over + * a connection to the receiver. + * All the sender threads share bunch of modules like directory queue, + * throttler, threads controller etc + */ +class SenderThread : public WdtThread { + public: + /// Identifers for the barriers used in the thread + enum SENDER_BARRIERS { VERSION_MISMATCH_BARRIER, NUM_BARRIERS }; + + /// Identifiers for the funnels used in the thread + enum SENDER_FUNNELS { VERSION_MISMATCH_FUNNEL, NUM_FUNNELS }; + + /// Identifier for the condition wrappers used in the thread + enum SENDER_CONDITIONS { NUM_CONDITIONS }; + + /// abort checker passed to client sockets. This checks both global sender + /// abort and whether global checkpoint has been received or not + class SocketAbortChecker : public WdtBase::AbortChecker { + public: + explicit SocketAbortChecker(WdtBase *wdtBase, + ThreadTransferHistory &transferHistory) + : AbortChecker(wdtBase), transferHistory_(transferHistory) { + } + + bool shouldAbort() const { + return AbortChecker::shouldAbort() || + transferHistory_.isGlobalCheckpointReceived(); + } + + private: + ThreadTransferHistory &transferHistory_; + }; + + /// Constructor for the sender thread + SenderThread(Sender *sender, int threadIndex, int32_t port, + ThreadsController *threadsController) + : WdtThread(threadIndex, sender->getProtocolVersion(), threadsController), + wdtParent_(sender), + port_(port), + dirQueue_(sender->dirQueue_.get()), + transferHistoryController_(sender->transferHistoryController_.get()) { + controller_->registerThread(threadIndex_); + transferHistoryController_->addThreadHistory(port_, threadStats_); + socketAbortChecker_ = + folly::make_unique(sender, getTransferHistory()); + threadStats_.setId(folly::to(threadIndex_)); + } + + typedef SenderState (SenderThread::*StateFunction)(); + + /// Returns the neogtiated protocol + int getNegotiatedProtocol() const override; + + /// Steps to do ebfore calling start + ErrorCode init() override; + + /// Reset the sender thread + void reset() override; + + /// Get the port sender thread is connecting to + int getPort() const override; + + /// Destructor of the sender thread + ~SenderThread() { + } + + private: + /// Overloaded operator for printing thread info + friend std::ostream &operator<<(std::ostream &os, + const SenderThread &senderThread); + + /// Parent shared among all the threads for meta information + Sender *wdtParent_; + + /// Special abort checker for the client socket + std::unique_ptr socketAbortChecker_{nullptr}; + + /// The main entry point of the thread + void start() override; + + /// Get the local transfer history + ThreadTransferHistory &getTransferHistory() { + return transferHistoryController_->getTransferHistory(port_); + } + + /** + * tries to connect to the receiver + * Previous states : Almost all states(in case of network errors, all states + * move to this state) + * Next states : SEND_SETTINGS(if there is no previous error) + * READ_LOCAL_CHECKPOINT(if there is previous error) + * END(failed) + */ + SenderState connect(); + /** + * tries to read local checkpoint and return unacked sources to queue. If the + * checkpoint value is -1, then we know previous attempt to send DONE had + * failed. So, we move to READ_RECEIVER_CMD state. + * Previous states : CONNECT + * Next states : CONNECT(read failure), + * END(protocol error or global checkpoint found), + * READ_RECEIVER_CMD(if checkpoint is -1), + * SEND_SETTINGS(success) + */ + SenderState readLocalCheckPoint(); + /** + * sends sender settings to the receiver + * Previous states : READ_LOCAL_CHECKPOINT, + * CONNECT + * Next states : SEND_BLOCKS(success), + * CONNECT(failure) + */ + SenderState sendSettings(); + /** + * sends blocks to receiver till the queue is not empty. After transferring a + * block, we add it to the history. While adding to history, if it is found + * that global checkpoint has been received for this thread, we move to END + * state. + * Previous states : SEND_SETTINGS, + * PROCESS_ERR_CMD + * Next states : SEND_BLOCKS(success), + * END(global checkpoint received), + * CHECK_FOR_ABORT(socket write failure), + * SEND_DONE_CMD(no more blocks left to transfer) + */ + SenderState sendBlocks(); + /** + * sends DONE cmd to the receiver + * Previous states : SEND_BLOCKS + * Next states : CONNECT(failure), + * READ_RECEIVER_CMD(success) + */ + SenderState sendDoneCmd(); + /** + * sends size cmd to the receiver + * Previous states : SEND_BLOCKS + * Next states : CHECK_FOR_ABORT(failure), + * SEND_BLOCKS(success) + */ + SenderState sendSizeCmd(); + /** + * checks to see if the receiver has sent ABORT or not + * Previous states : SEND_BLOCKS, + * SEND_DONE_CMD + * Next states : CONNECT(no ABORT cmd), + * END(protocol error), + * PROCESS_ABORT_CMD(read ABORT cmd) + */ + SenderState checkForAbort(); + /** + * reads previously transferred file chunks list. If it receives an ACK cmd, + * then it moves on. If wait cmd is received, it waits. Otherwise reads the + * file chunks and when done starts directory queue thread. + * Previous states : SEND_SETTINGS, + * Next states: READ_FILE_CHUNKS(if wait cmd is received), + * CHECK_FOR_ABORT(network error), + * END(protocol error), + * SEND_BLOCKS(success) + * + */ + SenderState readFileChunks(); + /** + * reads receiver cmd + * Previous states : SEND_DONE_CMD + * Next states : PROCESS_DONE_CMD, + * PROCESS_WAIT_CMD, + * PROCESS_ERR_CMD, + * END(protocol error), + * CONNECT(failure) + */ + SenderState readReceiverCmd(); + /** + * handles DONE cmd + * Previous states : READ_RECEIVER_CMD + * Next states : END + */ + SenderState processDoneCmd(); + /** + * handles WAIT cmd + * Previous states : READ_RECEIVER_CMD + * Next states : READ_RECEIVER_CMD + */ + SenderState processWaitCmd(); + /** + * reads list of global checkpoints and returns unacked sources to queue. + * Previous states : READ_RECEIVER_CMD + * Next states : CONNECT(socket read failure) + * END(checkpoint list decode failure), + * SEND_BLOCKS(success) + */ + SenderState processErrCmd(); + /** + * processes ABORT cmd + * Previous states : CHECK_FOR_ABORT, + * READ_RECEIVER_CMD + * Next states : END + */ + SenderState processAbortCmd(); + + /** + * waits for all active threads to be aborted, checks to see if the abort was + * due to version mismatch. Also performs various sanity checks. + * Previous states : Almost all threads, abort flags is checked between every + * state transition + * Next states : CONNECT(Abort was due to version kismatch), + * END(if abort was not due to version mismatch or some sanity + * check failed) + */ + SenderState processVersionMismatch(); + + /// mapping from sender states to state functions + static const StateFunction stateMap_[]; + + /// Port number of this sender thread + const int32_t port_; + + /// Negotiated protocol of the sender thread + int negotiatedProtocol_{-1}; + + /// Pointer to client socket to maintain connection to the receiver + std::unique_ptr socket_; + + /// Buffer used by the sender thread to read/write data + char buf_[Protocol::kMinBufLength]; + + /// whether total file size has been sent to the receiver + bool totalSizeSent_{false}; + + /// number of consecutive reconnects without any progress + int numReconnectWithoutProgress_{0}; + + /// Point to the directory queue of parent sender + DirectorySourceQueue *dirQueue_; + + /// Thread history controller shared across all threads + TransferHistoryController *transferHistoryController_; +}; +} +} diff --git a/ThreadTransferHistory.cpp b/ThreadTransferHistory.cpp new file mode 100644 index 00000000..a69f7fd4 --- /dev/null +++ b/ThreadTransferHistory.cpp @@ -0,0 +1,282 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ThreadTransferHistory.h" +#include "Sender.h" +namespace facebook { +namespace wdt { +ThreadTransferHistory::ThreadTransferHistory(DirectorySourceQueue &queue, + TransferStats &threadStats, + int32_t port) + : queue_(queue), threadStats_(threadStats), port_(port) { + VLOG(1) << "Making thread history for port " << port_; +} + +std::string ThreadTransferHistory::getSourceId(int64_t index) { + std::lock_guard lock(mutex_); + std::string sourceId; + const int64_t historySize = history_.size(); + if (index >= 0 && index < historySize) { + sourceId = history_[index]->getIdentifier(); + } else { + LOG(WARNING) << "Trying to read out of bounds data " << index << " " + << history_.size(); + } + return sourceId; +} + +bool ThreadTransferHistory::addSource(std::unique_ptr &source) { + std::lock_guard lock(mutex_); + if (globalCheckpoint_) { + // already received an error for this thread + VLOG(1) << "adding source after global checkpoint is received. returning " + "the source to the queue"; + markSourceAsFailed(source, lastCheckpoint_.get()); + lastCheckpoint_.reset(); + queue_.returnToQueue(source); + return false; + } + history_.emplace_back(std::move(source)); + return true; +} + +ErrorCode ThreadTransferHistory::setLocalCheckpoint( + const Checkpoint &checkpoint) { + std::lock_guard lock(mutex_); + return setCheckpointAndReturnToQueue(checkpoint, false); +} + +ErrorCode ThreadTransferHistory::setGlobalCheckpoint( + const Checkpoint &checkpoint) { + std::unique_lock lock(mutex_); + ErrorCode status = setCheckpointAndReturnToQueue(checkpoint, true); + while (inUse_) { + // have to wait, error thread signalled through globalCheckpoint_ flag + LOG(INFO) << "Transfer history still in use, waiting, checkpoint " + << checkpoint; + conditionInUse_.wait(lock); + } + return status; +} +ErrorCode ThreadTransferHistory::setCheckpointAndReturnToQueue( + const Checkpoint &checkpoint, bool globalCheckpoint) { + const int64_t historySize = history_.size(); + int64_t numReceivedSources = checkpoint.numBlocks; + int64_t lastBlockReceivedBytes = checkpoint.lastBlockReceivedBytes; + if (numReceivedSources > historySize) { + LOG(ERROR) + << "checkpoint is greater than total number of sources transferred " + << history_.size() << " " << numReceivedSources; + return INVALID_CHECKPOINT; + } + ErrorCode errCode = validateCheckpoint(checkpoint, globalCheckpoint); + if (errCode == INVALID_CHECKPOINT) { + return INVALID_CHECKPOINT; + } + globalCheckpoint_ |= globalCheckpoint; + lastCheckpoint_ = folly::make_unique(checkpoint); + int64_t numFailedSources = historySize - numReceivedSources; + if (numFailedSources == 0 && lastBlockReceivedBytes > 0) { + if (!globalCheckpoint) { + // no block to apply checkpoint offset. This can happen if we receive same + // local checkpoint without adding anything to the history + LOG(WARNING) << "Local checkpoint has received bytes for last block, but " + "there are no unacked blocks in the history. Ignoring."; + } + } + numAcknowledged_ = numReceivedSources; + std::vector> sourcesToReturn; + for (int64_t i = 0; i < numFailedSources; i++) { + std::unique_ptr source = std::move(history_.back()); + history_.pop_back(); + const Checkpoint *checkpointPtr = + (i == numFailedSources - 1 ? &checkpoint : nullptr); + markSourceAsFailed(source, checkpointPtr); + sourcesToReturn.emplace_back(std::move(source)); + } + queue_.returnToQueue(sourcesToReturn); + LOG(INFO) << numFailedSources + << " number of sources returned to queue, checkpoint: " + << checkpoint; + return errCode; +} + +std::vector ThreadTransferHistory::popAckedSourceStats() { + std::unique_lock lock(mutex_); + const int64_t historySize = history_.size(); + WDT_CHECK(numAcknowledged_ == historySize); + // no locking needed, as this should be called after transfer has finished + std::vector sourceStats; + while (!history_.empty()) { + sourceStats.emplace_back(std::move(history_.back()->getTransferStats())); + history_.pop_back(); + } + return sourceStats; +} + +void ThreadTransferHistory::markAllAcknowledged() { + std::unique_lock lock(mutex_); + numAcknowledged_ = history_.size(); +} + +void ThreadTransferHistory::returnUnackedSourcesToQueue() { + std::unique_lock lock(mutex_); + Checkpoint checkpoint; + checkpoint.numBlocks = numAcknowledged_; + setCheckpointAndReturnToQueue(checkpoint, false); +} + +ErrorCode ThreadTransferHistory::validateCheckpoint( + const Checkpoint &checkpoint, bool globalCheckpoint) { + if (lastCheckpoint_ == nullptr) { + return OK; + } + if (checkpoint.numBlocks < lastCheckpoint_->numBlocks) { + LOG(ERROR) << "Current checkpoint must be higher than previous checkpoint, " + "Last checkpoint: " << *lastCheckpoint_ + << ", Current checkpoint: " << checkpoint; + return INVALID_CHECKPOINT; + } + if (checkpoint.numBlocks > lastCheckpoint_->numBlocks) { + return OK; + } + bool noProgress = false; + // numBlocks same + if (checkpoint.lastBlockSeqId == lastCheckpoint_->lastBlockSeqId && + checkpoint.lastBlockOffset == lastCheckpoint_->lastBlockOffset) { + // same block + if (checkpoint.lastBlockReceivedBytes != + lastCheckpoint_->lastBlockReceivedBytes) { + LOG(ERROR) << "Current checkpoint has different received bytes, but all " + "other fields are same, Last checkpoint " + << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; + return INVALID_CHECKPOINT; + } + noProgress = true; + } else { + // different block + WDT_CHECK(checkpoint.lastBlockReceivedBytes >= 0); + if (checkpoint.lastBlockReceivedBytes == 0) { + noProgress = true; + } + } + if (noProgress && !globalCheckpoint) { + // we can get same global checkpoint multiple times, so no need to check for + // progress + LOG(WARNING) << "No progress since last checkpoint, Last checkpoint: " + << *lastCheckpoint_ << ", Current checkpoint: " << checkpoint; + return NO_PROGRESS; + } + return OK; +} + +void ThreadTransferHistory::markSourceAsFailed( + std::unique_ptr &source, const Checkpoint *checkpoint) { + auto metadata = source->getMetaData(); + bool validCheckpoint = false; + if (checkpoint != nullptr) { + if (checkpoint->hasSeqId) { + if ((checkpoint->lastBlockSeqId == metadata.seqId) && + (checkpoint->lastBlockOffset == source->getOffset())) { + validCheckpoint = true; + } else { + LOG(WARNING) + << "Checkpoint block does not match history block. Checkpoint: " + << checkpoint->lastBlockSeqId << ", " << checkpoint->lastBlockOffset + << " History: " << metadata.seqId << ", " << source->getOffset(); + } + } else { + // Receiver at lower version! + // checkpoint does not have seq-id. We have to blindly trust + // lastBlockReceivedBytes. If we do not, transfer will fail because of + // number of bytes mismatch. Even if an error happens because of this, + // Receiver will fail. + validCheckpoint = true; + } + } + int64_t receivedBytes = + (validCheckpoint ? checkpoint->lastBlockReceivedBytes : 0); + TransferStats &sourceStats = source->getTransferStats(); + if (sourceStats.getErrorCode() != OK) { + // already marked as failed + sourceStats.addEffectiveBytes(0, receivedBytes); + threadStats_.addEffectiveBytes(0, receivedBytes); + } else { + auto dataBytes = source->getSize(); + auto headerBytes = sourceStats.getEffectiveHeaderBytes(); + int64_t wastedBytes = dataBytes - receivedBytes; + sourceStats.subtractEffectiveBytes(headerBytes, wastedBytes); + sourceStats.decrNumBlocks(); + sourceStats.setErrorCode(SOCKET_WRITE_ERROR); + sourceStats.incrFailedAttempts(); + + threadStats_.subtractEffectiveBytes(headerBytes, wastedBytes); + threadStats_.decrNumBlocks(); + threadStats_.incrFailedAttempts(); + } + source->advanceOffset(receivedBytes); +} + +bool ThreadTransferHistory::isGlobalCheckpointReceived() { + std::lock_guard lock(mutex_); + return globalCheckpoint_; +} + +void ThreadTransferHistory::markNotInUse() { + std::lock_guard lock(mutex_); + inUse_ = false; + conditionInUse_.notify_all(); +} + +TransferHistoryController::TransferHistoryController( + DirectorySourceQueue &dirQueue) + : dirQueue_(dirQueue) { +} + +ThreadTransferHistory &TransferHistoryController::getTransferHistory( + int32_t port) { + auto it = threadHistoriesMap_.find(port); + WDT_CHECK(it != threadHistoriesMap_.end()) << "port not found" << port; + return *(it->second.get()); +} + +void TransferHistoryController::addThreadHistory(int32_t port, + TransferStats &threadStats) { + VLOG(1) << "Adding the history for " << port; + threadHistoriesMap_.emplace(port, folly::make_unique( + dirQueue_, threadStats, port)); +} + +ErrorCode TransferHistoryController::handleVersionMismatch() { + for (auto &historyPair : threadHistoriesMap_) { + auto &history = historyPair.second; + if (history->getNumAcked() > 0) { + LOG(ERROR) << "Even though the transfer aborted due to VERSION_MISMATCH, " + "some blocks got acked by the receiver, port " + << historyPair.first << " numAcked " << history->getNumAcked(); + return ERROR; + } + history->returnUnackedSourcesToQueue(); + } + return OK; +} + +void TransferHistoryController::handleGlobalCheckpoint( + const Checkpoint &checkpoint) { + auto errPort = checkpoint.port; + auto it = threadHistoriesMap_.find(errPort); + if (it == threadHistoriesMap_.end()) { + LOG(ERROR) << "Invalid checkpoint " << checkpoint + << ". No sender thread running on port " << errPort; + return; + } + VLOG(1) << "received global checkpoint " << checkpoint; + it->second->setGlobalCheckpoint(checkpoint); +} +} +} diff --git a/ThreadTransferHistory.h b/ThreadTransferHistory.h new file mode 100644 index 00000000..dfc26559 --- /dev/null +++ b/ThreadTransferHistory.h @@ -0,0 +1,168 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "DirectorySourceQueue.h" +#include "Reporting.h" +#include "Protocol.h" +#include +#include +namespace facebook { +namespace wdt { +/// Transfer history of a sender thread +class ThreadTransferHistory { + public: + /** + * @param queue directory queue + * @param threadStats stat object of the thread + */ + ThreadTransferHistory(DirectorySourceQueue &queue, TransferStats &threadStats, + int32_t port); + + /** + * @param index of the source + * @return if index is in bounds, returns the identifier for the + * source, else returns empty string + */ + std::string getSourceId(int64_t index); + + /** + * Adds the source to the history. If global checkpoint has already been + * received, then the source is returned to the queue. + * + * @param source source to be added to history + * @return true if added to history, false if not added due to a + * global checkpoint + */ + bool addSource(std::unique_ptr &source); + + /** + * Sets checkpoint. Also, returns unacked sources to queue + * @param checkpoint checkpoint received + * @param globalCheckpoint global or local checkpoint + * @return number of sources returned to queue + */ + ErrorCode setLocalCheckpoint(const Checkpoint &checkpoint); + + /** + * @return stats for acked sources, must be called after all the + * unacked sources are returned to the queue + */ + std::vector popAckedSourceStats(); + + /// marks all the sources as acked + void markAllAcknowledged(); + + /** + * returns all unacked sources to the queue + */ + void returnUnackedSourcesToQueue(); + + /** + * @return number of sources acked by the receiver + */ + int64_t getNumAcked() const { + return numAcknowledged_; + } + + /// @return whether global checkpoint has been received or not + bool isGlobalCheckpointReceived(); + + /// Clears the inUse_ flag and notifies other waiting threads + void markNotInUse(); + + /// Copy constructor deleted + ThreadTransferHistory(const ThreadTransferHistory &that) = delete; + + /// Delete the assignment operatory by copy + ThreadTransferHistory &operator=(const ThreadTransferHistory &that) = delete; + + private: + friend class TransferHistoryController; + /// Validates a checkpoint and returns the status + ErrorCode validateCheckpoint(const Checkpoint &checkpoint, + bool globalCheckpoint); + /** + * Sets global checkpoint. If the history is still in use, waits for the + * error thread to add current block to the history. + * + * @param checkpoint checkpoint received + * + * @return status of the operation + */ + ErrorCode setGlobalCheckpoint(const Checkpoint &checkpoint); + + void markSourceAsFailed(std::unique_ptr &source, + const Checkpoint *checkpoint); + /** + * Sets checkpoint. Also, returns unacked sources to queue + * + * @param checkpoint checkpoint received + * @param globalCheckpoint global or local checkpoint + * @return status of sources returned to queue + */ + ErrorCode setCheckpointAndReturnToQueue(const Checkpoint &checkpoint, + bool globalCheckpoint); + + /// reference to global queue + DirectorySourceQueue &queue_; + /// reference to thread stats + TransferStats &threadStats_; + /// history of the thread + std::vector> history_; + /// whether a global error checkpoint has been received or not + bool globalCheckpoint_{false}; + /// number of sources acked by the receiver thread + int64_t numAcknowledged_{0}; + /// last received checkpoint + std::unique_ptr lastCheckpoint_{nullptr}; + /// Port assosciated with the history + int32_t port_; + /// whether the owner thread is still using this + bool inUse_{true}; + /// Mutex used by history internally for synchronization + std::mutex mutex_; + /// Condition variable to signify the history being in use + std::condition_variable conditionInUse_; +}; + +/// Controller for history across the sender threads +class TransferHistoryController { + public: + /** + * Constructor for the history controller + * @param dirQueue Directory queue used by the sender + */ + explicit TransferHistoryController(DirectorySourceQueue &dirQueue); + + /** + * Add transfer history for a thread + * @param port Port being used by the sender thread + * @param threadStats Thread stats for the sender thread + */ + void addThreadHistory(int32_t port, TransferStats &threadStats); + + /// Get transfer history for a thread using its port number + ThreadTransferHistory &getTransferHistory(int32_t port); + + /// Handle version mismatch across the histories of all thread + ErrorCode handleVersionMismatch(); + + /// Handle checkpoint that was sent by the receiver + void handleGlobalCheckpoint(const Checkpoint &checkpoint); + + private: + /// Reference to the directory queue being used by the sender + DirectorySourceQueue &dirQueue_; + + /// Map of port (used by sender threads) and transfer history + std::unordered_map> + threadHistoriesMap_; +}; +} +} diff --git a/ThreadsController.cpp b/ThreadsController.cpp new file mode 100644 index 00000000..cca8b81c --- /dev/null +++ b/ThreadsController.cpp @@ -0,0 +1,241 @@ +#include "ThreadsController.h" +#include "WdtOptions.h" +using namespace std; +namespace facebook { +namespace wdt { +void ConditionGuardImpl::wait(int timeoutMillis) { + if (timeoutMillis <= 0) { + cv_.wait(*lock_); + return; + } + auto waitingTime = chrono::milliseconds(timeoutMillis); + cv_.wait_for(*lock_, waitingTime); +} + +void ConditionGuardImpl::notifyAll() { + cv_.notify_all(); +} + +void ConditionGuardImpl::notifyOne() { + cv_.notify_one(); +} + +ConditionGuardImpl::~ConditionGuardImpl() { + if (lock_ != nullptr) { + delete lock_; + } + cv_.notify_one(); +} + +ConditionGuardImpl::ConditionGuardImpl(mutex &guardMutex, + condition_variable &cv) + : cv_(cv) { + lock_ = new unique_lock(guardMutex); +} + +ConditionGuardImpl::ConditionGuardImpl(ConditionGuardImpl &&that) noexcept + : cv_(that.cv_) { + swap(lock_, that.lock_); +} + +ConditionGuardImpl ConditionGuard::acquire() { + return ConditionGuardImpl(mutex_, cv_); +} + +FunnelStatus Funnel::getStatus() { + unique_lock lock(mutex_); + if (status_ == FUNNEL_START) { + status_ = FUNNEL_PROGRESS; + return FUNNEL_START; + } + return status_; +} + +void Funnel::wait() { + unique_lock lock(mutex_); + if (status_ != FUNNEL_PROGRESS) { + return; + } + cv_.wait(lock); +} + +void Funnel::wait(int32_t waitingTime) { + auto waitMillis = chrono::milliseconds(waitingTime); + unique_lock lock(mutex_); + if (status_ != FUNNEL_PROGRESS) { + return; + } + cv_.wait_for(lock, waitMillis); +} + +void Funnel::notifySuccess() { + unique_lock lock(mutex_); + status_ = FUNNEL_END; + cv_.notify_all(); +} + +void Funnel::notifyFail() { + unique_lock lock(mutex_); + status_ = FUNNEL_START; + cv_.notify_one(); +} + +bool Barrier::checkForFinish() { + // lock should be held while calling this method + WDT_CHECK_GE(numThreads_, numHits_); + if (numHits_ == numThreads_) { + isComplete_ = true; + cv_.notify_all(); + } + return isComplete_; +} + +void Barrier::execute() { + unique_lock lock(mutex_); + WDT_CHECK(!isComplete_) << "Hitting the barrier after completion"; + ++numHits_; + if (checkForFinish()) { + return; + } + while (!isComplete_) { + cv_.wait(lock); + } +} + +void Barrier::deRegister() { + unique_lock lock(mutex_); + if (isComplete_) { + return; + } + --numThreads_; + checkForFinish(); +} + +ThreadsController::ThreadsController(int totalThreads) { + totalThreads_ = totalThreads; + for (int threadNum = 0; threadNum < totalThreads; ++threadNum) { + threadStateMap_[threadNum] = INIT; + } + execAtStart_.reset(new ExecuteOnceFunc(totalThreads_, true)); + execAtEnd_.reset(new ExecuteOnceFunc(totalThreads_, false)); +} + +void ThreadsController::registerThread(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + threadStateMap_[threadIndex] = RUNNING; +} + +void ThreadsController::deRegisterThread(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + threadStateMap_[threadIndex] = FINISHED; + // Notify all the barriers + for (auto barrier : barriers_) { + WDT_CHECK(barrier != nullptr); + barrier->deRegister(); + } +} + +ThreadStatus ThreadsController::getState(int threadIndex) { + GuardLock lock(controllerMutex_); + auto it = threadStateMap_.find(threadIndex); + WDT_CHECK(it != threadStateMap_.end()); + return it->second; +} + +void ThreadsController::markState(int threadIndex, ThreadStatus threadState) { + GuardLock lock(controllerMutex_); + threadStateMap_[threadIndex] = threadState; +} + +unordered_map ThreadsController::getThreadStates() const { + GuardLock lock(controllerMutex_); + return threadStateMap_; +} + +int ThreadsController::getTotalThreads() { + return totalThreads_; +} + +bool ThreadsController::hasThreads(int threadIndex, ThreadStatus threadState) { + GuardLock lock(controllerMutex_); + for (auto &threadPair : threadStateMap_) { + if (threadPair.first == threadIndex) { + continue; + } + if (threadPair.second == threadState) { + return true; + } + } + return false; +} + +shared_ptr ThreadsController::getCondition( + const uint64_t conditionIndex) { + bool isExists = (conditionGuards_.size() > conditionIndex) && + (conditionGuards_[conditionIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a condition wrapper that doesn't exist." + << " Request Index : " << conditionIndex + << ", num condition wrappers : " << conditionGuards_.size(); + return conditionGuards_[conditionIndex]; +} + +shared_ptr ThreadsController::getBarrier(const uint64_t barrierIndex) { + bool isExists = + (barriers_.size() > barrierIndex) && (barriers_[barrierIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a barrier that doesn't exist. Request index : " + << barrierIndex << ", num barriers " << barriers_.size(); + return barriers_[barrierIndex]; +} + +shared_ptr ThreadsController::getFunnel(const uint64_t funnelIndex) { + bool isExists = (funnelExecutors_.size() > funnelIndex) && + (funnelExecutors_[funnelIndex] != nullptr); + WDT_CHECK(isExists) + << "Requesting for a funnel that doesn't exist. Request index : " + << funnelIndex << ", num funnels " << funnelExecutors_.size(); + return funnelExecutors_[funnelIndex]; +} + +void ThreadsController::reset() { + // Only used in the case of long running mode + setNumBarriers(barriers_.size()); + setNumConditions(conditionGuards_.size()); + setNumFunnels(funnelExecutors_.size()); + execAtStart_->reset(); + execAtEnd_->reset(); + GuardLock lock(controllerMutex_); + // Restore threads back to initial state + for (auto &threadPair : threadStateMap_) { + threadPair.second = RUNNING; + } +} + +void ThreadsController::setNumBarriers(int numBarriers) { + // Meant to be called outside of threads + barriers_.clear(); + for (int i = 0; i < numBarriers; i++) { + barriers_.push_back(make_shared(getTotalThreads())); + } +} + +void ThreadsController::setNumConditions(int numConditions) { + conditionGuards_.clear(); + for (int i = 0; i < numConditions; i++) { + conditionGuards_.push_back(make_shared()); + } +} + +void ThreadsController::setNumFunnels(int numFunnels) { + funnelExecutors_.clear(); + for (int i = 0; i < numFunnels; i++) { + funnelExecutors_.push_back(make_shared()); + } +} +} +} diff --git a/ThreadsController.h b/ThreadsController.h new file mode 100644 index 00000000..927dd8dc --- /dev/null +++ b/ThreadsController.h @@ -0,0 +1,371 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include +#include +#include +#include +#include +#include "WdtThread.h" +#include "ErrorCodes.h" + +namespace facebook { +namespace wdt { + +class WdtThread; +/** + * Thread states that represent what kind of functionality + * are they executing on a higher level. + * INIT - State before running at the time of construction + * RUNNING - Thread is running without any errors + * WAITING - Thread is not doing anything meaninful but + * rather waiting on other threads for something + * FINISHED - Threads have finished with/without error + */ +enum ThreadStatus { INIT, RUNNING, WAITING, FINISHED }; + +/** + * Primitive that takes a function and executes + * it only once either on the first thread + * or the last thread entrance + */ +class ExecuteOnceFunc { + public: + /// Constructor for the once only executor + ExecuteOnceFunc(int numThreads, bool execFirst) { + execFirst_ = execFirst; + numThreads_ = numThreads; + } + + /// Deleted copy constructor + ExecuteOnceFunc(const ExecuteOnceFunc &that) = delete; + + /// Deleted assignment operator + ExecuteOnceFunc &operator=(const ExecuteOnceFunc &that) = delete; + + /// Implements the main functionality of the executor + template + void execute(Func &&execFunc) { + std::unique_lock lock(mutex_); + ++numHits_; + WDT_CHECK(numHits_ <= numThreads_); + int64_t numExpected = (execFirst_) ? 1 : numThreads_; + if (numHits_ == numExpected) { + execFunc(); + } + } + + /// Reset the number of hits + void reset() { + numHits_ = 0; + } + + private: + /// Mutex for thread synchronization + std::mutex mutex_; + /// Number of times execute has been called + int numHits_{0}; + /// Function can be executed on the first + /// thread or the last thread + bool execFirst_{true}; + /// Number of total threads + int numThreads_; +}; + +/** + * A scoped locking primitive. When you get this object + * it means that you already have the lock. You can also + * wait, notify etc using this primitive + */ +class ConditionGuardImpl { + public: + /// Release the lock and wait for the timeout + /// After the wait is over, lock is reacquired + void wait(int timeoutMillis = -1); + /// Notify all the threads waiting on the lock + void notifyAll(); + /// Notify one thread waiting on the lock + void notifyOne(); + /// Delete the copy constructor + ConditionGuardImpl(const ConditionGuardImpl &that) = delete; + /// Delete the copy assignment operator + ConditionGuardImpl &operator=(const ConditionGuardImpl &that) = delete; + /// Move constructor for the guard + ConditionGuardImpl(ConditionGuardImpl &&that) noexcept; + /// Move assignment operator deleted + ConditionGuardImpl &operator=(ConditionGuardImpl &&that) = delete; + /// Destructor that releases the lock, you would explicitly + /// need to notify any other threads waiting in the wait() + ~ConditionGuardImpl(); + + protected: + friend class ConditionGuard; + /// Constructor that takes the shared mutex and condition + /// variable + ConditionGuardImpl(std::mutex &mutex, std::condition_variable &cv); + /// Instance of lock is made on construction with the specified mutex + std::unique_lock *lock_{nullptr}; + /// Shared condition variable + std::condition_variable &cv_; +}; + +/** + * Class for simplifying the primitive to take a lock + * in conjunction with the ability to do things + * on a condition variable based on the lock. + * Use the condition guard like this + * ConditionGuard condition; + * auto guard = condition.acquire(); + * guard.wait(); + */ +class ConditionGuard { + public: + /// Caller has to call acquire before doing anything + ConditionGuardImpl acquire(); + + /// Default constructor + ConditionGuard() { + } + + /// Deleted copy constructor + ConditionGuard(const ConditionGuard &that) = delete; + + /// Deleted assignment operator + ConditionGuard &operator=(const ConditionGuard &that) = delete; + + private: + /// Mutex for the condition variable + std::mutex mutex_; + /// std condition variable to support the functionality + std::condition_variable cv_; +}; + +/** + * A barrier primitive. When called for executing + * will block the threads till all the threads registered + * call execute() + */ +class Barrier { + public: + /// Deleted copy constructor + Barrier(const Barrier &that) = delete; + + /// Deleted assignment operator + Barrier &operator=(const Barrier &that) = delete; + + /// Constructor which takes total number of threads + /// to be hit in order for the barrier to clear + explicit Barrier(int numThreads) { + numThreads_ = numThreads; + VLOG(1) << "making barrier with " << numThreads; + } + + /// Executes the main functionality of the barrier + void execute(); + + /** + * Thread controller should call this method when one thread + * has been finished, since that thread will no longer be + * participating in the barrier + */ + void deRegister(); + + private: + /// Checks for finish, need to hold a lock to call this method + bool checkForFinish(); + /// Condition variable that threads wait on + std::condition_variable cv_; + + /// Number of threads entered the execute + int64_t numHits_{0}; + + /// Total number of threads that are supposed + /// to hit the barrier + int numThreads_{0}; + + /// Thread synchronization mutex + std::mutex mutex_; + + /// Represents the completion of barrier + bool isComplete_{false}; +}; + +/** + * Different stages of the simple funnel + * FUNNEL_START the state of funnel at the beginning + * FUNNEL_PROGRESS is set by the first thread to enter the funnel + * and it means that funnel functionality is in progress + * FUNNEL_END means that funnel functionality has been executed + */ +enum FunnelStatus { FUNNEL_START, FUNNEL_PROGRESS, FUNNEL_END }; + +/** + * Primitive that makes the threads execute in a funnel + * manner. Only one thread gets to execute the main functionality + * while other entering threads wait (while executing a function) + */ +class Funnel { + public: + /// Deleted copy constructor + Funnel(const Funnel &that) = delete; + + /// Default constructor for funnel + Funnel() { + status_ = FUNNEL_START; + } + + /// Deleted assignment operator + Funnel &operator=(const Funnel &that) = delete; + + /** + * Get the current status of funnel. + * If the status is FUNNEL_START it gets set + * to FUNNEL_PROGRESS else it is just a get + */ + FunnelStatus getStatus(); + + /// Threads in progress can wait indefinitely + void wait(); + + /// Threads that get status as progress execute this function + void wait(int32_t waitingTime); + + /** + * The first thread that was able to start the funnel + * calls this method on successful execution + */ + void notifySuccess(); + + /// The first thread that was able to start the funnel + /// calls this method on failure in execution + void notifyFail(); + + private: + /// Status of the funnel + FunnelStatus status_; + /// Mutex for the simple funnel executor + std::mutex mutex_; + /// Condition variable on which progressing threads wait + std::condition_variable cv_; +}; + +/** + * Controller class responsible for the receiver + * and sender threads. Manages the states of threads and + * session information + */ +class ThreadsController { + public: + /// Constructor that takes in the total number of threads + /// to be run + explicit ThreadsController(int totalThreads); + + /// Make threads of a type Sender/Receiver + template + std::vector> makeThreads( + WdtBaseType *wdtParent, int numThreads, + const std::vector &ports) { + std::vector> threads; + for (int threadIndex = 0; threadIndex < numThreads; ++threadIndex) { + threads.emplace_back(folly::make_unique( + wdtParent, threadIndex, ports[threadIndex], this)); + } + return threads; + } + /// Mark the state of a thread + void markState(int threadIndex, ThreadStatus state); + + /// Get the status of the thread by index + ThreadStatus getState(int threadIndex); + + /// Execute a function func once, by the first thread + template + void executeAtStart(FunctionType &&fn) const { + execAtStart_->execute(fn); + } + + /// Execute a function once by the last thread + template + void executeAtEnd(FunctionType &&fn) const { + execAtEnd_->execute(fn); + } + + /// Returns a funnel executor shared between the threads + /// If the executor does not exist then it creates one + std::shared_ptr getFunnel(const uint64_t funnelIndex); + + /// Returns a barrier shared between the threads + /// If the executor does not exist then it creates one + std::shared_ptr getBarrier(const uint64_t barrierIndex); + + /// Get the condition variable wrapper + std::shared_ptr getCondition(const uint64_t conditionIndex); + + /* + * Returns back states of all the threads + */ + std::unordered_map getThreadStates() const; + + /// Register a thread, a thread registers with the state RUNNING + void registerThread(int threadIndex); + + /// De-register a thread, marks it ended + void deRegisterThread(int threadIndex); + + /// Returns true if any thread apart from the calling is in the state + bool hasThreads(int threadIndex, ThreadStatus threadState); + + /// Get the nunber of registered threads + int getTotalThreads(); + + /// Reset the thread controller so that same instance can be used again + void reset(); + + /// Set the total number of barriers + void setNumBarriers(int numBarriers); + + /// Set the number of condition wrappers + void setNumConditions(int numConditions); + + /// Set total number of funnel executors + void setNumFunnels(int numFunnels); + + /// Destructor for the threads controller + ~ThreadsController() { + } + + private: + /// Total number of threads managed by the thread controller + int totalThreads_; + + typedef std::unique_lock GuardLock; + + /// Mutex used in all of the thread controller methods + mutable std::mutex controllerMutex_; + + /// States of the threads + std::unordered_map threadStateMap_; + + /// Executor to execute things at the start of transfer + std::unique_ptr execAtStart_; + + /// Executor to execute things at the end of transfer + std::unique_ptr execAtEnd_; + + /// Vector of funnel executors, read/modified by get/set funnel methods + std::vector> funnelExecutors_; + + /// Vector of condition wrappers, read/modified by get/set condition methods + std::vector> conditionGuards_; + + /// Vector of barriers, can be read/modified by get/set barrier methods + std::vector> barriers_; +}; +} +} diff --git a/ThreadsControllerTest.cpp b/ThreadsControllerTest.cpp new file mode 100644 index 00000000..c9ad55e4 --- /dev/null +++ b/ThreadsControllerTest.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#include "ThreadsController.h" +#include +#include +#include +#include +#include +#include +#include +using namespace std; +namespace facebook { +namespace wdt { +class ThreadUtil { + public: + template + void addThread(Fn&& func) { + threads_.emplace_back(func); + } + void joinThreads() { + for (auto& t : threads_) { + t.join(); + } + threads_.clear(); + } + + void threadSleep(int64_t millis) { + auto waitingTime = chrono::milliseconds(millis); + LOG(INFO) << "Sleeping for " << millis << " ms"; + unique_lock lock(mutex_); + cv_.wait_for(lock, waitingTime); + } + + void notifyThreads() { + cv_.notify_all(); + } + + ~ThreadUtil() { + joinThreads(); + } + + private: + vector threads_; + mutex mutex_; + condition_variable cv_; +}; + +TEST(ThreadsController, Barrier) { + int numThreads = 15; + { + Barrier barrier(numThreads); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&]() { barrier.execute(); }); + } + } + { + Barrier barrier(numThreads); + srand(time(nullptr)); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&barrier, &threadUtil, i]() { + if (i % 2 == 0) { + int seconds = folly::Random::rand32() % 5; + threadUtil.threadSleep(seconds * 100); + barrier.deRegister(); + return; + } + barrier.execute(); + }); + } + } + { + Barrier barrier(numThreads); + ThreadUtil threadUtil; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&barrier, &threadUtil, i]() { + switch (i) { + case 1: + threadUtil.threadSleep(1 * 100); + barrier.execute(); + return; + case 2: + threadUtil.threadSleep(2 * 100); + barrier.deRegister(); + return; + case 3: + threadUtil.threadSleep(5 * 100); + barrier.execute(); + return; + default: + barrier.execute(); + return; + }; + }); + } + } +} + +TEST(ThreadsController, ExecutOnceFunc) { + int numThreads = 8; + ExecuteOnceFunc execAtStart(numThreads, true); + ExecuteOnceFunc execAtEnd(numThreads, false); + ThreadUtil threadUtil; + int result = 0; + for (int i = 0; i < numThreads; i++) { + threadUtil.addThread([&]() { + execAtStart.execute([&]() { ++result; }); + execAtEnd.execute([&]() { ++result; }); + }); + } + threadUtil.joinThreads(); + EXPECT_EQ(result, 2); +} +} +} diff --git a/TransferLogManager.cpp b/TransferLogManager.cpp index f61dd068..dec58236 100644 --- a/TransferLogManager.cpp +++ b/TransferLogManager.cpp @@ -256,7 +256,6 @@ void TransferLogManager::openLog() { std::move(std::thread(&TransferLogManager::writeEntriesToDisk, this)); LOG(INFO) << "Log writer thread started"; } - return; } void TransferLogManager::closeLog() { @@ -326,7 +325,8 @@ bool TransferLogManager::verifySenderIp(const std::string &curSenderIp) { bool verifySuccessful = true; if (!options.disable_sender_verification_during_resumption) { if (senderIp_.empty()) { - LOG(INFO) << "Sender-ip empty, not verifying sender-ip"; + LOG(INFO) << "Sender-ip empty, not verifying sender-ip, new-ip: " + << curSenderIp; } else if (senderIp_ != curSenderIp) { LOG(ERROR) << "Current sender ip does not match ip in the " "transfer log " << curSenderIp << " " << senderIp_ @@ -579,6 +579,7 @@ ErrorCode LogParser::processHeaderEntry(char *buf, int size, int64_t logConfig; if (!encoderDecoder_.decodeLogHeader(buf, size, timestamp, logVersion, logRecoveryId, senderIp, logConfig)) { + LOG(ERROR) << "Couldn't decode the log header"; return INVALID_LOG; } if (logVersion != TransferLogManager::LOG_VERSION) { @@ -894,6 +895,7 @@ ErrorCode LogParser::parseLog(int fd, std::string &senderIp, return INVALID_LOG; } if (status == INVALID_LOG) { + LOG(ERROR) << "Invalid transfer log"; return status; } if (status == INCONSISTENT_DIRECTORY) { diff --git a/WdtBase.cpp b/WdtBase.cpp index 6a953502..6e595de9 100644 --- a/WdtBase.cpp +++ b/WdtBase.cpp @@ -351,6 +351,7 @@ WdtBase::WdtBase() : abortCheckerCallback_(this) { WdtBase::~WdtBase() { abortChecker_ = nullptr; + delete threadsController_; } std::vector WdtBase::genPortsVector(int32_t startPort, @@ -407,6 +408,10 @@ void WdtBase::setThrottler(std::shared_ptr throttler) { throttler_ = throttler; } +std::shared_ptr WdtBase::getThrottler() const { + return throttler_; +} + void WdtBase::setTransferId(const std::string& transferId) { transferId_ = transferId; LOG(INFO) << "Setting transfer id " << transferId_; @@ -423,6 +428,10 @@ void WdtBase::setProtocolVersion(int64_t protocol) { LOG(INFO) << "using wdt protocol version " << protocolVersion_; } +int WdtBase::getProtocolVersion() const { + return protocolVersion_; +} + std::string WdtBase::getTransferId() { return transferId_; } diff --git a/WdtBase.h b/WdtBase.h index a1916c19..3b114c9b 100644 --- a/WdtBase.h +++ b/WdtBase.h @@ -15,6 +15,7 @@ #include "Reporting.h" #include "Throttler.h" #include "Protocol.h" +#include "WdtThread.h" #include #include #include @@ -227,6 +228,9 @@ class WdtBase { /// Sets the protocol version for the transfer void setProtocolVersion(int64_t protocolVersion); + /// Get the protocol version of the transfer + int getProtocolVersion() const; + /// Get the transfer id of the object std::string getTransferId(); @@ -243,19 +247,8 @@ class WdtBase { /// Utility to generate a random transfer id static std::string generateTransferId(); - protected: - /// Global throttler across all threads - std::shared_ptr throttler_; - - /// Holds the instance of the progress reporter default or customized - std::unique_ptr progressReporter_; - - /// Unique id for the transfer - std::string transferId_; - - /// protocol version to use, this is determined by negotiating protocol - /// version with the other side - int protocolVersion_{Protocol::protocol_version}; + /// Get the throttler + std::shared_ptr getThrottler() const; /// abort checker class passed to socket functions class AbortChecker : public IAbortChecker { @@ -271,9 +264,29 @@ class WdtBase { WdtBase* wdtBase_; }; + protected: + /// Ports that the sender/receiver is running on + std::vector ports_; + + /// Global throttler across all threads + std::shared_ptr throttler_; + + /// Holds the instance of the progress reporter default or customized + std::unique_ptr progressReporter_; + + /// Unique id for the transfer + std::string transferId_; + + /// protocol version to use, this is determined by negotiating protocol + /// version with the other side + int protocolVersion_{Protocol::protocol_version}; + /// abort checker passed to socket functions AbortChecker abortCheckerCallback_; + /// Controller for wdt threads shared between base and threads + ThreadsController* threadsController_{nullptr}; + private: folly::RWSpinLock abortCodeLock_; /// Internal and default abort code diff --git a/WdtConfig.h b/WdtConfig.h index 281ef12a..c0610382 100644 --- a/WdtConfig.h +++ b/WdtConfig.h @@ -7,10 +7,10 @@ #include #define WDT_VERSION_MAJOR 1 -#define WDT_VERSION_MINOR 21 -#define WDT_VERSION_BUILD 1510120 +#define WDT_VERSION_MINOR 22 +#define WDT_VERSION_BUILD 1510210 // Add -fbcode to version str -#define WDT_VERSION_STR "1.21.1510120-fbcode" +#define WDT_VERSION_STR "1.22.1510210-fbcode" // Tie minor and proto version #define WDT_PROTOCOL_VERSION WDT_VERSION_MINOR diff --git a/WdtThread.cpp b/WdtThread.cpp new file mode 100644 index 00000000..c1f0a523 --- /dev/null +++ b/WdtThread.cpp @@ -0,0 +1,48 @@ +#include "WdtThread.h" +using namespace std; +namespace facebook { +namespace wdt { +TransferStats WdtThread::moveStats() { + return std::move(threadStats_); +} + +const PerfStatReport& WdtThread::getPerfReport() const { + return perfReport_; +} + +const TransferStats& WdtThread::getTransferStats() const { + return threadStats_; +} + +void WdtThread::startThread() { + if (threadPtr_) { + WDT_CHECK(false) << "There is a already a thread running " << threadIndex_ + << " " << getPort(); + } + auto state = controller_->getState(threadIndex_); + // Check the state should be running here + WDT_CHECK_EQ(state, RUNNING); + threadPtr_.reset(new std::thread(&WdtThread::start, this)); +} + +ErrorCode WdtThread::finish() { + if (!threadPtr_) { + LOG(ERROR) << "Finish called on an instance while no thread has been " + << " created to do any work"; + return ERROR; + } + threadPtr_->join(); + threadPtr_.reset(); + return OK; +} + +WdtThread::~WdtThread() { + if (threadPtr_) { + LOG(INFO) << threadIndex_ + << " has an alive thread while the instance is being " + << "destructed"; + finish(); + } +} +} +} diff --git a/WdtThread.h b/WdtThread.h new file mode 100644 index 00000000..4b1b3774 --- /dev/null +++ b/WdtThread.h @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2014-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once +#include "Reporting.h" +#include "ErrorCodes.h" +#include "ThreadsController.h" +#include +#include +namespace facebook { +namespace wdt { +class ThreadsController; +class WdtThread { + public: + /// Constructor for wdt thread + WdtThread(int threadIndex, int protocolVersion, ThreadsController *controller) + : threadIndex_(threadIndex), threadProtocolVersion_(protocolVersion) { + controller_ = controller; + } + /// Starts a thread which runs the wdt functionality + void startThread(); + + /// Get the perf stats of the transfer for this thread + const PerfStatReport &getPerfReport() const; + + /// Initializes the wdt thread before starting + virtual ErrorCode init() = 0; + + /// Conclude the thread transfer + virtual ErrorCode finish(); + + /// Moves the local stats into a new instance + TransferStats moveStats(); + + /// Get the transfer stats recorded by this thread + const TransferStats &getTransferStats() const; + + /// Reset the wdt thread + virtual void reset() = 0; + + /// Get the port this thread is running on + virtual int getPort() const = 0; + + // TODO remove this function + virtual int getNegotiatedProtocol() const { + return threadProtocolVersion_; + } + + virtual ~WdtThread(); + + protected: + /// The main entry point of the thread + virtual void start() = 0; + + /// Index of this thread with respect to other threads + const int threadIndex_; + + /// Copy of the protocol version that might be changed + int threadProtocolVersion_; + + /// Transfer stats for this thread + TransferStats threadStats_{true}; + + /// Perf stats report for this thread + PerfStatReport perfReport_; + + /// Thread controller for all the sender threads + ThreadsController *controller_{nullptr}; + + /// Pointer to the std::thread executing the transfer + std::unique_ptr threadPtr_{nullptr}; +}; +} +} diff --git a/wdt_global_checkpoint_test.sh b/wdt_global_checkpoint_test.sh new file mode 100644 index 00000000..36b191f8 --- /dev/null +++ b/wdt_global_checkpoint_test.sh @@ -0,0 +1,38 @@ +#! /bin/bash + +set -o pipefail + +source `dirname "$0"`/common_functions.sh + +BASEDIR=/dev/shm/wdtTest_$USER +mkdir -p "$BASEDIR" +DIR=`mktemp -d --tmpdir=$BASEDIR` +echo "Testing in $DIR" + +mkdir "$DIR/src" +# create a big file +dd if=/dev/zero of="$DIR/src/file" bs=536870912 count=1 + +TEST_COUNT=0 +# start the server +_bin/wdt/wdt -skip_writes -num_ports=2 -transfer_id=wdt \ +-connect_timeout_millis 100 -read_timeout_millis=200 > "$DIR/server${TEST_COUNT}.log" 2>&1 & +pidofreceiver=$! + +# block 22356 for small duration so that file is transferred through 22357 +blockDportByDropping 22356 +# start client +_bin/wdt/wdt -destination localhost -directory "$DIR/src" -num_ports=2 \ +-block_size_mbytes=-1 -avg_mbytes_per_sec=100 -transfer_id=wdt 2>&1 | \ +tee "$DIR/client${TEST_COUNT}.log" & +pidofsender=$! +sleep 0.1 +undoLastIpTableChange +sleep 5 +# block 22357 in the middle +blockDportByDropping 22357 +waitForTransferEnd + +echo "Test passed, deleting directory $DIR" +rm -rf "$DIR" +wdtExit 0