From 2f10f3be0fa8b99bb78926e5261a9eb45bb3f9cc Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Tue, 30 Apr 2024 17:24:51 -0700 Subject: [PATCH] try catch for protocols --- libOTe/Base/MasnyRindal.cpp | 14 +- libOTe/Base/MasnyRindalKyber.cpp | 14 +- libOTe/Base/McRosRoy.h | 14 +- libOTe/Base/McRosRoyTwist.h | 14 +- libOTe/Base/SimplestOT.cpp | 28 +- libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.cpp | 14 +- libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.cpp | 14 +- libOTe/NChooseOne/NcoOtExt.cpp | 28 +- libOTe/NChooseOne/Oos/OosNcoOtReceiver.cpp | 50 +- libOTe/NChooseOne/Oos/OosNcoOtSender.cpp | 35 +- libOTe/Tools/Pprf/RegularPprf.h | 2267 +++++++++-------- libOTe/TwoChooseOne/Kos/KosOtExtReceiver.cpp | 7 +- libOTe/TwoChooseOne/Kos/KosOtExtSender.cpp | 8 +- .../TwoChooseOne/KosDot/KosDotExtReceiver.cpp | 7 +- .../TwoChooseOne/KosDot/KosDotExtSender.cpp | 7 +- libOTe/TwoChooseOne/OTExtInterface.cpp | 43 +- .../Silent/SilentOtExtReceiver.cpp | 31 +- .../TwoChooseOne/Silent/SilentOtExtSender.cpp | 28 +- .../SoftSpokenOT/SoftSpokenMalOtExt.cpp | 14 +- .../SoftSpokenOT/SoftSpokenShOtExt.cpp | 15 +- libOTe/Vole/Noisy/NoisyVoleReceiver.h | 14 +- libOTe/Vole/Noisy/NoisyVoleSender.h | 16 +- libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp | 14 +- 23 files changed, 1503 insertions(+), 1193 deletions(-) diff --git a/libOTe/Base/MasnyRindal.cpp b/libOTe/Base/MasnyRindal.cpp index 24e543b..d66e7b1 100644 --- a/libOTe/Base/MasnyRindal.cpp +++ b/libOTe/Base/MasnyRindal.cpp @@ -26,7 +26,7 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { using namespace DefaultCurve; //MC_BEGIN(task<>, &choices, messages, &prng, &chl, @@ -91,10 +91,15 @@ namespace osuCrypto } } + catch (...) + { + chl.close(); + throw; + } task<> MasnyRindal::send(span> messages, PRNG& prng, Socket& chl) - { + try { using namespace DefaultCurve; Curve{}; // required to init relic @@ -143,5 +148,10 @@ namespace osuCrypto } } } + catch (...) + { + chl.close(); + throw; + } } #endif diff --git a/libOTe/Base/MasnyRindalKyber.cpp b/libOTe/Base/MasnyRindalKyber.cpp index 0a85dc7..7ef261e 100644 --- a/libOTe/Base/MasnyRindalKyber.cpp +++ b/libOTe/Base/MasnyRindalKyber.cpp @@ -12,7 +12,7 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { static_assert(std::is_trivial::value, ""); static_assert(std::is_trivial::value, ""); @@ -41,12 +41,17 @@ namespace osuCrypto memcpy(&messages[i], ot[i].rot, sizeof(block)); } } + catch (...) + { + chl.close(); + throw; + } task<> MasnyRindalKyber::send( span> messages, PRNG& prng, Socket& chl) - { + try { auto pkBuff = std::vector{}; auto ctxts = std::vector{}; auto ptxt = KyberOTPtxt{ }; @@ -72,5 +77,10 @@ namespace osuCrypto co_await chl.send(std::move(ctxts)); } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/Base/McRosRoy.h b/libOTe/Base/McRosRoy.h index f578975..fbb674b 100644 --- a/libOTe/Base/McRosRoy.h +++ b/libOTe/Base/McRosRoy.h @@ -116,7 +116,7 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { auto A = Point{}; auto sk = std::vector{}; @@ -159,13 +159,18 @@ namespace osuCrypto } } + catch (...) + { + chl.close(); + throw; + } template task<> McRosRoy::send( span> msg, PRNG& prng, Socket& chl) - { + try { Curve{}; // init relic auto A = Point{}; @@ -211,6 +216,11 @@ namespace osuCrypto ro.Final(msg[i][1]); } } + catch (...) + { + chl.close(); + throw; + } } } diff --git a/libOTe/Base/McRosRoyTwist.h b/libOTe/Base/McRosRoyTwist.h index 780d0ed..a7da08b 100644 --- a/libOTe/Base/McRosRoyTwist.h +++ b/libOTe/Base/McRosRoyTwist.h @@ -138,7 +138,7 @@ namespace osuCrypto template inline task<> McRosRoyTwist::receive(const BitVector& choices, span messages, PRNG& prng, Socket& chl) - { + try { auto n = choices.size(); auto sk = std::vector{}; auto curveChoice = std::vector{}; @@ -179,10 +179,15 @@ namespace osuCrypto ro.Final(messages[i]); } } + catch (...) + { + chl.close(); + throw; + } template inline task<> McRosRoyTwist::send(span> msg, PRNG& prng, Socket& chl) - { + try { auto n = static_cast(msg.size()); auto sk = Scalar25519(prng); @@ -225,6 +230,11 @@ namespace osuCrypto ro.Final(msg[i][1]); } } + catch (...) + { + chl.close(); + throw; + } template inline typename McRosRoyTwist::Monty25519 McRosRoyTwist::blockToCurve(Block256 b) diff --git a/libOTe/Base/SimplestOT.cpp b/libOTe/Base/SimplestOT.cpp index 05e200e..5b9b8e1 100644 --- a/libOTe/Base/SimplestOT.cpp +++ b/libOTe/Base/SimplestOT.cpp @@ -15,7 +15,7 @@ namespace osuCrypto span msg, PRNG& prng, Socket& chl) - { + try { using namespace DefaultCurve; Curve{};// init relic auto buff = std::vector{}; @@ -75,12 +75,17 @@ namespace osuCrypto ro.Final(msg[i]); } } + catch (...) + { + chl.close(); + throw; + } task<> SimplestOT::send( span> msg, PRNG& prng, Socket& chl) - { + try { using namespace DefaultCurve; Curve{}; // init relic @@ -142,6 +147,11 @@ namespace osuCrypto ro.Final(msg[i][1]); } } + catch (...) + { + chl.close(); + throw; + } } #endif @@ -397,7 +407,7 @@ namespace osuCrypto span msg, PRNG& prng, Socket& chl) - { + try { auto rs = AlginedState(); auto rd = rs->recvData(); co_await chl.recv(rd); @@ -410,12 +420,17 @@ namespace osuCrypto rs->gen4(i, msg); } } + catch (...) + { + chl.close(); + throw; + } task<> AsmSimplestOT::send( span> msg, PRNG& prng, Socket& chl) - { + try { auto ss = AlginedState(); auto sd = ss->init(prng); @@ -428,5 +443,10 @@ namespace osuCrypto ss->gen4(i, msg); } } + catch (...) + { + chl.close(); + throw; + } } #endif diff --git a/libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.cpp b/libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.cpp index b6d244b..e34c809 100644 --- a/libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.cpp +++ b/libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.cpp @@ -33,7 +33,7 @@ namespace osuCrypto task<> KkrtNcoOtReceiver::init(u64 numOtExt, PRNG& prng, Socket& chl) - { + try { if (hasBaseOts() == false) co_await(genBaseOts(prng, chl)); @@ -140,6 +140,11 @@ namespace osuCrypto PRNG(theirSeed).get(keys.data(), keys.size()); mMultiKeyAES.setKeys(keys); } + catch (...) + { + chl.close(); + throw; + } u64 KkrtNcoOtReceiver::getBaseOTCount() const @@ -311,7 +316,7 @@ namespace osuCrypto } task<> KkrtNcoOtReceiver::sendCorrection(Socket& chl, u64 sendCount) - { + try { #ifndef NDEBUG // make sure these OTs all contain valid correction values, aka encode has been called. @@ -334,6 +339,11 @@ namespace osuCrypto co_await chl.send(std::move(sub)); } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.cpp b/libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.cpp index 262d7bb..0fe722d 100644 --- a/libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.cpp +++ b/libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.cpp @@ -67,7 +67,7 @@ namespace osuCrypto task<> KkrtNcoOtSender::init( u64 numOTExt, PRNG& prng, Socket& chl) - { + try { @@ -162,6 +162,11 @@ namespace osuCrypto doneIdx = stopIdx; } } + catch (...) + { + chl.close(); + throw; + } void KkrtNcoOtSender::encode(u64 otIdx, const void* input, void* dest, u64 destSize) { @@ -258,7 +263,7 @@ namespace osuCrypto } task<> KkrtNcoOtSender::recvCorrection(Socket& chl, u64 recvCount) - { + try { #ifndef NDEBUG if (recvCount > mCorrectionVals.bounds()[0] - mCorrectionIdx) @@ -272,5 +277,10 @@ namespace osuCrypto mCorrectionIdx += recvCount; co_await chl.recv(span(&*dest, recvCount * mCorrectionVals.stride())); } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/NChooseOne/NcoOtExt.cpp b/libOTe/NChooseOne/NcoOtExt.cpp index c6107ed..4c543bc 100644 --- a/libOTe/NChooseOne/NcoOtExt.cpp +++ b/libOTe/NChooseOne/NcoOtExt.cpp @@ -15,7 +15,7 @@ namespace osuCrypto { task<> NcoOtExtReceiver::genBaseOts(PRNG& prng, Socket& chl) - { + try { struct TT { #ifdef ENABLE_IKNP @@ -59,10 +59,15 @@ namespace osuCrypto co_await(setBaseOts(msgs, prng, chl)); } + catch (...) + { + chl.close(); + throw; + } task<> NcoOtExtSender::genBaseOts(PRNG& prng, Socket& chl) - { + try { struct TT { #ifdef ENABLE_IKNP @@ -108,9 +113,14 @@ namespace osuCrypto co_await(setBaseOts(msgs, bv, chl)); } + catch (...) + { + chl.close(); + throw; + } task<> NcoOtExtSender::sendChosen(MatrixView messages, PRNG& prng, Socket& chl) - { + try { auto temp = Matrix{}; if (hasBaseOts() == false) @@ -143,12 +153,17 @@ namespace osuCrypto co_await(chl.send(std::move(temp))); } + catch (...) + { + chl.close(); + throw; + } task<> NcoOtExtReceiver::receiveChosen( u64 numMsgsPerOT, span messages, span choices, PRNG& prng, Socket& chl) - { + try { auto temp = Matrix{}; if (hasBaseOts() == false) @@ -181,5 +196,10 @@ namespace osuCrypto } } + catch (...) + { + chl.close(); + throw; + } } #endif diff --git a/libOTe/NChooseOne/Oos/OosNcoOtReceiver.cpp b/libOTe/NChooseOne/Oos/OosNcoOtReceiver.cpp index 312bb45..58c003f 100644 --- a/libOTe/NChooseOne/Oos/OosNcoOtReceiver.cpp +++ b/libOTe/NChooseOne/Oos/OosNcoOtReceiver.cpp @@ -13,7 +13,7 @@ using namespace std; namespace osuCrypto { task<> OosNcoOtReceiver::setBaseOts(span> baseRecvOts, PRNG& prng, Socket& chl) - { + try { if (u64(baseRecvOts.size()) != u64(mGens.size())) throw std::runtime_error("rt error at " LOCATION); @@ -31,6 +31,11 @@ namespace osuCrypto mHasBase = true; co_await chl.send(std::move(delta)); } + catch (...) + { + chl.close(); + throw; + } void OosNcoOtReceiver::setUniformBaseOts(span> baseRecvOts) { @@ -47,7 +52,7 @@ namespace osuCrypto } task<> OosNcoOtReceiver::init(u64 numOtExt, PRNG& prng, Socket& chl) - { + try { if (mInputByteCount == 0) throw std::runtime_error("configure must be called first" LOCATION); @@ -167,6 +172,11 @@ namespace osuCrypto doneIdx = stopIdx; } } + catch (...) + { + chl.close(); + throw; + } OosNcoOtReceiver OosNcoOtReceiver::splitBase() @@ -355,7 +365,7 @@ namespace osuCrypto } task<> OosNcoOtReceiver::sendCorrection(Socket& chl, u64 sendCount) - { + try { auto sub = T1Sub{}; #ifndef NDEBUG for (u64 i = mCorrectionIdx; i < sendCount + mCorrectionIdx; ++i) @@ -376,9 +386,14 @@ namespace osuCrypto co_await chl.send(std::move(sub)); } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtReceiver::check(Socket& chl, block wordSeed) - { + try { if (mMalicious) { co_await(sendFinalization(chl, wordSeed)); @@ -387,9 +402,14 @@ namespace osuCrypto co_await(sendProof(chl)); } } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtReceiver::sendFinalization(Socket& chl, block seed) - { + try { #ifndef NDEBUG for (u64 i = 0; i < mCorrectionIdx; ++i) @@ -435,13 +455,23 @@ namespace osuCrypto // now send the internally stored correction values. return sendCorrection(chl, mStatSecParam); } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtReceiver::recvChallenge(Socket& chl) - { + try { // the sender will now tell us the random challenge seed. return macoro::make_task(chl.recv(mChallengeSeed)); } + catch (...) + { + chl.close(); + throw; + } void OosNcoOtReceiver::computeProof() { @@ -678,11 +708,17 @@ namespace osuCrypto } } + task<> OosNcoOtReceiver::sendProof(Socket& chl) - { + try { // send over our summations. co_await(chl.send(std::move(mTBuff))); co_await(chl.send(std::move(mWBuff))); } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/NChooseOne/Oos/OosNcoOtSender.cpp b/libOTe/NChooseOne/Oos/OosNcoOtSender.cpp index b3775d2..6c050c6 100644 --- a/libOTe/NChooseOne/Oos/OosNcoOtSender.cpp +++ b/libOTe/NChooseOne/Oos/OosNcoOtSender.cpp @@ -85,7 +85,7 @@ namespace osuCrypto task<> OosNcoOtSender::init( u64 numOTExt, PRNG& prng, Socket& chl) - { + try { if (mInputByteCount == 0) throw std::runtime_error("configure must be called first" LOCATION); @@ -172,6 +172,11 @@ namespace osuCrypto doneIdx = stopIdx; } } + catch (...) + { + chl.close(); + throw; + } void OosNcoOtSender::encode( @@ -266,7 +271,7 @@ namespace osuCrypto } task<> OosNcoOtSender::recvCorrection(Socket& chl, u64 recvCount) - { + try { #ifndef NDEBUG if (recvCount > mCorrectionVals.bounds()[0] - mCorrectionIdx) @@ -283,9 +288,14 @@ namespace osuCrypto co_await(chl.recv(span(dest, recvCount * mCorrectionVals.stride()))); } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtSender::check(Socket& chl, block seed) - { + try { if (mMalicious) { if (mStatSecParam % 8) @@ -303,21 +313,36 @@ namespace osuCrypto //std::cout << "pass" << std::endl; } } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtSender::recvFinalization(Socket& chl) - { + try { // first we need to receive the extra mStatSecParam number of correction // values. This will just be for random inputs and are used to mask // their true choices that were used in the remaining correction values. return recvCorrection(chl, mStatSecParam); } + catch (...) + { + chl.close(); + throw; + } task<> OosNcoOtSender::sendChallenge(Socket& chl, block seed) - { + try { mChallengeSeed = seed; return macoro::make_task(chl.send(std::move(mChallengeSeed))); } + catch (...) + { + chl.close(); + throw; + } void OosNcoOtSender::computeProof() { diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h index c5cfb53..b7751c0 100644 --- a/libOTe/Tools/Pprf/RegularPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -17,1141 +17,1152 @@ namespace osuCrypto { - extern const std::array gGgmAes; + extern const std::array gGgmAes; - template< - typename F, - typename G = F, - typename CoeffCtx = DefaultCoeffCtx - > - class RegularPprfSender : public TimerAdapter { - public: + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class RegularPprfSender : public TimerAdapter { + public: - // the number of leaves in a single tree. - u64 mDomain = 0; + // the number of leaves in a single tree. + u64 mDomain = 0; - // the depth of each tree. - u64 mDepth = 0; + // the depth of each tree. + u64 mDepth = 0; - // the number of trees, must be a multiple of 8. - u64 mPntCount = 0; + // the number of trees, must be a multiple of 8. + u64 mPntCount = 0; - // the values that should be programmed at the punctured points. - std::vector mValue; + // the values that should be programmed at the punctured points. + std::vector mValue; - // the base OTs that should be set. - Matrix> mBaseOTs; - - // if true, tree OT messages are eagerly sent in batches of 8. - // otherwise, the OT messages are sent in a single batch. - bool mEagerSend = true; - - using VecF = typename CoeffCtx::template Vec; - using VecG = typename CoeffCtx::template Vec; - - // a function that can be used to output the result of the PPRF. - std::function mOutputFn; - - // an internal buffer that is used to expand the tree. - AlignedUnVector mTempBuffer; - - RegularPprfSender() = default; - - RegularPprfSender(const RegularPprfSender&) = delete; - - RegularPprfSender(RegularPprfSender&&) = delete; - - RegularPprfSender(u64 domainSize, u64 pointCount) { - configure(domainSize, pointCount); - } - - void configure(u64 domainSize, u64 pointCount) - { - if (domainSize & 1) - throw std::runtime_error("Pprf domain must be even. " LOCATION); - if (domainSize < 2) - throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); - - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - - mBaseOTs.resize(0, 0); - } - - - // the number of base OTs that should be set. - u64 baseOtCount() const { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const { - return mBaseOTs.size(); - } - - - void setBase(span> baseMessages) { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - mBaseOTs.resize(mPntCount, mDepth); - for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) - mBaseOTs(i) = baseMessages[i]; - } - - task<> expand( - Socket& chl, - const VecF& value, - block seed, - VecF& output, - PprfOutputFormat oFormat, - bool programPuncturedPoint, - u64 numThreads, - CoeffCtx ctx = {}) - { - if (programPuncturedPoint) - setValue(value); - - setTimePoint("SilentMultiPprfSender.start"); - - pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); - - //auto tree = span>{}; - auto levels = std::vector> >{}; - auto leafIndex = u64{}; - auto leafLevelPtr = (VecF*)nullptr; - auto leafLevel = VecF{}; - auto buff = std::vector{}; - auto encSums = span>{}; - auto leafMsgs = span{}; - auto encStepSize = u64{}; - auto leafStepSize = u64{}; - auto encOffset = u64{}; - auto leafOffset = u64{}; - - auto dd = mDomain > 2 ? roundUpTo((mDomain + 1) / 2, 2) : 1; - pprf::allocateExpandTree(dd, mTempBuffer, levels); - assert(levels.size() == mDepth); - - if (!mEagerSend) - { - // we need to allocate one large buffer that will store all OT messages. - pprf::allocateExpandBuffer( - mDepth - 1, mPntCount, programPuncturedPoint, buff, encSums, leafMsgs, ctx); - encStepSize = encSums.size() / mPntCount; - leafStepSize = leafMsgs.size() / mPntCount; - encOffset = 0; - leafOffset = 0; - } - - for (auto treeIndex = 0ull; treeIndex < mPntCount; treeIndex += 8) - { - // for interleaved format, the leaf level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - leafIndex = treeIndex * mDomain; - leafLevelPtr = &output; - } - else - { - // we will use leaf level as a buffer before - // copying the result to the output. - leafIndex = 0; - ctx.resize(leafLevel, mDomain * 8); - leafLevelPtr = &leafLevel; - } - - auto min = std::min(8, mPntCount - treeIndex); - if (mEagerSend) - { - // allocate a send buffer for the next 8 trees. - pprf::allocateExpandBuffer( - mDepth - 1, min, programPuncturedPoint, buff, encSums, leafMsgs, ctx); - encStepSize = encSums.size() / min; - leafStepSize = leafMsgs.size() / min; - encOffset = 0; - leafOffset = 0; - } - - // exapnd the tree - expandOne( - seed, - treeIndex, - programPuncturedPoint, - levels, - *leafLevelPtr, - leafIndex, - encSums.subspan(encOffset, encStepSize * min), - leafMsgs.subspan(leafOffset, leafStepSize * min), - ctx); - - encOffset += encStepSize * min; - leafOffset += leafStepSize * min; - - if (mEagerSend) - { - // send the buffer for the current set of trees. - co_await (chl.send(std::move(buff))); - } - - // if we aren't interleaved, we need to copy the - // leaf layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); - - } - - - if (!mEagerSend) - { - // send the buffer for all of the trees. - co_await (chl.send(std::move(buff))); - } - - mBaseOTs = {}; - - setTimePoint("SilentMultiPprfSender.de-alloc"); - } - - void setValue(span value) { - - mValue.resize(mPntCount); - - if (value.size() == 1) { - std::fill(mValue.begin(), mValue.end(), value[0]); - } - else { - if ((u64)value.size() != mPntCount) - throw RTE_LOC; - - std::copy(value.begin(), value.end(), mValue.begin()); - } - } - - void clear() { - mBaseOTs.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void expandOne( - block aesSeed, - u64 treeIdx, - bool programPuncturedPoint, - span>> levels, - VecF& leafLevel, - const u64 leafOffset, - span> encSums, - span leafMsgs, - CoeffCtx ctx) - { - auto remTrees = std::min(8, mPntCount - treeIdx); - - // the first level should be size 1, the root of the tree. - // we will populate it with random seeds using aesSeed in counter mode - // based on the tree index. - assert(levels[0].size() == 1); - mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); - - assert(encSums.size() == (mDepth - 1) * remTrees); - auto encSumIter = encSums.begin(); - - // space for our sums of each level. Should always be less then - // 24 levels... If not increase the limit or make it a vector. - std::array, 2> sums; - - // use the optimized approach for intern nodes of the tree - // For each level perform the following. - for (u64 d = 0; d < mDepth - 1; ++d) - { - // clear the sums - memset(&sums, 0, sizeof(sums)); - - // The total number of parents in this level. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // The previous level of the GGM tree. - auto parents = levels[d]; - - // The next level of theGGM tree that we are populating. - auto children = levels[d + 1]; - - // For each child, populate the child by expanding the parent. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) - { - // The value of the parent. - auto& parent = parents.data()[parentIdx]; - - auto& child0 = children.data()[childIdx]; - auto& child1 = children.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - // inspired by the Expand Accumualte idea to - // use - // - // child0 = AES(parent) ^ parent - // child1 = AES(parent) + parent - // - // but instead we are a bit more conservative and - // compute - // - // child0 = AES:Round(AES(parent), parent) - // = AES:Round(AES(parent), 0) ^ parent - // child1 = AES(parent) + parent - // - // That is, we applies an additional AES round function - // to the first child before XORing it with parent. - child0[0] = AES::roundEnc(child1[0], parent[0]); - child0[1] = AES::roundEnc(child1[1], parent[1]); - child0[2] = AES::roundEnc(child1[2], parent[2]); - child0[3] = AES::roundEnc(child1[3], parent[3]); - child0[4] = AES::roundEnc(child1[4], parent[4]); - child0[5] = AES::roundEnc(child1[5], parent[5]); - child0[6] = AES::roundEnc(child1[6], parent[6]); - child0[7] = AES::roundEnc(child1[7], parent[7]); - - // Update the running sums for this level. We keep - // a left and right totals for each level. - sums[0][0] = sums[0][0] ^ child0[0]; - sums[0][1] = sums[0][1] ^ child0[1]; - sums[0][2] = sums[0][2] ^ child0[2]; - sums[0][3] = sums[0][3] ^ child0[3]; - sums[0][4] = sums[0][4] ^ child0[4]; - sums[0][5] = sums[0][5] ^ child0[5]; - sums[0][6] = sums[0][6] ^ child0[6]; - sums[0][7] = sums[0][7] ^ child0[7]; - - // child1 = AES(parent) + parent - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - sums[1][0] = sums[1][0] ^ child1[0]; - sums[1][1] = sums[1][1] ^ child1[1]; - sums[1][2] = sums[1][2] ^ child1[2]; - sums[1][3] = sums[1][3] ^ child1[3]; - sums[1][4] = sums[1][4] ^ child1[4]; - sums[1][5] = sums[1][5] ^ child1[5]; - sums[1][6] = sums[1][6] ^ child1[6]; - sums[1][7] = sums[1][7] ^ child1[7]; - - } - - // encrypt the sums and write them to the output. - for (u64 j = 0; j < remTrees; ++j) - { - (*encSumIter)[0] = sums[0][j] ^ mBaseOTs(treeIdx + j, mDepth - 1 - d)[1]; - (*encSumIter)[1] = sums[1][j] ^ mBaseOTs(treeIdx + j, mDepth - 1 - d)[0]; - ++encSumIter; - } - } - assert(encSumIter == encSums.end()); - - auto d = mDepth - 1; - - // The previous level of the GGM tree. - auto level0 = levels[d]; - - // The total number of parents in this level. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // The next level of theGGM tree that we are populating. - std::array child; - - // clear the sums - std::array leafSums; - ctx.resize(leafSums[0], 8); - ctx.resize(leafSums[1], 8); - ctx.zero(leafSums[0].begin(), leafSums[0].end()); - ctx.zero(leafSums[1].begin(), leafSums[1].end()); - - auto outIter = leafLevel.data() + leafOffset; - - // for the leaf nodes we need to hash both children. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto& parent = level0.data()[parentIdx]; - - // The bit that indicates if we are on the left child (0) - // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - - if constexpr (std::is_same_v) - { - gGgmAes.data()[keep].hashBlocks<8>(parent.data(), outIter); - } - else - { - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gGgmAes.data()[keep].hashBlocks<8>(parent.data(), child.data()); - - ctx.fromBlock(*(outIter + 0), child.data()[0]); - ctx.fromBlock(*(outIter + 1), child.data()[1]); - ctx.fromBlock(*(outIter + 2), child.data()[2]); - ctx.fromBlock(*(outIter + 3), child.data()[3]); - ctx.fromBlock(*(outIter + 4), child.data()[4]); - ctx.fromBlock(*(outIter + 5), child.data()[5]); - ctx.fromBlock(*(outIter + 6), child.data()[6]); - ctx.fromBlock(*(outIter + 7), child.data()[7]); - } - - // leafSum += child - auto& leafSum = leafSums[keep]; - ctx.plus(leafSum.data()[0], leafSum.data()[0], *(outIter + 0)); - ctx.plus(leafSum.data()[1], leafSum.data()[1], *(outIter + 1)); - ctx.plus(leafSum.data()[2], leafSum.data()[2], *(outIter + 2)); - ctx.plus(leafSum.data()[3], leafSum.data()[3], *(outIter + 3)); - ctx.plus(leafSum.data()[4], leafSum.data()[4], *(outIter + 4)); - ctx.plus(leafSum.data()[5], leafSum.data()[5], *(outIter + 5)); - ctx.plus(leafSum.data()[6], leafSum.data()[6], *(outIter + 6)); - ctx.plus(leafSum.data()[7], leafSum.data()[7], *(outIter + 7)); - - outIter += 8; - assert(outIter <= leafLevel.data() + leafLevel.size()); - } - - } - - if (programPuncturedPoint) - { - // For the leaf level, we are going to do something special. - // The other party is currently missing both leaf children of - // the active parent. Since this is the leaf level, we want - // the inactive child to just be the normal value but the - // active child should be the correct value XOR the delta. - // This will be done by sending the sums and the sums plus - // delta and ensure that they can only decrypt the correct ones. - VecF leafOts; - ctx.resize(leafOts, 2); - PRNG otMasker; - - for (u64 j = 0; j < remTrees; ++j) - { - // we will construct two OT strings. Let - // s0, s1 be the left and right child sums. - // - // m0 = (s0 , s1 + val) - // m1 = (s0 + val, s1 ) - // - // these will be encrypted by the OT keys - for (u64 k = 0; k < 2; ++k) - { - if (k == 0) - { - // m0 = (s0, s1 + val) - ctx.copy(leafOts[0], leafSums[0][j]); - ctx.plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); - } - else - { - // m1 = (s0+val, s1) - ctx.plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); - ctx.copy(leafOts[1], leafSums[1][j]); - } - - // copy m0 into the output buffer. - span buff = leafMsgs.subspan(0, 2 * ctx.template byteSize()); - leafMsgs = leafMsgs.subspan(buff.size()); - ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); - - // encrypt the output buffer. - otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); - for (u64 i = 0; i < buff.size(); ++i) - buff[i] ^= otMasker.get(); - - } - } - } - else - { - VecF leafOts; - ctx.resize(leafOts, 1); - PRNG otMasker; - - for (u64 j = 0; j < remTrees; ++j) - { - for (u64 k = 0; k < 2; ++k) - { - // copy the sum k into the output buffer. - ctx.copy(leafOts[0], leafSums[k][j]); - span buff = leafMsgs.subspan(0, ctx.template byteSize()); - leafMsgs = leafMsgs.subspan(buff.size()); - ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); - - // encrypt the output buffer. - otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); - for (u64 i = 0; i < buff.size(); ++i) - buff[i] ^= otMasker.get(); - - } - } - } - - assert(leafMsgs.size() == 0); - } - - - }; - - - template< - typename F, - typename G = F, - typename CoeffCtx = DefaultCoeffCtx - > - class RegularPprfReceiver : public TimerAdapter - { - public: - - // the number of leaves in a single tree. - u64 mDomain = 0; - - // the depth of each tree. - u64 mDepth = 0; - - // the number of trees, must be a multiple of 8. - u64 mPntCount = 0; - - using VecF = typename CoeffCtx::template Vec; - using VecG = typename CoeffCtx::template Vec; - - // base ots that will be used to expand the tree. - Matrix mBaseOTs; - - // the choice bits, each row should be the bit decomposition of the active path. - Matrix mBaseChoices; - - // if true, tree OT messages are eagerly sent in batches of 8. - // otherwise, the OT messages are sent in a single batch. - bool mEagerSend = true; - - // a function that can be used to output the result of the PPRF. - std::function mOutputFn; - - // an internal buffer that is used to expand the tree. - AlignedUnVector mTempBuffer; - - RegularPprfReceiver() = default; - RegularPprfReceiver(const RegularPprfReceiver&) = delete; - RegularPprfReceiver(RegularPprfReceiver&&) = delete; - - void configure(u64 domainSize, u64 pointCount) - { - if (domainSize & 1) - throw std::runtime_error("Pprf domain must be even. " LOCATION); - if (domainSize < 2) - throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); - - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - - mBaseOTs.resize(0, 0); - } - - - // this function sample mPntCount integers in the range - // [0,domain) and returns these as the choice bits. - BitVector sampleChoiceBits(PRNG& prng) - { - BitVector choices(mPntCount * mDepth); - - // The points are read in blocks of 8, so make sure that there is a - // whole number of blocks. - mBaseChoices.resize(mPntCount, mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - u64 idx = prng.get() % mDomain; - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = *BitIterator((u8*)&idx, j); - } - - for (u64 i = 0; i < mBaseChoices.size(); ++i) - { - choices[i] = mBaseChoices(i); - } - - return choices; - } - - // choices is in the same format as the output from sampleChoiceBits. - void setChoiceBits(const BitVector& choices) - { - // Make sure we're given the right number of OTs. - if (choices.size() != baseOtCount()) - throw RTE_LOC; - - mBaseChoices.resize(mPntCount, mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - u64 idx = 0; - for (u64 j = 0; j < mDepth; ++j) - { - mBaseChoices(i, j) = choices[mDepth * i + j]; - idx |= u64(choices[mDepth * i + j]) << j; - } - - if (idx >= mDomain) - throw std::runtime_error("provided choice bits index outside of the domain." LOCATION); - } - } - - - // the number of base OTs that should be set. - u64 baseOtCount() const - { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const - { - return mBaseOTs.size(); - } - - - void setBase(span baseMessages) - { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - // The OTs are used in blocks of 8, so make sure that there is a whole - // number of blocks. - mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); - memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); - } - - std::vector getPoints(PprfOutputFormat format) - { - std::vector pnts(mPntCount); - getPoints(pnts, format); - return pnts; - } - void getPoints(span points, PprfOutputFormat format) - { - if ((u64)points.size() != mPntCount) - throw RTE_LOC; - - switch (format) - { - case PprfOutputFormat::ByLeafIndex: - case PprfOutputFormat::ByTreeIndex: - - memset(points.data(), 0, points.size() * sizeof(u64)); - for (u64 j = 0; j < mPntCount; ++j) - { - for (u64 k = 0; k < mDepth; ++k) - points[j] |= u64(mBaseChoices(j, k)) << k; - - assert(points[j] < mDomain); - } - - - break; - case PprfOutputFormat::Interleaved: - case PprfOutputFormat::Callback: - - getPoints(points, PprfOutputFormat::ByLeafIndex); - - // in interleaved mode we generate 8 trees in a batch. - // the i'th leaf of these 8 trees are next to eachother. - for (u64 j = 0; j < points.size(); ++j) - { - auto subTree = j % 8; - auto batch = j / 8; - points[j] = (batch * mDomain + points[j]) * 8 + subTree; - } - - //interleavedPoints(points, mDomain, format); - - break; - default: - throw RTE_LOC; - break; - } - } - - // programPuncturedPoint says whether the sender is trying to program the - // active child to be its correct value XOR delta. If it is not, the - // active child will just take a random value. - task<> expand( - Socket& chl, - VecF& output, - PprfOutputFormat oFormat, - bool programPuncturedPoint, - u64 numThreads, - CoeffCtx ctx = {}) - { - pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); - - auto treeIndex = u64{}; - auto levels = std::vector>>{}; - auto leafIndex = u64{}; - auto leafLevelPtr = (VecF*)nullptr; - auto leafLevel = VecF{}; - auto buff = std::vector{}; - auto encSums = span>{}; - auto leafMsgs = span{}; - auto points = std::vector{}; - auto encStepSize = u64{}; - auto leafStepSize = u64{}; - auto encOffset = u64{}; - auto leafOffset = u64{}; - - setTimePoint("SilentMultiPprfReceiver.start"); - points.resize(mPntCount); - getPoints(points, PprfOutputFormat::ByLeafIndex); - - //setTimePoint("SilentMultiPprfSender.reserve"); - - auto dd = mDomain > 2 ? roundUpTo((mDomain + 1) / 2, 2) : 1; - pprf::allocateExpandTree(dd, mTempBuffer, levels); - assert(levels.size() == mDepth); - - - if (!mEagerSend) - { - // we need to allocate one large buffer that will store all OT messages. - pprf::allocateExpandBuffer( - mDepth - 1, mPntCount, programPuncturedPoint, buff, encSums, leafMsgs, ctx); - encStepSize = encSums.size() / mPntCount; - leafStepSize = leafMsgs.size() / mPntCount; - encOffset = 0; - leafOffset = 0; - - co_await (chl.recv(buff)); - } - - for (treeIndex = 0ull; treeIndex < mPntCount; treeIndex += 8) - { - // for interleaved format, the leaf level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - leafIndex = treeIndex * mDomain; - leafLevelPtr = &output; - } - else - { - // we will use leaf level as a buffer before - // copying the result to the output. - leafIndex = 0; - ctx.resize(leafLevel, mDomain * 8); - leafLevelPtr = &leafLevel; - } - - auto min = std::min(8, mPntCount - treeIndex); - if (mEagerSend) - { - - // allocate the send buffer and partition it. - pprf::allocateExpandBuffer(mDepth - 1, min, - programPuncturedPoint, buff, encSums, leafMsgs, ctx); - encStepSize = encSums.size() / min; - leafStepSize = leafMsgs.size() / min; - encOffset = 0; - leafOffset = 0; - co_await (chl.recv(buff)); - } - - // exapnd the tree - expandOne( - treeIndex, - programPuncturedPoint, - levels, - *leafLevelPtr, - leafIndex, - encSums.subspan(encOffset, encStepSize * min), - leafMsgs.subspan(leafOffset, leafStepSize * min), - points, - ctx); - - encOffset += encStepSize * min; - leafOffset += leafStepSize * min; - - // if we aren't interleaved, we need to copy the - // leaf layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); - } - - setTimePoint("SilentMultiPprfReceiver.join"); - - mBaseOTs = {}; - - setTimePoint("SilentMultiPprfReceiver.de-alloc"); - } - - void clear() - { - mBaseOTs.resize(0, 0); - mBaseChoices.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void expandOne( - u64 treeIdx, - bool programPuncturedPoint, - span>> levels, - VecF& leafLevel, - const u64 outputOffset, - span> theirSums, - span leafMsg, - span points, - CoeffCtx& ctx) - { - auto remTrees = std::min(8, mPntCount - treeIdx); - assert(theirSums.size() == remTrees * (mDepth - 1)); - - // We change the hash function for the leaf so lets update - // inactiveChildValues to use the new hash and subtract - // these from the leafSums - std::array leafSums; - if (mDepth > 1) - { - auto theirSumsIter = theirSums.begin(); - - // special case for the first level. - auto l1 = levels[1]; - for (u64 i = 0; i < remTrees; ++i) - { - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - - int active = mBaseChoices[i + treeIdx].back(); - l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ (*theirSumsIter)[active ^ 1]; - l1[active][i] = ZeroBlock; - ++theirSumsIter; - //if (!i) - // std::cout << " unmask " - // << mBaseOTs[i + treeIdx].back() << " ^ " - // << theirSums[0][active ^ 1][i] << " = " - // << l1[active ^ 1][i] << std::endl; - - } - - // space for our sums of each level. - std::array, 2> mySums; - - // this will be the value of both children of active an parent - // before the active child is updated. We will need to subtract - // this value as the main loop does not distinguish active parents. - std::array inactiveChildValues; - inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); - inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < mDepth - 1; ++d) - { - // initialized the sums with inactiveChildValue so that - // it will cancel when we expand the actual inactive child. - std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); - std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; - assert(level0.size() == width || level0.size() == width + 1); - - // The next level that we want to construct. - auto level1 = levels[d + 1]; - assert(level1.size() == width * 2); - - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) - { - // The value of the parent. - auto parent = level0[parentIdx]; - - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - // inspired by the Expand Accumualte idea to - // use - // - // child0 = AES(parent) ^ parent - // child1 = AES(parent) + parent - // - // but instead we are a bit more conservative and - // compute - // - // child0 = AES:Round(AES(parent), parent) - // = AES:Round(AES(parent), 0) ^ parent - // child1 = AES(parent) + parent - // - // That is, we applies an additional AES round function - // to the first child before XORing it with parent. - child0[0] = AES::roundEnc(child1[0], parent[0]); - child0[1] = AES::roundEnc(child1[1], parent[1]); - child0[2] = AES::roundEnc(child1[2], parent[2]); - child0[3] = AES::roundEnc(child1[3], parent[3]); - child0[4] = AES::roundEnc(child1[4], parent[4]); - child0[5] = AES::roundEnc(child1[5], parent[5]); - child0[6] = AES::roundEnc(child1[6], parent[6]); - child0[7] = AES::roundEnc(child1[7], parent[7]); - - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent but this will cancel - // with inactiveChildValue thats already there. - mySums[0][0] = mySums[0][0] ^ child0[0]; - mySums[0][1] = mySums[0][1] ^ child0[1]; - mySums[0][2] = mySums[0][2] ^ child0[2]; - mySums[0][3] = mySums[0][3] ^ child0[3]; - mySums[0][4] = mySums[0][4] ^ child0[4]; - mySums[0][5] = mySums[0][5] ^ child0[5]; - mySums[0][6] = mySums[0][6] ^ child0[6]; - mySums[0][7] = mySums[0][7] ^ child0[7]; - - // child1 = AES(parent) + parent - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - mySums[1][0] = mySums[1][0] ^ child1[0]; - mySums[1][1] = mySums[1][1] ^ child1[1]; - mySums[1][2] = mySums[1][2] ^ child1[2]; - mySums[1][3] = mySums[1][3] ^ child1[3]; - mySums[1][4] = mySums[1][4] ^ child1[4]; - mySums[1][5] = mySums[1][5] ^ child1[5]; - mySums[1][6] = mySums[1][6] ^ child1[6]; - mySums[1][7] = mySums[1][7] ^ child1[7]; - - } - - - // we have to update the non-active child of the active parent. - for (u64 i = 0; i < remTrees; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = points[i + treeIdx]; - - // The index of the active (missing) child node. - auto missingChildIdx = leafIdx >> (mDepth - 1 - d); - - // The index of the active child node sibling. - auto siblingIdx = missingChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = siblingIdx & 1; - - // our sums & OTs cancel and we are leaf with the - // correct value for the inactive child. - level1[siblingIdx][i] = - (*theirSumsIter)[notAi] ^ - mySums[notAi][i] ^ - mBaseOTs(i + treeIdx, mDepth - 1 - d); - - ++theirSumsIter; - - // we have to set the active child to zero so - // the next children are predictable. - level1[missingChildIdx][i] = ZeroBlock; - } - } - - auto d = mDepth - 1; - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; - - // The next level of theGGM tree that we are populating. - std::array child; - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - VecF temp; - ctx.resize(temp, 2); - for (u64 k = 0; k < 2; ++k) - { - ctx.resize(leafSums[k], 8); - ctx.zero(leafSums[k].begin(), leafSums[k].end()); - ctx.fromBlock(temp[k], gGgmAes[k].hashBlock(ZeroBlock)); - ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); - for (u64 i = 1; i < 8; ++i) - ctx.copy(leafSums[k][i], leafSums[k][0]); - } - - auto outIter = leafLevel.data() + outputOffset; - // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0.data()[parentIdx]; - - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - if constexpr (std::is_same_v) - { - gGgmAes.data()[keep].hashBlocks<8>(parent.data(), outIter); - } - else - { - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gGgmAes.data()[keep].hashBlocks<8>(parent.data(), child.data()); - - ctx.fromBlock(*(outIter + 0), child.data()[0]); - ctx.fromBlock(*(outIter + 1), child.data()[1]); - ctx.fromBlock(*(outIter + 2), child.data()[2]); - ctx.fromBlock(*(outIter + 3), child.data()[3]); - ctx.fromBlock(*(outIter + 4), child.data()[4]); - ctx.fromBlock(*(outIter + 5), child.data()[5]); - ctx.fromBlock(*(outIter + 6), child.data()[6]); - ctx.fromBlock(*(outIter + 7), child.data()[7]); - } - auto& leafSum = leafSums[keep]; - ctx.plus(leafSum.data()[0], leafSum.data()[0], *(outIter + 0)); - ctx.plus(leafSum.data()[1], leafSum.data()[1], *(outIter + 1)); - ctx.plus(leafSum.data()[2], leafSum.data()[2], *(outIter + 2)); - ctx.plus(leafSum.data()[3], leafSum.data()[3], *(outIter + 3)); - ctx.plus(leafSum.data()[4], leafSum.data()[4], *(outIter + 4)); - ctx.plus(leafSum.data()[5], leafSum.data()[5], *(outIter + 5)); - ctx.plus(leafSum.data()[6], leafSum.data()[6], *(outIter + 6)); - ctx.plus(leafSum.data()[7], leafSum.data()[7], *(outIter + 7)); - - outIter += 8; - assert(outIter <= leafLevel.data() + leafLevel.size()); - } - } - } - else - { - for (u64 k = 0; k < 2; ++k) - { - ctx.resize(leafSums[k], 8); - ctx.zero(leafSums[k].begin(), leafSums[k].end()); - } - } - - // leaf level. - if (programPuncturedPoint) - { - // Now processes the leaf level. This one is special - // because we must XOR in the correction value as - // before but we must also fixed the child value for - // the active child. To do this, we will receive 4 - // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvleaf"); - VecF leafOts; - ctx.resize(leafOts, 2); - PRNG otMasker; - - for (u64 j = 0; j < remTrees; ++j) - { - - // The index of the child on the active path. - auto activeChildIdx = points[j + treeIdx]; - - // The index of the other (inactive) child. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - // offset to the first or second ot message, based on the one we want - auto offset = ctx.template byteSize() * 2 * notAi; - - - // decrypt the ot string - span buff = leafMsg.subspan(offset, ctx.template byteSize() * 2); - leafMsg = leafMsg.subspan(buff.size() * 2); - otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); - for (u64 i = 0; i < buff.size(); ++i) - buff[i] ^= otMasker.get(); - - ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); - - auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; - auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; - - ctx.minus(leafLevel[out0], leafOts[0], leafSums[0][j]); - ctx.minus(leafLevel[out1], leafOts[1], leafSums[1][j]); - } - } - else - { - VecF leafOts; - ctx.resize(leafOts, 1); - PRNG otMasker; - - for (u64 j = 0; j < remTrees; ++j) - { - // The index of the child on the active path. - auto activeChildIdx = points[j + treeIdx]; - - // The index of the other (inactive) child. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - // offset to the first or second ot message, based on the one we want - auto offset = ctx.template byteSize() * notAi; - - // decrypt the ot string - span buff = leafMsg.subspan(offset, ctx.template byteSize()); - leafMsg = leafMsg.subspan(buff.size() * 2); - otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); - for (u64 i = 0; i < buff.size(); ++i) - buff[i] ^= otMasker.get(); - - ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); - - std::array out{ - (activeChildIdx & ~1ull) * 8 + j + outputOffset, - (activeChildIdx | 1ull) * 8 + j + outputOffset - }; - - auto keep = leafLevel.begin() + out[notAi]; - auto zero = leafLevel.begin() + out[notAi ^ 1]; - - ctx.minus(*keep, leafOts[0], leafSums[notAi][j]); - ctx.zero(zero, zero + 1); - } - } - } - }; + // the base OTs that should be set. + Matrix> mBaseOTs; + + // if true, tree OT messages are eagerly sent in batches of 8. + // otherwise, the OT messages are sent in a single batch. + bool mEagerSend = true; + + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + + // a function that can be used to output the result of the PPRF. + std::function mOutputFn; + + // an internal buffer that is used to expand the tree. + AlignedUnVector mTempBuffer; + + RegularPprfSender() = default; + + RegularPprfSender(const RegularPprfSender&) = delete; + + RegularPprfSender(RegularPprfSender&&) = delete; + + RegularPprfSender(u64 domainSize, u64 pointCount) { + configure(domainSize, pointCount); + } + + void configure(u64 domainSize, u64 pointCount) + { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); + + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + + mBaseOTs.resize(0, 0); + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const { + return mBaseOTs.size(); + } + + + void setBase(span> baseMessages) { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + mBaseOTs.resize(mPntCount, mDepth); + for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) + mBaseOTs(i) = baseMessages[i]; + } + + task<> expand( + Socket& chl, + const VecF& value, + block seed, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads, + CoeffCtx ctx = {}) + try { + if (programPuncturedPoint) + setValue(value); + + setTimePoint("SilentMultiPprfSender.start"); + + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); + + //auto tree = span>{}; + auto levels = std::vector> >{}; + auto leafIndex = u64{}; + auto leafLevelPtr = (VecF*)nullptr; + auto leafLevel = VecF{}; + auto buff = std::vector{}; + auto encSums = span>{}; + auto leafMsgs = span{}; + auto encStepSize = u64{}; + auto leafStepSize = u64{}; + auto encOffset = u64{}; + auto leafOffset = u64{}; + + auto dd = mDomain > 2 ? roundUpTo((mDomain + 1) / 2, 2) : 1; + pprf::allocateExpandTree(dd, mTempBuffer, levels); + assert(levels.size() == mDepth); + + if (!mEagerSend) + { + // we need to allocate one large buffer that will store all OT messages. + pprf::allocateExpandBuffer( + mDepth - 1, mPntCount, programPuncturedPoint, buff, encSums, leafMsgs, ctx); + encStepSize = encSums.size() / mPntCount; + leafStepSize = leafMsgs.size() / mPntCount; + encOffset = 0; + leafOffset = 0; + } + + for (auto treeIndex = 0ull; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + ctx.resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } + + auto min = std::min(8, mPntCount - treeIndex); + if (mEagerSend) + { + // allocate a send buffer for the next 8 trees. + pprf::allocateExpandBuffer( + mDepth - 1, min, programPuncturedPoint, buff, encSums, leafMsgs, ctx); + encStepSize = encSums.size() / min; + leafStepSize = leafMsgs.size() / min; + encOffset = 0; + leafOffset = 0; + } + + // exapnd the tree + expandOne( + seed, + treeIndex, + programPuncturedPoint, + levels, + *leafLevelPtr, + leafIndex, + encSums.subspan(encOffset, encStepSize * min), + leafMsgs.subspan(leafOffset, leafStepSize * min), + ctx); + + encOffset += encStepSize * min; + leafOffset += leafStepSize * min; + + if (mEagerSend) + { + // send the buffer for the current set of trees. + co_await(chl.send(std::move(buff))); + } + + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + + } + + + if (!mEagerSend) + { + // send the buffer for all of the trees. + co_await(chl.send(std::move(buff))); + } + + mBaseOTs = {}; + + setTimePoint("SilentMultiPprfSender.de-alloc"); + } + catch (...) + { + chl.close(); + throw; + } + + void setValue(span value) { + + mValue.resize(mPntCount); + + if (value.size() == 1) { + std::fill(mValue.begin(), mValue.end(), value[0]); + } + else { + if ((u64)value.size() != mPntCount) + throw RTE_LOC; + + std::copy(value.begin(), value.end(), mValue.begin()); + } + } + + void clear() { + mBaseOTs.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + void expandOne( + block aesSeed, + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF& leafLevel, + const u64 leafOffset, + span> encSums, + span leafMsgs, + CoeffCtx ctx) + { + auto remTrees = std::min(8, mPntCount - treeIdx); + + // the first level should be size 1, the root of the tree. + // we will populate it with random seeds using aesSeed in counter mode + // based on the tree index. + assert(levels[0].size() == 1); + mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); + + assert(encSums.size() == (mDepth - 1) * remTrees); + auto encSumIter = encSums.begin(); + + // space for our sums of each level. Should always be less then + // 24 levels... If not increase the limit or make it a vector. + std::array, 2> sums; + + // use the optimized approach for intern nodes of the tree + // For each level perform the following. + for (u64 d = 0; d < mDepth - 1; ++d) + { + // clear the sums + memset(&sums, 0, sizeof(sums)); + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The previous level of the GGM tree. + auto parents = levels[d]; + + // The next level of theGGM tree that we are populating. + auto children = levels[d + 1]; + + // For each child, populate the child by expanding the parent. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto& parent = parents.data()[parentIdx]; + + auto& child0 = children.data()[childIdx]; + auto& child1 = children.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. + sums[0][0] = sums[0][0] ^ child0[0]; + sums[0][1] = sums[0][1] ^ child0[1]; + sums[0][2] = sums[0][2] ^ child0[2]; + sums[0][3] = sums[0][3] ^ child0[3]; + sums[0][4] = sums[0][4] ^ child0[4]; + sums[0][5] = sums[0][5] ^ child0[5]; + sums[0][6] = sums[0][6] ^ child0[6]; + sums[0][7] = sums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + sums[1][0] = sums[1][0] ^ child1[0]; + sums[1][1] = sums[1][1] ^ child1[1]; + sums[1][2] = sums[1][2] ^ child1[2]; + sums[1][3] = sums[1][3] ^ child1[3]; + sums[1][4] = sums[1][4] ^ child1[4]; + sums[1][5] = sums[1][5] ^ child1[5]; + sums[1][6] = sums[1][6] ^ child1[6]; + sums[1][7] = sums[1][7] ^ child1[7]; + + } + + // encrypt the sums and write them to the output. + for (u64 j = 0; j < remTrees; ++j) + { + (*encSumIter)[0] = sums[0][j] ^ mBaseOTs(treeIdx + j, mDepth - 1 - d)[1]; + (*encSumIter)[1] = sums[1][j] ^ mBaseOTs(treeIdx + j, mDepth - 1 - d)[0]; + ++encSumIter; + } + } + assert(encSumIter == encSums.end()); + + auto d = mDepth - 1; + + // The previous level of the GGM tree. + auto level0 = levels[d]; + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The next level of theGGM tree that we are populating. + std::array child; + + // clear the sums + std::array leafSums; + ctx.resize(leafSums[0], 8); + ctx.resize(leafSums[1], 8); + ctx.zero(leafSums[0].begin(), leafSums[0].end()); + ctx.zero(leafSums[1].begin(), leafSums[1].end()); + + auto outIter = leafLevel.data() + leafOffset; + + // for the leaf nodes we need to hash both children. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto& parent = level0.data()[parentIdx]; + + // The bit that indicates if we are on the left child (0) + // or on the right child (1). + for (u64 keep = 0; keep < 2; ++keep, ++childIdx) + { + // The child that we will write in this iteration. + + if constexpr (std::is_same_v) + { + gGgmAes.data()[keep].hashBlocks<8>(parent.data(), outIter); + } + else + { + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes.data()[keep].hashBlocks<8>(parent.data(), child.data()); + + ctx.fromBlock(*(outIter + 0), child.data()[0]); + ctx.fromBlock(*(outIter + 1), child.data()[1]); + ctx.fromBlock(*(outIter + 2), child.data()[2]); + ctx.fromBlock(*(outIter + 3), child.data()[3]); + ctx.fromBlock(*(outIter + 4), child.data()[4]); + ctx.fromBlock(*(outIter + 5), child.data()[5]); + ctx.fromBlock(*(outIter + 6), child.data()[6]); + ctx.fromBlock(*(outIter + 7), child.data()[7]); + } + + // leafSum += child + auto& leafSum = leafSums[keep]; + ctx.plus(leafSum.data()[0], leafSum.data()[0], *(outIter + 0)); + ctx.plus(leafSum.data()[1], leafSum.data()[1], *(outIter + 1)); + ctx.plus(leafSum.data()[2], leafSum.data()[2], *(outIter + 2)); + ctx.plus(leafSum.data()[3], leafSum.data()[3], *(outIter + 3)); + ctx.plus(leafSum.data()[4], leafSum.data()[4], *(outIter + 4)); + ctx.plus(leafSum.data()[5], leafSum.data()[5], *(outIter + 5)); + ctx.plus(leafSum.data()[6], leafSum.data()[6], *(outIter + 6)); + ctx.plus(leafSum.data()[7], leafSum.data()[7], *(outIter + 7)); + + outIter += 8; + assert(outIter <= leafLevel.data() + leafLevel.size()); + } + + } + + if (programPuncturedPoint) + { + // For the leaf level, we are going to do something special. + // The other party is currently missing both leaf children of + // the active parent. Since this is the leaf level, we want + // the inactive child to just be the normal value but the + // active child should be the correct value XOR the delta. + // This will be done by sending the sums and the sums plus + // delta and ensure that they can only decrypt the correct ones. + VecF leafOts; + ctx.resize(leafOts, 2); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + // we will construct two OT strings. Let + // s0, s1 be the left and right child sums. + // + // m0 = (s0 , s1 + val) + // m1 = (s0 + val, s1 ) + // + // these will be encrypted by the OT keys + for (u64 k = 0; k < 2; ++k) + { + if (k == 0) + { + // m0 = (s0, s1 + val) + ctx.copy(leafOts[0], leafSums[0][j]); + ctx.plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); + } + else + { + // m1 = (s0+val, s1) + ctx.plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); + ctx.copy(leafOts[1], leafSums[1][j]); + } + + // copy m0 into the output buffer. + span buff = leafMsgs.subspan(0, 2 * ctx.template byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } + } + } + else + { + VecF leafOts; + ctx.resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + for (u64 k = 0; k < 2; ++k) + { + // copy the sum k into the output buffer. + ctx.copy(leafOts[0], leafSums[k][j]); + span buff = leafMsgs.subspan(0, ctx.template byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } + } + } + + assert(leafMsgs.size() == 0); + } + + + }; + + + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class RegularPprfReceiver : public TimerAdapter + { + public: + + // the number of leaves in a single tree. + u64 mDomain = 0; + + // the depth of each tree. + u64 mDepth = 0; + + // the number of trees, must be a multiple of 8. + u64 mPntCount = 0; + + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + + // base ots that will be used to expand the tree. + Matrix mBaseOTs; + + // the choice bits, each row should be the bit decomposition of the active path. + Matrix mBaseChoices; + + // if true, tree OT messages are eagerly sent in batches of 8. + // otherwise, the OT messages are sent in a single batch. + bool mEagerSend = true; + + // a function that can be used to output the result of the PPRF. + std::function mOutputFn; + + // an internal buffer that is used to expand the tree. + AlignedUnVector mTempBuffer; + + RegularPprfReceiver() = default; + RegularPprfReceiver(const RegularPprfReceiver&) = delete; + RegularPprfReceiver(RegularPprfReceiver&&) = delete; + + void configure(u64 domainSize, u64 pointCount) + { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); + + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + + mBaseOTs.resize(0, 0); + } + + + // this function sample mPntCount integers in the range + // [0,domain) and returns these as the choice bits. + BitVector sampleChoiceBits(PRNG& prng) + { + BitVector choices(mPntCount * mDepth); + + // The points are read in blocks of 8, so make sure that there is a + // whole number of blocks. + mBaseChoices.resize(mPntCount, mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + u64 idx = prng.get() % mDomain; + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = *BitIterator((u8*)&idx, j); + } + + for (u64 i = 0; i < mBaseChoices.size(); ++i) + { + choices[i] = mBaseChoices(i); + } + + return choices; + } + + // choices is in the same format as the output from sampleChoiceBits. + void setChoiceBits(const BitVector& choices) + { + // Make sure we're given the right number of OTs. + if (choices.size() != baseOtCount()) + throw RTE_LOC; + + mBaseChoices.resize(mPntCount, mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + u64 idx = 0; + for (u64 j = 0; j < mDepth; ++j) + { + mBaseChoices(i, j) = choices[mDepth * i + j]; + idx |= u64(choices[mDepth * i + j]) << j; + } + + if (idx >= mDomain) + throw std::runtime_error("provided choice bits index outside of the domain." LOCATION); + } + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const + { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const + { + return mBaseOTs.size(); + } + + + void setBase(span baseMessages) + { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + // The OTs are used in blocks of 8, so make sure that there is a whole + // number of blocks. + mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); + memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); + } + + std::vector getPoints(PprfOutputFormat format) + { + std::vector pnts(mPntCount); + getPoints(pnts, format); + return pnts; + } + void getPoints(span points, PprfOutputFormat format) + { + if ((u64)points.size() != mPntCount) + throw RTE_LOC; + + switch (format) + { + case PprfOutputFormat::ByLeafIndex: + case PprfOutputFormat::ByTreeIndex: + + memset(points.data(), 0, points.size() * sizeof(u64)); + for (u64 j = 0; j < mPntCount; ++j) + { + for (u64 k = 0; k < mDepth; ++k) + points[j] |= u64(mBaseChoices(j, k)) << k; + + assert(points[j] < mDomain); + } + + + break; + case PprfOutputFormat::Interleaved: + case PprfOutputFormat::Callback: + + getPoints(points, PprfOutputFormat::ByLeafIndex); + + // in interleaved mode we generate 8 trees in a batch. + // the i'th leaf of these 8 trees are next to eachother. + for (u64 j = 0; j < points.size(); ++j) + { + auto subTree = j % 8; + auto batch = j / 8; + points[j] = (batch * mDomain + points[j]) * 8 + subTree; + } + + //interleavedPoints(points, mDomain, format); + + break; + default: + throw RTE_LOC; + break; + } + } + + // programPuncturedPoint says whether the sender is trying to program the + // active child to be its correct value XOR delta. If it is not, the + // active child will just take a random value. + task<> expand( + Socket& chl, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads, + CoeffCtx ctx = {}) + try + { + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); + + auto treeIndex = u64{}; + auto levels = std::vector>>{}; + auto leafIndex = u64{}; + auto leafLevelPtr = (VecF*)nullptr; + auto leafLevel = VecF{}; + auto buff = std::vector{}; + auto encSums = span>{}; + auto leafMsgs = span{}; + auto points = std::vector{}; + auto encStepSize = u64{}; + auto leafStepSize = u64{}; + auto encOffset = u64{}; + auto leafOffset = u64{}; + + setTimePoint("SilentMultiPprfReceiver.start"); + points.resize(mPntCount); + getPoints(points, PprfOutputFormat::ByLeafIndex); + + //setTimePoint("SilentMultiPprfSender.reserve"); + + auto dd = mDomain > 2 ? roundUpTo((mDomain + 1) / 2, 2) : 1; + pprf::allocateExpandTree(dd, mTempBuffer, levels); + assert(levels.size() == mDepth); + + + if (!mEagerSend) + { + // we need to allocate one large buffer that will store all OT messages. + pprf::allocateExpandBuffer( + mDepth - 1, mPntCount, programPuncturedPoint, buff, encSums, leafMsgs, ctx); + encStepSize = encSums.size() / mPntCount; + leafStepSize = leafMsgs.size() / mPntCount; + encOffset = 0; + leafOffset = 0; + + co_await(chl.recv(buff)); + } + + for (treeIndex = 0ull; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + ctx.resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } + + auto min = std::min(8, mPntCount - treeIndex); + if (mEagerSend) + { + + // allocate the send buffer and partition it. + pprf::allocateExpandBuffer(mDepth - 1, min, + programPuncturedPoint, buff, encSums, leafMsgs, ctx); + encStepSize = encSums.size() / min; + leafStepSize = leafMsgs.size() / min; + encOffset = 0; + leafOffset = 0; + co_await(chl.recv(buff)); + } + + // exapnd the tree + expandOne( + treeIndex, + programPuncturedPoint, + levels, + *leafLevelPtr, + leafIndex, + encSums.subspan(encOffset, encStepSize * min), + leafMsgs.subspan(leafOffset, leafStepSize * min), + points, + ctx); + + encOffset += encStepSize * min; + leafOffset += leafStepSize * min; + + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + } + + setTimePoint("SilentMultiPprfReceiver.join"); + + mBaseOTs = {}; + + setTimePoint("SilentMultiPprfReceiver.de-alloc"); + } + catch (...) + { + chl.close(); + throw; + } + + void clear() + { + mBaseOTs.resize(0, 0); + mBaseChoices.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + void expandOne( + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF& leafLevel, + const u64 outputOffset, + span> theirSums, + span leafMsg, + span points, + CoeffCtx& ctx) + { + auto remTrees = std::min(8, mPntCount - treeIdx); + assert(theirSums.size() == remTrees * (mDepth - 1)); + + // We change the hash function for the leaf so lets update + // inactiveChildValues to use the new hash and subtract + // these from the leafSums + std::array leafSums; + if (mDepth > 1) + { + auto theirSumsIter = theirSums.begin(); + + // special case for the first level. + auto l1 = levels[1]; + for (u64 i = 0; i < remTrees; ++i) + { + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + + int active = mBaseChoices[i + treeIdx].back(); + l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ (*theirSumsIter)[active ^ 1]; + l1[active][i] = ZeroBlock; + ++theirSumsIter; + //if (!i) + // std::cout << " unmask " + // << mBaseOTs[i + treeIdx].back() << " ^ " + // << theirSums[0][active ^ 1][i] << " = " + // << l1[active ^ 1][i] << std::endl; + + } + + // space for our sums of each level. + std::array, 2> mySums; + + // this will be the value of both children of active an parent + // before the active child is updated. We will need to subtract + // this value as the main loop does not distinguish active parents. + std::array inactiveChildValues; + inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); + inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); + + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < mDepth - 1; ++d) + { + // initialized the sums with inactiveChildValue so that + // it will cancel when we expand the actual inactive child. + std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); + std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); + + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + assert(level0.size() == width || level0.size() == width + 1); + + // The next level that we want to construct. + auto level1 = levels[d + 1]; + assert(level1.size() == width * 2); + + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto parent = level0[parentIdx]; + + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent but this will cancel + // with inactiveChildValue thats already there. + mySums[0][0] = mySums[0][0] ^ child0[0]; + mySums[0][1] = mySums[0][1] ^ child0[1]; + mySums[0][2] = mySums[0][2] ^ child0[2]; + mySums[0][3] = mySums[0][3] ^ child0[3]; + mySums[0][4] = mySums[0][4] ^ child0[4]; + mySums[0][5] = mySums[0][5] ^ child0[5]; + mySums[0][6] = mySums[0][6] ^ child0[6]; + mySums[0][7] = mySums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + mySums[1][0] = mySums[1][0] ^ child1[0]; + mySums[1][1] = mySums[1][1] ^ child1[1]; + mySums[1][2] = mySums[1][2] ^ child1[2]; + mySums[1][3] = mySums[1][3] ^ child1[3]; + mySums[1][4] = mySums[1][4] ^ child1[4]; + mySums[1][5] = mySums[1][5] ^ child1[5]; + mySums[1][6] = mySums[1][6] ^ child1[6]; + mySums[1][7] = mySums[1][7] ^ child1[7]; + + } + + + // we have to update the non-active child of the active parent. + for (u64 i = 0; i < remTrees; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i + treeIdx]; + + // The index of the active (missing) child node. + auto missingChildIdx = leafIdx >> (mDepth - 1 - d); + + // The index of the active child node sibling. + auto siblingIdx = missingChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = siblingIdx & 1; + + // our sums & OTs cancel and we are leaf with the + // correct value for the inactive child. + level1[siblingIdx][i] = + (*theirSumsIter)[notAi] ^ + mySums[notAi][i] ^ + mBaseOTs(i + treeIdx, mDepth - 1 - d); + + ++theirSumsIter; + + // we have to set the active child to zero so + // the next children are predictable. + level1[missingChildIdx][i] = ZeroBlock; + } + } + + auto d = mDepth - 1; + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + + // The next level of theGGM tree that we are populating. + std::array child; + + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + VecF temp; + ctx.resize(temp, 2); + for (u64 k = 0; k < 2; ++k) + { + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + ctx.fromBlock(temp[k], gGgmAes[k].hashBlock(ZeroBlock)); + ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); + for (u64 i = 1; i < 8; ++i) + ctx.copy(leafSums[k][i], leafSums[k][0]); + } + + auto outIter = leafLevel.data() + outputOffset; + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto parent = level0.data()[parentIdx]; + + for (u64 keep = 0; keep < 2; ++keep, ++childIdx) + { + if constexpr (std::is_same_v) + { + gGgmAes.data()[keep].hashBlocks<8>(parent.data(), outIter); + } + else + { + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes.data()[keep].hashBlocks<8>(parent.data(), child.data()); + + ctx.fromBlock(*(outIter + 0), child.data()[0]); + ctx.fromBlock(*(outIter + 1), child.data()[1]); + ctx.fromBlock(*(outIter + 2), child.data()[2]); + ctx.fromBlock(*(outIter + 3), child.data()[3]); + ctx.fromBlock(*(outIter + 4), child.data()[4]); + ctx.fromBlock(*(outIter + 5), child.data()[5]); + ctx.fromBlock(*(outIter + 6), child.data()[6]); + ctx.fromBlock(*(outIter + 7), child.data()[7]); + } + auto& leafSum = leafSums[keep]; + ctx.plus(leafSum.data()[0], leafSum.data()[0], *(outIter + 0)); + ctx.plus(leafSum.data()[1], leafSum.data()[1], *(outIter + 1)); + ctx.plus(leafSum.data()[2], leafSum.data()[2], *(outIter + 2)); + ctx.plus(leafSum.data()[3], leafSum.data()[3], *(outIter + 3)); + ctx.plus(leafSum.data()[4], leafSum.data()[4], *(outIter + 4)); + ctx.plus(leafSum.data()[5], leafSum.data()[5], *(outIter + 5)); + ctx.plus(leafSum.data()[6], leafSum.data()[6], *(outIter + 6)); + ctx.plus(leafSum.data()[7], leafSum.data()[7], *(outIter + 7)); + + outIter += 8; + assert(outIter <= leafLevel.data() + leafLevel.size()); + } + } + } + else + { + for (u64 k = 0; k < 2; ++k) + { + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + } + } + + // leaf level. + if (programPuncturedPoint) + { + // Now processes the leaf level. This one is special + // because we must XOR in the correction value as + // before but we must also fixed the child value for + // the active child. To do this, we will receive 4 + // values. Two for each case (left active or right active). + //timer.setTimePoint("recv.recvleaf"); + VecF leafOts; + ctx.resize(leafOts, 2); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + + // The index of the child on the active path. + auto activeChildIdx = points[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // offset to the first or second ot message, based on the one we want + auto offset = ctx.template byteSize() * 2 * notAi; + + + // decrypt the ot string + span buff = leafMsg.subspan(offset, ctx.template byteSize() * 2); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); + + auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; + auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; + + ctx.minus(leafLevel[out0], leafOts[0], leafSums[0][j]); + ctx.minus(leafLevel[out1], leafOts[1], leafSums[1][j]); + } + } + else + { + VecF leafOts; + ctx.resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + // The index of the child on the active path. + auto activeChildIdx = points[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // offset to the first or second ot message, based on the one we want + auto offset = ctx.template byteSize() * notAi; + + // decrypt the ot string + span buff = leafMsg.subspan(offset, ctx.template byteSize()); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); + + std::array out{ + (activeChildIdx & ~1ull) * 8 + j + outputOffset, + (activeChildIdx | 1ull) * 8 + j + outputOffset + }; + + auto keep = leafLevel.begin() + out[notAi]; + auto zero = leafLevel.begin() + out[notAi ^ 1]; + + ctx.minus(*keep, leafOts[0], leafSums[notAi][j]); + ctx.zero(zero, zero + 1); + } + } + } + }; } #endif \ No newline at end of file diff --git a/libOTe/TwoChooseOne/Kos/KosOtExtReceiver.cpp b/libOTe/TwoChooseOne/Kos/KosOtExtReceiver.cpp index 8f0552d..fef9024 100644 --- a/libOTe/TwoChooseOne/Kos/KosOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Kos/KosOtExtReceiver.cpp @@ -77,7 +77,7 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { // we are going to process OTs in blocks of 128 * superBlkSize messages. if (hasBaseOts() == false) co_await genBaseOts(prng, chl); @@ -241,6 +241,11 @@ namespace osuCrypto if(mIsMalicious) co_await(chl.send(std::move(uBuff))); } + catch (...) + { + chl.close(); + throw; + } AlignedUnVector KosOtExtReceiver::hash( span message, diff --git a/libOTe/TwoChooseOne/Kos/KosOtExtSender.cpp b/libOTe/TwoChooseOne/Kos/KosOtExtSender.cpp index 910a81f..c50236b 100644 --- a/libOTe/TwoChooseOne/Kos/KosOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Kos/KosOtExtSender.cpp @@ -56,7 +56,7 @@ namespace osuCrypto span> messages, PRNG& prng, Socket& chl) - { + try { if (hasBaseOts() == false) co_await genBaseOts(prng, chl); @@ -269,7 +269,11 @@ namespace osuCrypto } setTimePoint("Kos.send.done"); - + } + catch (...) + { + chl.close(); + throw; } diff --git a/libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.cpp b/libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.cpp index 0b449ad..ec7528d 100644 --- a/libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.cpp +++ b/libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.cpp @@ -61,7 +61,7 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { auto numOtExt = u64{}; auto numSuperBlocks = u64{}; auto numBlocks = u64{}; @@ -308,5 +308,10 @@ namespace osuCrypto setTimePoint("KosDot.recv.done"); static_assert(gOtExtBaseOtCount == 128, "expecting 128"); } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/TwoChooseOne/KosDot/KosDotExtSender.cpp b/libOTe/TwoChooseOne/KosDot/KosDotExtSender.cpp index e3c89c6..b9601a4 100644 --- a/libOTe/TwoChooseOne/KosDot/KosDotExtSender.cpp +++ b/libOTe/TwoChooseOne/KosDot/KosDotExtSender.cpp @@ -57,7 +57,7 @@ namespace osuCrypto span> messages, PRNG& prng, Socket& chl) - { + try { auto numOtExt = u64{}; auto numSuperBlocks = u64{}; @@ -319,6 +319,11 @@ namespace osuCrypto static_assert(gOtExtBaseOtCount == 128, "expecting 128"); } } + catch (...) + { + chl.close(); + throw; + } } #endif \ No newline at end of file diff --git a/libOTe/TwoChooseOne/OTExtInterface.cpp b/libOTe/TwoChooseOne/OTExtInterface.cpp index e11bc3d..efd9a37 100644 --- a/libOTe/TwoChooseOne/OTExtInterface.cpp +++ b/libOTe/TwoChooseOne/OTExtInterface.cpp @@ -20,16 +20,21 @@ namespace osuCrypto } task<> OtExtReceiver::genBaseOts(OtSender& base, PRNG& prng, Socket& chl) - { + try { auto count = baseOtCount(); auto msgs = std::vector>{}; msgs.resize(count); co_await base.send(msgs, prng, chl); setBaseOts(msgs); } + catch (...) + { + chl.close(); + throw; + } task<> OtExtSender::genBaseOts(PRNG& prng, Socket& chl) - { + try { #ifdef LIBOTE_HAS_BASE_OT auto base = DefaultBaseOT{}; co_await genBaseOts(base, prng, chl); @@ -37,11 +42,15 @@ namespace osuCrypto throw std::runtime_error("The libOTe library does not have base OTs. Enable them to call this. " LOCATION); co_return; #endif - + } + catch (...) + { + chl.close(); + throw; } task<> OtExtSender::genBaseOts(OtReceiver& base, PRNG& prng, Socket& chl) - { + try { auto count = baseOtCount(); auto msgs = std::vector{}; auto bv = BitVector{}; @@ -52,6 +61,11 @@ namespace osuCrypto setBaseOts(msgs, bv); } + catch (...) + { + chl.close(); + throw; + } task<> OtReceiver::receiveChosen( @@ -59,7 +73,7 @@ namespace osuCrypto span recvMessages, PRNG& prng, Socket& chl) - { + try { auto temp = std::vector>(recvMessages.size()); co_await(receive(choices, recvMessages, prng, chl)); @@ -72,9 +86,14 @@ namespace osuCrypto ++iter; } } + catch (...) + { + chl.close(); + throw; + } task<> OtReceiver::receiveCorrelated(const BitVector& choices, span recvMessages, PRNG& prng, Socket& chl) - { + try { auto temp = std::vector(recvMessages.size()); co_await(receive(choices, recvMessages, prng, chl)); @@ -87,12 +106,17 @@ namespace osuCrypto ++iter; } } + catch (...) + { + chl.close(); + throw; + } task<> OtSender::sendChosen( span> messages, PRNG& prng, Socket& chl) - { + try { auto temp = std::vector>(messages.size()); co_await(send(temp, prng, chl)); @@ -104,4 +128,9 @@ namespace osuCrypto } co_await(chl.send(std::move(temp))); } + catch (...) + { + chl.close(); + throw; + } } diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 3f17219..a0c3ff2 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -152,7 +152,7 @@ namespace osuCrypto PRNG& prng, Socket& chl, bool useOtExtension) - { + try { auto choice = sampleBaseChoiceBits(prng); auto msg = AlignedUnVector{}; @@ -190,7 +190,12 @@ namespace osuCrypto #endif setSilentBaseOts(msg); setTimePoint("recver.gen.done"); - }; + } + catch (...) + { + chl.close(); + throw; + } u64 SilentOtExtReceiver::silentBaseOtCount() const { @@ -296,12 +301,17 @@ namespace osuCrypto span messages, PRNG& prng, Socket& chl) - { + try { auto randChoice = BitVector(messages.size()); co_await silentReceive(randChoice, messages, prng, chl, OTType::Random); randChoice ^= choices; co_await chl.send(std::move(randChoice)); } + catch (...) + { + chl.close(); + throw; + } task<> SilentOtExtReceiver::silentReceive( BitVector& choices, @@ -309,7 +319,7 @@ namespace osuCrypto PRNG& prng, Socket& chl, OTType type) - { + try { auto packing = (type == OTType::Random) ? ChoiceBitPacking::True : ChoiceBitPacking::False; @@ -339,13 +349,18 @@ namespace osuCrypto clear(); } + catch (...) + { + chl.close(); + throw; + } task<> SilentOtExtReceiver::silentReceiveInplace( u64 n, PRNG& prng, Socket& chl, ChoiceBitPacking type) - { + try { auto gapVals = std::vector{}; auto rT = MatrixView{}; @@ -392,7 +407,11 @@ namespace osuCrypto if (mC.size()) mC.resize(mRequestNumOts); } - + catch (...) + { + chl.close(); + throw; + } task<> SilentOtExtReceiver::ferretMalCheck(Socket& chl, PRNG& prng) { diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index 71b8ab5..5de7bb1 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -91,7 +91,7 @@ namespace osuCrypto } task<> SilentOtExtSender::genSilentBaseOts(PRNG& prng, Socket& chl, bool useOtExtension) - { + try { auto msg = AlignedUnVector>(silentBaseOtCount()); if (isConfigured() == false) @@ -125,6 +125,11 @@ namespace osuCrypto setTimePoint("sender.gen.done"); } + catch (...) + { + chl.close(); + throw; + } u64 SilentOtExtSender::silentBaseOtCount() const { @@ -195,7 +200,7 @@ namespace osuCrypto span> messages, PRNG& prng, Socket& chl) - { + try { auto correction = BitVector(messages.size()); auto iter = BitIterator{}; auto i = u64{}; @@ -212,16 +217,26 @@ namespace osuCrypto messages[i][1] = temp[bit ^ 1]; } } + catch (...) + { + chl.close(); + throw; + } task<> SilentOtExtSender::silentSend( span> messages, PRNG& prng, Socket& chl) - { + try { co_await(silentSendInplace(prng.get(), messages.size(), prng, chl)); hash(messages, ChoiceBitPacking::True); clear(); } + catch (...) + { + chl.close(); + throw; + } void SilentOtExtSender::hash( span> messages, @@ -313,7 +328,7 @@ namespace osuCrypto u64 n, PRNG& prng, Socket& chl) - { + try { auto delta = AlignedUnVector{}; setTimePoint("sender.expand.enter"); @@ -362,6 +377,11 @@ namespace osuCrypto mB.resize(mRequestNumOts); } + catch (...) + { + chl.close(); + throw RTE_LOC; + } task<> SilentOtExtSender::ferretMalCheck(Socket& chl, PRNG& prng) diff --git a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp index c8967c2..1a4c52a 100644 --- a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp +++ b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp @@ -12,7 +12,7 @@ namespace osuCrypto task<> SoftSpokenMalOtSender::send( span> messages, PRNG& prng, Socket& chl) - { + try { if ((u64)messages.data() % 32) throw std::runtime_error("soft spoken requires the messages to by 32 byte aligned. Consider using AlignedUnVector or AlignedVector." LOCATION); @@ -68,6 +68,11 @@ namespace osuCrypto co_await(mBase.mSubVole.checkResponse(chl)); } + catch (...) + { + chl.close(); + throw; + } task<> SoftSpokenMalOtSender::runBatch(Socket& chl, span messages) { @@ -284,7 +289,7 @@ namespace osuCrypto task<> SoftSpokenMalOtReceiver::receive( const BitVector& choices, span messages, PRNG& prng, Socket& chl) - { + try { if ((u64)messages.data() % 32) throw std::runtime_error("soft spoken requires the messages to by 32 byte aligned. Consider using AlignedUnVector or AlignedVector." LOCATION); @@ -374,6 +379,11 @@ namespace osuCrypto co_await(mBase.mSubVole.sendResponse(chl)); } + catch (...) + { + chl.close(); + throw; + } task<> SoftSpokenMalOtReceiver::runBatch(Socket& chl, span messages, span choices) { diff --git a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.cpp b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.cpp index e018089..b7fbee8 100644 --- a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.cpp +++ b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.cpp @@ -33,7 +33,7 @@ namespace osuCrypto template task<> SoftSpokenShOtSender::send( span> messages, PRNG& prng, Socket& chl) - { + try { auto numInstances = u64{}; auto numChunks = u64{}; auto chunkSize_ = u64{}; @@ -88,6 +88,11 @@ namespace osuCrypto temp); } } + catch (...) + { + chl.close(); + throw; + } template void SoftSpokenShOtSender::processChunk( @@ -170,7 +175,7 @@ namespace osuCrypto template task<> SoftSpokenShOtReceiver::receive( const BitVector& choices, span messages, PRNG& prng, Socket& chl) - { + try { auto numInstances = u64{}; auto numChunks = u64{}; auto nChunk = u64{}; @@ -239,7 +244,11 @@ namespace osuCrypto if (hasSendBuffer()) co_await sendBuffer(chl); } - + catch (...) + { + chl.close(); + throw; + } template void SoftSpokenShOtReceiver::processChunk( diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.h b/libOTe/Vole/Noisy/NoisyVoleReceiver.h index 44e1128..266fe12 100644 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.h +++ b/libOTe/Vole/Noisy/NoisyVoleReceiver.h @@ -48,7 +48,7 @@ namespace osuCrypto { template task<> receive(VecG& c, VecF& a, PRNG& prng, OtSender& ot, Socket& chl, CoeffCtx ctx) - { + try { auto otMsg = AlignedUnVector>{}; setTimePoint("NoisyVoleReceiver.ot.begin"); @@ -59,6 +59,11 @@ namespace osuCrypto { co_await(receive(c, a, prng, otMsg, chl, ctx)); } + catch (...) + { + chl.close(); + throw; + } // for chosen c, compute a such htat // @@ -68,7 +73,7 @@ namespace osuCrypto { task<> receive(VecG& c, VecF& a, PRNG& _, span> otMsg, Socket& chl, CoeffCtx ctx) - { + try { auto buff = std::vector{}; auto msg = VecF{}; auto temp = VecF{}; @@ -136,6 +141,11 @@ namespace osuCrypto { co_await(chl.send(std::move(buff))); setTimePoint("NoisyVoleReceiver.done"); } + catch (...) + { + chl.close(); + throw; + } }; diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.h b/libOTe/Vole/Noisy/NoisyVoleSender.h index 6d6b136..1337263 100644 --- a/libOTe/Vole/Noisy/NoisyVoleSender.h +++ b/libOTe/Vole/Noisy/NoisyVoleSender.h @@ -53,7 +53,8 @@ namespace osuCrypto { // template task<> send(F delta, FVec& b, PRNG& prng, - OtReceiver& ot, Socket& chl, CoeffCtx ctx) { + OtReceiver& ot, Socket& chl, CoeffCtx ctx) + try { auto bv = ctx.binaryDecomposition(delta); auto otMsg = AlignedUnVector{ }; otMsg.resize(bv.size()); @@ -66,6 +67,11 @@ namespace osuCrypto { co_await(send(delta, b, prng, otMsg, chl, ctx)); } + catch (...) + { + chl.close(); + throw; + } // for chosen delta, compute b such htat // @@ -73,7 +79,8 @@ namespace osuCrypto { // template task<> send(F delta, FVec& b, PRNG& _, - span otMsg, Socket& chl, CoeffCtx ctx) { + span otMsg, Socket& chl, CoeffCtx ctx) + try { auto prng = PRNG{}; auto buffer = std::vector{}; auto msg = VecF{}; @@ -126,6 +133,11 @@ namespace osuCrypto { setTimePoint("NoisyVoleSender.done"); } + catch (...) + { + chl.close(); + throw; + } }; } // namespace osuCrypto diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp index 14005ae..8769b6a 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp @@ -383,7 +383,7 @@ namespace osuCrypto } task<> SmallFieldVoleSender::expand(Socket& chl,PRNG& prng, u64 numThreads) - { + try { auto corrections = std::vector>{}; auto hashes = std::vector>{}; auto seedView = MatrixView{}; @@ -430,10 +430,15 @@ namespace osuCrypto } } + catch (...) + { + chl.close(); + throw; + } task<> SmallFieldVoleReceiver::expand(Socket& chl, PRNG& prng, u64 numThreads) - { + try { auto seeds = AlignedUnVector{}; auto seedsFull = MatrixView{}; auto totals = std::vector>{}; @@ -544,6 +549,11 @@ namespace osuCrypto } } + catch (...) + { + chl.close(); + throw; + } // TODO: Malicious version. Should use an actual hash function for bottom layer of tree.