Skip to content

Commit

Permalink
[Optim] - Avoid copying arguments in parallelTransform of threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
sjanel committed Mar 10, 2024
1 parent b478810 commit 1d1bddb
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 47 deletions.
3 changes: 2 additions & 1 deletion src/api-objects/include/withdrawinfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class DeliveredWithdrawInfo {
template <>
struct fmt::formatter<cct::DeliveredWithdrawInfo> {
constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) {
auto it = ctx.begin(), end = ctx.end();
const auto it = ctx.begin();
const auto end = ctx.end();
if (it != end && *it != '}') {
throw format_error("invalid format");
}
Expand Down
2 changes: 1 addition & 1 deletion src/engine/include/parseoptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ auto ParseOptions(ParserType &parser, int argc, const char *argv[]) {
groupParsedOptions.mergeGlobalWith(globalOptions);
}

return parsedOptions;
return std::make_pair(std::move(programName), parsedOptions);
}
} // namespace cct
108 changes: 72 additions & 36 deletions src/engine/src/exchangesorchestrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "cct_log.hpp"
#include "cct_smallvector.hpp"
#include "cct_string.hpp"
#include "cct_type_traits.hpp"
#include "commonapi.hpp"
#include "currencycode.hpp"
#include "currencyexchangeflatset.hpp"
Expand All @@ -31,7 +32,6 @@
#include "exchangeretriever.hpp"
#include "exchangeretrieverbase.hpp"
#include "market.hpp"
#include "marketorderbook.hpp"
#include "monetaryamount.hpp"
#include "monetaryamountbycurrencyset.hpp"
#include "ordersconstraints.hpp"
Expand All @@ -56,16 +56,43 @@ template <class MainVec>
void FilterVector(MainVec &main, std::span<const bool> considerSpan) {
const auto begIt = main.begin();
const auto endIt = main.end();

main.erase(std::remove_if(begIt, endIt, [=](const auto &val) { return !considerSpan[&val - &*begIt]; }), endIt);
}

using ExchangeAmountPairVector = SmallVector<std::pair<Exchange *, MonetaryAmount>, kTypicalNbPrivateAccounts>;
using ExchangeAmountMarketsPathVector =
SmallVector<std::tuple<Exchange *, MonetaryAmount, MarketsPath>, kTypicalNbPrivateAccounts>;
using ExchangeAmountToCurrency = std::tuple<Exchange *, MonetaryAmount, CurrencyCode, MarketsPath>;
using ExchangeAmountToCurrencyToAmount =
std::tuple<Exchange *, MonetaryAmount, CurrencyCode, MarketsPath, MonetaryAmount>;

struct ExchangeAmountMarkets {
Exchange *exchange;
MonetaryAmount amount;
MarketsPath marketsPath;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountMarketsPathVector = SmallVector<ExchangeAmountMarkets, kTypicalNbPrivateAccounts>;

struct ExchangeAmountToCurrency {
Exchange *exchange;
MonetaryAmount amount;
CurrencyCode currency;
MarketsPath marketsPath;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountToCurrencyVector = SmallVector<ExchangeAmountToCurrency, kTypicalNbPrivateAccounts>;

struct ExchangeAmountToCurrencyToAmount {
Exchange *exchange;
MonetaryAmount amount;
CurrencyCode currency;
MarketsPath marketsPath;
MonetaryAmount endAmount;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountToCurrencyToAmountVector = SmallVector<ExchangeAmountToCurrencyToAmount, kTypicalNbPrivateAccounts>;

template <class VecWithExchangeFirstPos>
Expand Down Expand Up @@ -138,12 +165,12 @@ MarketOrderBookConversionRates ExchangesOrchestrator::getMarketOrderBooks(Market
equiCurrencyCode.isNeutral()
? std::nullopt
: exchange->apiPublic().estimatedConvert(MonetaryAmount(1, mk.quote()), equiCurrencyCode);
MarketOrderBook marketOrderBook(depth ? exchange->queryOrderBook(mk, *depth) : exchange->queryOrderBook(mk));
if (!optConversionRate && !equiCurrencyCode.isNeutral()) {
log::warn("Unable to convert {} into {} on {}", marketOrderBook.market().quote(), equiCurrencyCode,
exchange->name());
log::warn("Unable to convert {} into {} on {}", mk.quote(), equiCurrencyCode, exchange->name());
}
return std::make_tuple(exchange->name(), std::move(marketOrderBook), optConversionRate);
return std::make_tuple(exchange->name(),
depth ? exchange->queryOrderBook(mk, *depth) : exchange->queryOrderBook(mk),
optConversionRate);
};
_threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), marketOrderBooksFunc);
return ret;
Expand All @@ -157,17 +184,17 @@ BalancePerExchange ExchangesOrchestrator::getBalance(std::span<const ExchangeNam
log::info("Query balance from {}{}{} with{} balance in use", ConstructAccumulatedExchangeNames(privateExchangeNames),
equiCurrency.isNeutral() ? "" : " with equi currency ", equiCurrency, withBalanceInUse ? "" : "out");

ExchangeRetriever::SelectedExchanges balanceExchanges =
ExchangeRetriever::SelectedExchanges selectedExchanges =
_exchangeRetriever.select(ExchangeRetriever::Order::kInitial, privateExchangeNames);

SmallVector<BalancePortfolio, kTypicalNbPrivateAccounts> balancePortfolios(balanceExchanges.size());
SmallVector<BalancePortfolio, kTypicalNbPrivateAccounts> balancePortfolios(selectedExchanges.size());
_threadPool.parallelTransform(
balanceExchanges.begin(), balanceExchanges.end(), balancePortfolios.begin(),
selectedExchanges.begin(), selectedExchanges.end(), balancePortfolios.begin(),
[&balanceOptions](Exchange *exchange) { return exchange->apiPrivate().getAccountBalance(balanceOptions); });

BalancePerExchange ret;
ret.reserve(balanceExchanges.size());
std::transform(balanceExchanges.begin(), balanceExchanges.end(), std::make_move_iterator(balancePortfolios.begin()),
ret.reserve(selectedExchanges.size());
std::transform(selectedExchanges.begin(), selectedExchanges.end(), std::make_move_iterator(balancePortfolios.begin()),
std::back_inserter(ret), [](Exchange *exchange, BalancePortfolio &&balancePortfolio) {
return std::make_pair(exchange, std::move(balancePortfolio));
});
Expand Down Expand Up @@ -483,26 +510,35 @@ TradeResultPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, ExchangeAm
ExchangeAmountMarketsPathVector::iterator last, CurrencyCode toCurrency,
const TradeOptions &tradeOptions) {
TradeResultPerExchange tradeResultPerExchange(std::distance(first, last));
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(), [toCurrency, &tradeOptions](auto &tuple) {
Exchange *exchange = std::get<0>(tuple);
const MonetaryAmount from = std::get<1>(tuple);
TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, std::get<2>(tuple));
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(),
[toCurrency, &tradeOptions](ExchangeAmountMarkets &exchangeAmountMarketsPath) {
Exchange *exchange = exchangeAmountMarketsPath.exchange;
const MonetaryAmount from = exchangeAmountMarketsPath.amount;
const auto &marketsPath = exchangeAmountMarketsPath.marketsPath;

TradedAmounts tradedAmounts =
exchange->apiPrivate().trade(from, toCurrency, tradeOptions, marketsPath);

return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
return tradeResultPerExchange;
}

template <class Iterator>
TradeResultPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, Iterator first, Iterator last,
const TradeOptions &tradeOptions) {
TradeResultPerExchange tradeResultPerExchange(std::distance(first, last));
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(), [&tradeOptions](auto &tuple) {
Exchange *exchange = std::get<0>(tuple);
const MonetaryAmount from = std::get<1>(tuple);
const CurrencyCode toCurrency = std::get<2>(tuple);
TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, std::get<3>(tuple));
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
using ObjType = decltype(*first);
threadPool.parallelTransform(
first, last, tradeResultPerExchange.begin(), [&tradeOptions](ObjType &exchangeAmountMarketsPath) {
Exchange *exchange = exchangeAmountMarketsPath.exchange;
const MonetaryAmount from = exchangeAmountMarketsPath.amount;
const CurrencyCode toCurrency = exchangeAmountMarketsPath.currency;
const auto &marketsPath = exchangeAmountMarketsPath.marketsPath;

TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, marketsPath);
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
return tradeResultPerExchange;
}

Expand Down Expand Up @@ -549,17 +585,17 @@ TradeResultPerExchange ExchangesOrchestrator::trade(MonetaryAmount from, bool is
if (!exchangeAmountMarketsPathVector.empty()) {
// Sort exchanges from largest to lowest available amount (should be after filter on markets and conversion paths)
std::ranges::stable_sort(exchangeAmountMarketsPathVector,
[](const auto &lhs, const auto &rhs) { return std::get<1>(lhs) > std::get<1>(rhs); });
[](const auto &lhs, const auto &rhs) { return lhs.amount > rhs.amount; });

// Locate the point where there is enough available amount to trade for this currency
if (isPercentageTrade) {
MonetaryAmount totalAvailableAmount = std::accumulate(
exchangeAmountMarketsPathVector.begin(), exchangeAmountMarketsPathVector.end(), currentTotalAmount,
[](MonetaryAmount tot, const auto &tuple) { return tot + std::get<1>(tuple); });
MonetaryAmount totalAvailableAmount =
std::accumulate(exchangeAmountMarketsPathVector.begin(), exchangeAmountMarketsPathVector.end(),
currentTotalAmount, [](MonetaryAmount tot, const auto &tuple) { return tot + tuple.amount; });
from = (totalAvailableAmount * from.toNeutral()) / 100;
}
for (auto endIt = exchangeAmountMarketsPathVector.end(); it != endIt && currentTotalAmount < from; ++it) {
MonetaryAmount &amount = std::get<1>(*it);
MonetaryAmount &amount = it->amount;
if (currentTotalAmount + amount > from) {
// Cap last amount such that total start trade on all exchanges reaches exactly 'startAmount'
amount = from - currentTotalAmount;
Expand Down Expand Up @@ -628,8 +664,8 @@ TradeResultPerExchange ExchangesOrchestrator::smartBuy(MonetaryAmount endAmount,
}
MonetaryAmount avAmount = balance.get(fromCurrency);
if (avAmount > 0 &&
std::none_of(trades.begin(), trades.begin() + nbTrades, [pExchange, fromCurrency](const auto &tuple) {
return std::get<0>(tuple) == pExchange && std::get<1>(tuple).currencyCode() == fromCurrency;
std::none_of(trades.begin(), trades.begin() + nbTrades, [pExchange, fromCurrency](const auto &obj) {
return obj.exchange == pExchange && obj.amount.currencyCode() == fromCurrency;
})) {
auto conversionPath = exchangePublic.findMarketsPath(fromCurrency, toCurrency, markets, fiats,
api::ExchangePublic::MarketPathMode::kStrict);
Expand All @@ -649,7 +685,7 @@ TradeResultPerExchange ExchangesOrchestrator::smartBuy(MonetaryAmount endAmount,
}
// Sort exchanges from largest to lowest end amount
std::stable_sort(trades.begin() + nbTrades, trades.end(),
[](const auto &lhs, const auto &rhs) { return std::get<4>(lhs) > std::get<4>(rhs); });
[](const auto &lhs, const auto &rhs) { return lhs.endAmount > rhs.endAmount; });
int nbTradesToKeep = 0;
for (auto &[pExchange, startAmount, tradeToCurrency, conversionPath, tradeEndAmount] : trades) {
if (tradeEndAmount > remEndAmount) {
Expand Down
8 changes: 3 additions & 5 deletions src/main/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <cstdlib>
#include <exception>
#include <filesystem>
#include <iostream>

#include "cct_invalid_argument_exception.hpp"
Expand All @@ -13,19 +12,18 @@
#include "runmodes.hpp"

int main(int argc, const char* argv[]) {
using namespace cct;
try {
using namespace cct;
auto parser =
CommandLineOptionsParser<CoincenterCmdLineOptions>(CoincenterAllowedOptions<CoincenterCmdLineOptions>::value);
const auto cmdLineOptionsVector = ParseOptions(parser, argc, argv);
const auto [programName, cmdLineOptionsVector] = ParseOptions(parser, argc, argv);

if (!cmdLineOptionsVector.empty()) {
const CoincenterCommands coincenterCommands(cmdLineOptionsVector);
const auto programName = std::filesystem::path(argv[0]).filename().string();

ProcessCommandsFromCLI(programName, coincenterCommands, cmdLineOptionsVector.front(), settings::RunMode::kProd);
}
} catch (const cct::invalid_argument& e) {
} catch (const invalid_argument& e) {
std::cerr << "Invalid argument: " << e.what() << '\n';
return EXIT_FAILURE;
} catch (const std::exception& e) {
Expand Down
2 changes: 2 additions & 0 deletions src/objects/include/coincenterinfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CoincenterInfo {

AbstractMetricGateway *metricGatewayPtr() const { return _metricGatewayPtr.get(); }

const GeneralConfig &generalConfig() const { return _generalConfig; }

const LoggingInfo &loggingInfo() const { return _generalConfig.loggingInfo(); }

const RequestsConfig &requestsConfig() const { return _generalConfig.requestsConfig(); }
Expand Down
14 changes: 11 additions & 3 deletions src/tech/include/threadpool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,24 @@ class ThreadPool {

auto nbWorkers() const noexcept { return _workers.size(); }

// add new work item to the pool
// Add new work item to the pool
// By default, arguments will be copied for safety. If you want to pass arguments by reference,
// make sure that the reference lifetime is valid through the whole execution time of the future,
// and wrap the argument you want to pass by reference with 'std::ref'.
template <class Func, class... Args>
std::future<std::invoke_result_t<Func, Args...>> enqueue(Func&& func, Args&&... args);

// Parallel version of std::transform with unary operation.
// This function will first enqueue all the tasks at one, using waiting threads of the thread pool,
// and then retrieves and moves the results to 'out', as for std::transform.
// Note: the objects passed in argument from InputIt are not copied and passed by reference (through
// std::reference_wrapper)
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt parallelTransform(InputIt first, InputIt last, OutputIt out, UnaryOperation unary_op);

// Parallel version of std::transform with binary operation.
// Note: the objects passed in argument from InputIt are not copied and passed by reference (through
// std::reference_wrapper)
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt parallelTransform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt out, BinaryOperation binary_op);

Expand Down Expand Up @@ -105,6 +112,7 @@ inline ThreadPool::~ThreadPool() {

template <class Func, class... Args>
inline std::future<std::invoke_result_t<Func, Args...>> ThreadPool::enqueue(Func&& func, Args&&... args) {
// std::bind copies the arguments. To avoid copies, you can use std::ref to copy reference instead.
using return_type = std::invoke_result_t<Func, Args...>;

auto task = std::make_shared<std::packaged_task<return_type()>>(
Expand All @@ -130,7 +138,7 @@ inline OutputIt ThreadPool::parallelTransform(InputIt first, InputIt last, Outpu
using FutureT = std::future<std::invoke_result_t<UnaryOperation, decltype(*first)>>;
SmallVector<FutureT, kTypicalNbPrivateAccounts> futures;
for (; first != last; ++first) {
futures.emplace_back(enqueue(unary_op, *first));
futures.emplace_back(enqueue(unary_op, std::ref(*first)));
}
return retrieveAllResults(futures, out);
}
Expand All @@ -141,7 +149,7 @@ inline OutputIt ThreadPool::parallelTransform(InputIt1 first1, InputIt1 last1, I
using FutureT = std::future<std::invoke_result_t<BinaryOperation, decltype(*first1), decltype(*first2)>>;
SmallVector<FutureT, kTypicalNbPrivateAccounts> futures;
for (; first1 != last1; ++first1, ++first2) {
futures.emplace_back(enqueue(binary_op, *first1, *first2));
futures.emplace_back(enqueue(binary_op, std::ref(*first1), std::ref(*first2)));
}
return retrieveAllResults(futures, out);
}
Expand Down
30 changes: 29 additions & 1 deletion src/tech/test/threadpool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <gtest/gtest.h>

#include <chrono>
#include <forward_list>
#include <future>
#include <numeric>
Expand All @@ -28,6 +27,19 @@ int SlowAdd(const int &lhs, const int &rhs) {
std::this_thread::sleep_for(10ms);
return lhs + rhs;
}

struct NonCopyable {
NonCopyable(int i = 0) : i(i) {}

NonCopyable(const NonCopyable &) = delete;

int i;
};

int SlowDoubleNonCopyable(const NonCopyable &val) {
std::this_thread::sleep_for(10ms);
return val.i * 2;
}
} // namespace

TEST(ThreadPoolTest, Enqueue) {
Expand All @@ -44,6 +56,22 @@ TEST(ThreadPoolTest, Enqueue) {
}
}

TEST(ThreadPoolTest, EnqueueNonCopyable) {
ThreadPool threadPool(2);
vector<std::future<int>> results;

constexpr int kNbElems = 4;
vector<NonCopyable> inputData(kNbElems);
for (int elem = 0; elem < kNbElems; ++elem) {
inputData[elem] = NonCopyable(elem);
results.push_back(threadPool.enqueue(SlowDoubleNonCopyable, std::ref(inputData[elem])));
}

for (int elem = 0; elem < kNbElems; ++elem) {
EXPECT_EQ(results[elem].get(), elem * 2);
}
}

TEST(ThreadPoolTest, ParallelTransformRandomInputIt) {
ThreadPool threadPool(4);
constexpr int kNbElems = 22;
Expand Down

0 comments on commit 1d1bddb

Please sign in to comment.