diff --git a/CMakeLists.txt b/CMakeLists.txt index 2159d148b0..c7a5320e2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required (VERSION 3.6) +cmake_minimum_required (VERSION 2.8) project("libPSI") diff --git a/cryptoTools/CMakeLists.txt b/cryptoTools/CMakeLists.txt index c22f6deae4..02ff780a2e 100644 --- a/cryptoTools/CMakeLists.txt +++ b/cryptoTools/CMakeLists.txt @@ -81,4 +81,4 @@ target_link_libraries(cryptoTools ${MIRACL_LIB}) target_link_libraries(cryptoTools ${Boost_LIBRARIES}) -#set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +#set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) \ No newline at end of file diff --git a/cryptoTools/Common/ArrayView.h b/cryptoTools/Common/ArrayView.h index a0774e74f4..1a246327d7 100644 --- a/cryptoTools/Common/ArrayView.h +++ b/cryptoTools/Common/ArrayView.h @@ -14,7 +14,7 @@ namespace osuCrypto { :mBegin(begin), mCur(cur), mEnd(end) { if (mCur > mEnd) throw std::runtime_error("iter went past end. " LOCATION); - if (mCur < mBegin - 1) throw std::runtime_error("iter went past begin. " LOCATION); + if (mCur && mCur < mBegin - 1) throw std::runtime_error("iter went past begin. " LOCATION); } T* mBegin, *mCur, *mEnd; @@ -28,17 +28,6 @@ namespace osuCrypto { ++mCur; return ArrayIterator(mBegin, mCur - 1, mEnd); } - - ArrayIterator operator+(int i) { - return ArrayIterator(mBegin, mCur + i, mEnd); - } - - ArrayIterator& operator+=(int i) { - mCur += i; - if (mCur > mEnd) throw std::runtime_error("iter went past end. " LOCATION); - return *this; - } - ArrayIterator& operator--() { --mCur; if (mCur < mBegin - 1) throw std::runtime_error("iter went past end. " LOCATION); @@ -50,20 +39,33 @@ namespace osuCrypto { return ArrayIterator(mBegin, mCur + 1, mEnd); } - ArrayIterator operator-(int i) { + ArrayIterator operator+(i64 i) { + return ArrayIterator(mBegin, mCur + i, mEnd); + } + + ArrayIterator& operator+=(i64 i) { + mCur += i; + if (mCur > mEnd) throw std::runtime_error("iter went past end. " LOCATION); + return *this; + } + + ArrayIterator operator-(i64 i) { return ArrayIterator(mBegin, mCur - i, mEnd); } - ArrayIterator& operator-=(int i) { + ArrayIterator& operator-=(i64 i) { mCur -= i; if (mCur < mBegin - 1) throw std::runtime_error("iter went past end. " LOCATION); return *this; } + i64 operator-(T* i) { + return mCur - i; + } T& operator*() { if (mCur >= mEnd || mCur < mBegin)throw std::runtime_error("deref past begin or end. " LOCATION); - return *mCur; + return *mCur; } T* operator->() @@ -83,7 +85,7 @@ namespace osuCrypto { bool operator==(const ArrayIterator& cmp) { return mCur == cmp.mCur; } bool operator!=(const ArrayIterator& cmp) { return mCur != cmp.mCur; } - ArrayIterator* operator=(const ArrayIterator& cmp) + ArrayIterator& operator=(const ArrayIterator& cmp) { mBegin = cmp.mBegin; mCur = cmp.mCur; mEnd = cmp.mEnd; return *this; } @@ -95,12 +97,13 @@ namespace osuCrypto { template class ArrayView { - + T* mData; u64 mSize; bool mOwner; - public: + public: + typedef T value_type; ArrayView() :mData(nullptr), @@ -138,13 +141,17 @@ namespace osuCrypto { {} //template - + template ArrayView(Iter start, Iter end, typename Iter::iterator_category *p = 0) : mData(&*start), mSize(end - start), mOwner(false) { + //static_assert(std::is_same::value, "Iter iter must have the same value_type as ArrayView"); + //(void*)p; + std::ignore = p; + } ArrayView(T* begin, T* end, bool owner) : @@ -153,14 +160,26 @@ namespace osuCrypto { mOwner(owner) {} - ArrayView(std::vector& container) - : mData(container.data()), - mSize(container.size()), + + template class C, typename... Args> + ArrayView(const C& cont, typename C::value_type* p = 0) : + mData(((C&)cont).data()), + mSize((((C&)cont).end() - ((C&)cont).begin())), mOwner(false) { + std::ignore = p; + //static_assert(std::is_same::value_type, T>::value, "Container cont must have the same value_type as ArrayView"); + //(void*)p; } - template + //ArrayView(std::vector& container) + // : mData(container.data()), + // mSize(container.size()), + // mOwner(false) + //{ + //} + + template ArrayView(std::array& container) : mData(container.data()), mSize(container.size()), @@ -210,7 +229,7 @@ namespace osuCrypto { inline T& operator[](u64 idx) const { #ifndef NDEBUG - if (idx >= mSize) throw std::runtime_error(LOCATION); + if (idx >= mSize) throw std::runtime_error(LOCATION); #endif return mData[idx]; diff --git a/cryptoTools/Common/ByteStream.cpp b/cryptoTools/Common/ByteStream.cpp index b1de392e52..6611514419 100644 --- a/cryptoTools/Common/ByteStream.cpp +++ b/cryptoTools/Common/ByteStream.cpp @@ -60,7 +60,7 @@ namespace osuCrypto { { if (loc > mCapacity) throw std::runtime_error("rt error at " LOCATION); mPutHead = loc; - mGetHead = std::min(mGetHead, mPutHead); + mGetHead = std::min(mGetHead, mPutHead); } u64 ByteStream::tellg()const diff --git a/cryptoTools/Common/Defines.cpp b/cryptoTools/Common/Defines.cpp index 48a9fecec8..bb60885f42 100644 --- a/cryptoTools/Common/Defines.cpp +++ b/cryptoTools/Common/Defines.cpp @@ -5,7 +5,7 @@ namespace osuCrypto { - Timer gTimer; + Timer gTimer(true); const block ZeroBlock = _mm_set_epi64x(0, 0); const block OneBlock = _mm_set_epi64x(0, 1); const block AllOneBlock = _mm_set_epi64x(u64(-1), u64(-1)); diff --git a/cryptoTools/Common/Defines.h b/cryptoTools/Common/Defines.h index 499f59928e..f46c0aa058 100644 --- a/cryptoTools/Common/Defines.h +++ b/cryptoTools/Common/Defines.h @@ -1,5 +1,5 @@ #pragma once -// This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. +// This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. #include #include @@ -20,13 +20,13 @@ #pragma GCC diagnostic ignored "-Wignored-attributes" #endif -#ifdef _MSC_VER +#ifdef _MSC_VER #define __STR2__(x) #x #define __STR1__(x) __STR2__(x) #define TODO(x) __pragma(message (__FILE__ ":"__STR1__(__LINE__) " Warning:TODO - " #x)) -#define ALIGNED(__Declaration, __alignment) __declspec(align(__alignment)) __Declaration +#define ALIGNED(__Declaration, __alignment) __declspec(align(__alignment)) __Declaration #else -#define TODO(x) +#define TODO(x) #define ALIGNED(__Declaration, __alignment) __Declaration __attribute__((aligned (16))) #endif @@ -64,8 +64,11 @@ namespace osuCrypto { } typedef __m128i block; - inline block toBlock(u8*data) - { return _mm_set_epi64x(((u64*)data)[1], ((u64*)data)[0]);} + + inline block toBlock(u8*data) { return _mm_set_epi64x(((u64*)data)[1], ((u64*)data)[0]);} + + inline block toBlock(u64 x) { return _mm_set_epi64x(0,x); } + inline block toBlock(u64 x, u64 y) { return _mm_set_epi64x(x,y); } template using MultiBlock = std::array; @@ -93,7 +96,7 @@ namespace osuCrypto { return _mm_add_epi64(lhs, rhs); } - + #endif template @@ -141,7 +144,7 @@ namespace osuCrypto { } std::ostream& operator<<(std::ostream& out, const block& block); - + template std::ostream& operator<<(std::ostream& out, const MultiBlock& block); diff --git a/cryptoTools/Common/Log.cpp b/cryptoTools/Common/Log.cpp index b59d080577..a6e3ec8cd4 100644 --- a/cryptoTools/Common/Log.cpp +++ b/cryptoTools/Common/Log.cpp @@ -10,6 +10,7 @@ namespace osuCrypto { + std::mutex gIoStreamMtx; void setThreadName(const std::string name) diff --git a/cryptoTools/Common/Log.h b/cryptoTools/Common/Log.h index 9aecc434fa..3f947f7a27 100644 --- a/cryptoTools/Common/Log.h +++ b/cryptoTools/Common/Log.h @@ -8,6 +8,7 @@ namespace osuCrypto { + enum class Color { LightGreen = 2, LightGrey = 3, @@ -41,8 +42,5 @@ namespace osuCrypto void setThreadName(const std::string name); void setThreadName(const char* name); - -} - - \ No newline at end of file +} diff --git a/cryptoTools/Common/MatrixView.h b/cryptoTools/Common/MatrixView.h index 555a8145ea..8964d9aa9f 100644 --- a/cryptoTools/Common/MatrixView.h +++ b/cryptoTools/Common/MatrixView.h @@ -21,7 +21,7 @@ namespace osuCrypto public: - + typedef T value_type; MatrixView() :mData(nullptr), @@ -48,11 +48,7 @@ namespace osuCrypto MatrixView(u64 rowSize, u64 columnSize) : -#ifdef NDEBUG - mData(new T[rowSize * columnSize]), -#else mData(new T[rowSize * columnSize]()), -#endif mSize({ rowSize, columnSize }), mOwner(true) { } @@ -77,13 +73,32 @@ namespace osuCrypto mSize({ (end - start) / numColumns, numColumns }), mOwner(false) { + //static_assert(std::is_same::value, "Iter iter must have the same value_type as ArrayView"); + std::ignore = p; + } - //MatrixView(T* data, u64 rowSize, u64 columnSize) : - // mData(data), - // mSize({ rowSize, columnSize }), + //template + //MatrixView(const C& cont, u64 numColumns, typename C::value_type* p = 0) : + // mData(&*((C&)cont).begin()), + // mSize({ (((C&)cont).end() - ((C&)cont).begin()) / numColumns, numColumns }), // mOwner(false) - //{} + //{ + // static_assert(std::is_same::value, "Container cont must have the same value_type as ArrayView"); + + // (void*)p; + //} + + template class C, typename... Args> + MatrixView(const C& cont, u64 numColumns, typename C::value_type* p = 0) : + mData(&*((C&)cont).begin()), + mSize({ (((C&)cont).end() - ((C&)cont).begin()) / numColumns, numColumns }), + mOwner(false) + { + //static_assert(std::is_same::value, "Container cont must have the same value_type as ArrayView"); + std::ignore = p; + + } ~MatrixView() @@ -129,11 +144,11 @@ namespace osuCrypto }; ArrayIterator end() const { T* e = (T*)mData + (mSize[0] * mSize[1]); - return ArrayIterator(mData, e, e); + return ArrayIterator(mData, e, e); } #else T* begin() const { return mData; }; - T* end() const { return mData + mSize; } + T* end() const { return mData + mSize[0] * mSize[1]; } #endif ArrayView operator[](u64 rowIdx) const diff --git a/cryptoTools/Common/Timer.cpp b/cryptoTools/Common/Timer.cpp index 2e0b9cf68c..4885945f1c 100644 --- a/cryptoTools/Common/Timer.cpp +++ b/cryptoTools/Common/Timer.cpp @@ -9,11 +9,13 @@ namespace osuCrypto const Timer::timeUnit& Timer::setTimePoint(const std::string& msg) { + //if (mLocking) mMtx.lock(); mTimes.push_back(std::make_pair(timeUnit::clock::now(), msg)); - + auto& ret = mTimes.back().first; + //if (mLocking) mMtx.unlock(); //std::cout << msg << " " << std::chrono::duration_cast(mTimes.back().first - mStart).count() << std::endl; - return mTimes.back().first; + return ret; //return mStart; } @@ -30,7 +32,7 @@ namespace osuCrypto auto iter = timer.mTimes.begin(); out << iter->second; - u64 tabs = std::min((u64)4, (u64)4 - (iter->second.size() / 8)); + u64 tabs = std::min((u64)4, (u64)4 - (iter->second.size() / 8)); for (u64 i = 0; i < tabs; ++i) out << "\t"; @@ -42,7 +44,7 @@ namespace osuCrypto { out << iter->second; - tabs = std::min((u64)4, (u64)4 - (iter->second.size() / 8)); + tabs = std::min((u64)4, (u64)4 - (iter->second.size() / 8)); for (u64 i = 0; i < tabs ; ++i) out << "\t"; diff --git a/cryptoTools/Common/Timer.h b/cryptoTools/Common/Timer.h index 9b3085f05d..7eca57834e 100644 --- a/cryptoTools/Common/Timer.h +++ b/cryptoTools/Common/Timer.h @@ -3,7 +3,7 @@ #include #include #include - +#include namespace osuCrypto { @@ -14,10 +14,12 @@ namespace osuCrypto timeUnit mStart; std::list< std::pair> mTimes; - + bool mLocking; + //std::mutex mMtx; public: - Timer() + Timer(bool locking = false) :mStart(Timer::timeUnit::clock::now()) + , mLocking(locking) {} //Timer(const Timer&); diff --git a/cryptoTools/Crypto/Curve.cpp b/cryptoTools/Crypto/Curve.cpp index 0f1bf7f615..2d7ce601cf 100644 --- a/cryptoTools/Crypto/Curve.cpp +++ b/cryptoTools/Crypto/Curve.cpp @@ -374,11 +374,11 @@ namespace osuCrypto #endif if (mCurve->mIsPrimeField) { - return static_cast(epoint_comp(mCurve->mMiracl, mVal, cmp.mVal)); + return epoint_comp(mCurve->mMiracl, mVal, cmp.mVal) != 0; } else { - return static_cast(epoint2_comp(mCurve->mMiracl, mVal, cmp.mVal)); + return epoint2_comp(mCurve->mMiracl, mVal, cmp.mVal) != 0; } } bool EccPoint::operator!=( @@ -942,7 +942,7 @@ namespace osuCrypto //} } - void EccNumber::fromBytes(u8 * src) + void EccNumber::fromBytes(const u8 * src) { bytes_to_big(mCurve->mMiracl, (int)sizeBytes(), (char*)src, mVal); //mIsNres = NresState::nonNres; @@ -958,23 +958,23 @@ namespace osuCrypto //} } - void EccNumber::fromHex(char * src) + void EccNumber::fromHex(const char * src) { auto oldBase = mCurve->mMiracl->IOBASE; mCurve->mMiracl->IOBASE = 16; - cinstr(mCurve->mMiracl, mVal, src); + cinstr(mCurve->mMiracl, mVal, (char*)src); //mIsNres = NresState::nonNres; mCurve->mMiracl->IOBASE = oldBase; } - void EccNumber::fromDec(char * src) + void EccNumber::fromDec(const char * src) { auto oldBase = mCurve->mMiracl->IOBASE; mCurve->mMiracl->IOBASE = 10; - cinstr(mCurve->mMiracl, mVal, src); + cinstr(mCurve->mMiracl, mVal,(char*) src); //mIsNres = NresState::nonNres; mCurve->mMiracl->IOBASE = oldBase; @@ -1109,13 +1109,13 @@ namespace osuCrypto - result = static_cast(ebrick_init( + result = 0 < ebrick_init( mCurve->mMiracl, &mBrick, x, y, mCurve->BA, mCurve->BB, mCurve->getFieldPrime().mVal, - 8, mCurve->mEccpParams.bitCount)); + 8, mCurve->mEccpParams.bitCount); mirkill(x); mirkill(y); @@ -1124,7 +1124,7 @@ namespace osuCrypto { //fe2ec2(point)->getxy(x, y); - result = static_cast(ebrick2_init( + result = 0 < ebrick2_init( mCurve->mMiracl, &mBrick2, copy.mVal->X, @@ -1136,7 +1136,7 @@ namespace osuCrypto mCurve->mEcc2mParams.b, mCurve->mEcc2mParams.c, 8, - mCurve->mEcc2mParams.bitCount)); + mCurve->mEcc2mParams.bitCount); } if (result == 0) diff --git a/cryptoTools/Crypto/Curve.h b/cryptoTools/Crypto/Curve.h index ea7d932bd1..aca2a96f47 100644 --- a/cryptoTools/Crypto/Curve.h +++ b/cryptoTools/Crypto/Curve.h @@ -239,9 +239,9 @@ namespace osuCrypto u64 sizeBytes() const; void toBytes(u8* dest) const; - void fromBytes(u8* src); - void fromHex(char* src); - void fromDec(char* src); + void fromBytes(const u8* src); + void fromHex(const char* src); + void fromDec(const char* src); void randomize(PRNG& prng); void randomize(const block& seed); diff --git a/cryptoTools/Crypto/PRNG.cpp b/cryptoTools/Crypto/PRNG.cpp index d296352857..a4d529c7db 100644 --- a/cryptoTools/Crypto/PRNG.cpp +++ b/cryptoTools/Crypto/PRNG.cpp @@ -1,7 +1,7 @@ #include "PRNG.h" #include #include -#include +#include "Common/Log.h" namespace osuCrypto { diff --git a/cryptoTools/Crypto/sha1.cpp b/cryptoTools/Crypto/sha1.cpp index 8048d5249c..5186306011 100644 --- a/cryptoTools/Crypto/sha1.cpp +++ b/cryptoTools/Crypto/sha1.cpp @@ -185,7 +185,7 @@ namespace osuCrypto //mSha.Update(dataIn, length); while (length) { - u64 step = std::min(length, u64(64) - idx); + u64 step = std::min(length, u64(64) - idx); memcpy(buffer.data() + idx, dataIn, step); diff --git a/cryptoTools/Network/BtAcceptor.cpp b/cryptoTools/Network/BtAcceptor.cpp index 91bf5a3ca3..c020afb97a 100644 --- a/cryptoTools/Network/BtAcceptor.cpp +++ b/cryptoTools/Network/BtAcceptor.cpp @@ -3,7 +3,7 @@ #include "Network/BtChannel.h" #include "Network/Endpoint.h" #include "Common/Log.h" -#include +#include "Common/ByteStream.h" #include "boost/lexical_cast.hpp" diff --git a/cryptoTools/Network/BtChannel.cpp b/cryptoTools/Network/BtChannel.cpp index 1bd1b63470..1c3740dc06 100644 --- a/cryptoTools/Network/BtChannel.cpp +++ b/cryptoTools/Network/BtChannel.cpp @@ -1,7 +1,7 @@ #include "BtChannel.h" #include "Network/BtSocket.h" #include "Network/BtEndpoint.h" -#include "Common/Defines.h" +#include "Common/Defines.h" #include "Common/Log.h" namespace osuCrypto { @@ -204,6 +204,7 @@ namespace osuCrypto { if (mSocket) { mSocket->mTotalSentData = 0; + mSocket->mTotalRecvData = 0; mSocket->mMaxOutstandingSendData = 0; mSocket->mOutstandingSendData = 0; } @@ -214,6 +215,11 @@ namespace osuCrypto { return (mSocket) ? (u64)mSocket->mTotalSentData : 0; } + u64 BtChannel::getTotalDataRecv() const + { + return (mSocket) ? (u64)mSocket->mTotalRecvData : 0; + } + u64 BtChannel::getMaxOutstandingSendData() const { return (mSocket) ? (u64)mSocket->mMaxOutstandingSendData : 0; diff --git a/cryptoTools/Network/BtChannel.h b/cryptoTools/Network/BtChannel.h index 0f89d91939..c75d92a2ce 100644 --- a/cryptoTools/Network/BtChannel.h +++ b/cryptoTools/Network/BtChannel.h @@ -34,6 +34,7 @@ namespace osuCrypto { void resetStats() override; u64 getTotalDataSent() const override; + u64 getTotalDataRecv() const override; u64 getMaxOutstandingSendData() const override; diff --git a/cryptoTools/Network/BtIOService.cpp b/cryptoTools/Network/BtIOService.cpp index 122fd51584..a5236b01fe 100644 --- a/cryptoTools/Network/BtIOService.cpp +++ b/cryptoTools/Network/BtIOService.cpp @@ -115,55 +115,48 @@ namespace osuCrypto if (bytesTransfered != boost::asio::buffer_size(op.mBuffs[0]) || ec) + { + std::cout << ("rt error at " LOCATION " ec=" + ec.message() + ". else bytesTransfered != " + std::to_string(boost::asio::buffer_size(op.mBuffs[0]))) << std::endl; + std::cout << "This could be from the other end closing too early or the connection beign dropped." << std::endl; throw std::runtime_error("rt error at " LOCATION " ec=" + ec.message() + ". else bytesTransfered != " + std::to_string(boost::asio::buffer_size(op.mBuffs[0]))); + } - // Try to set the recv buffer to be the right size. - try { - // We support two types of receives. One where we provide the expected size of the message and one - // where we allow for variable length messages. op->other will be non null in the resize case and allow - // us to resize the ChannelBuffer which will hold the data. - if (op.mOther != nullptr) - { - // Get the ChannelBuffer from the multi purpose other pointer. - ChannelBuffer* mH = (ChannelBuffer*)op.mOther; - - // resize it. This could throw is the channel buffer chooses to. - mH->ChannelBufferResize(op.mSize); + // We support two types of receives. One where we provide the expected size of the message and one + // where we allow for variable length messages. op->other will be non null in the resize case and allow + // us to resize the ChannelBuffer which will hold the data. + if (op.mOther != nullptr) + { + // Get the ChannelBuffer from the multi purpose other pointer. + ChannelBuffer* mH = (ChannelBuffer*)op.mOther; - // set the WSA buffer to point into the channel buffer storage location. - op.mBuffs[1] = boost::asio::buffer((char*)mH->ChannelBufferData(), op.mSize); - } - else - { - // OK, this is the other type of recv where an expected size was provided. op->mWSABufs[1].len - // will contain the expected size and op->mSize contains the size reported in the header. - if (boost::asio::buffer_size(op.mBuffs[1]) != op.mSize) - throw std::runtime_error("The provided buffer does not fit the received message. Expected: " - + std::to_string(boost::asio::buffer_size(op.mBuffs[1])) + ", actual: " + std::to_string(op.mSize)); + // resize it. This could throw is the channel buffer chooses to. + mH->ChannelBufferResize(op.mSize); - } + // set the WSA buffer to point into the channel buffer storage location. + op.mBuffs[1] = boost::asio::buffer((char*)mH->ChannelBufferData(), op.mSize); } - catch (std::exception& e) + else { - // OK, something went wrong with resizing the recv buffer. Lets make our own buffer - std::unique_ptr newBuff(nullptr); - //std::unique_ptr newBuff(new char[op.mSize]); - //op.mBuffs[1] = boost::asio::buffer(newBuff.get(), op.mSize); - - // store the exception and then throw it once the recv is done. - try + // OK, this is the other type of recv where an expected size was provided. op->mWSABufs[1].len + // will contain the expected size and op->mSize contains the size reported in the header. + if (boost::asio::buffer_size(op.mBuffs[1]) != op.mSize) { - throw BadReceiveBufferSize(e.what(), boost::asio::buffer_size(op.mBuffs[1]), std::move(newBuff)); - } - catch (...) { + auto msg = "The provided buffer does not fit the received message. Expected: " + + std::to_string(boost::asio::buffer_size(op.mBuffs[1])) + ", actual: " + std::to_string(op.mSize); + std::cout << msg << std::endl; + + std::unique_ptr newBuff(nullptr); + auto e_ptr = std::make_exception_ptr(BadReceiveBufferSize(msg, op.mSize)); + op.mException = std::current_exception(); op.mPromise->set_exception(op.mException); delete op.mPromise; - return; + return; } } + boost::asio::async_read(socket->mHandle, std::array{ op.mBuffs[1] }, [&op, socket, this](const boost::system::error_code& ec, u64 bytesTransfered) @@ -308,10 +301,12 @@ namespace osuCrypto switch (op.mType) { case BoostIOOperation::Type::RecvData: + { - if (op.mOther == nullptr && boost::asio::buffer_size(op.mBuffs[1]) == 0) + if (op.mOther == nullptr && op.mSize == 0) throw std::runtime_error("rt error at " LOCATION); + } case BoostIOOperation::Type::CloseRecv: { @@ -319,6 +314,7 @@ namespace osuCrypto socket->mRecvStrand.post([this, socket, op]() { // the queue must be guarded from concurrent access, so add the op within the strand + socket->mTotalRecvData += op.mSize; // queue up the operation. socket->mRecvQueue.push_back(op); diff --git a/cryptoTools/Network/BtIOService.h b/cryptoTools/Network/BtIOService.h index 787e6a5666..042fd4612a 100644 --- a/cryptoTools/Network/BtIOService.h +++ b/cryptoTools/Network/BtIOService.h @@ -28,14 +28,20 @@ namespace osuCrypto public: std::string mWhat; u64 mLength; - std::unique_ptr mData; + //std::shared_ptr mData; - BadReceiveBufferSize(std::string what, u64 length, std::unique_ptr&& data) + BadReceiveBufferSize(std::string what, u64 length) : mWhat(what), - mLength(length), - mData(std::move(data)) + mLength(length) { } + + BadReceiveBufferSize(const BadReceiveBufferSize& src) = default; + BadReceiveBufferSize(BadReceiveBufferSize&& src) = default; + // : mWhat(src.mWhat) + // , mLength(src.mLength) + //{ + //} }; diff --git a/cryptoTools/Network/BtSocket.h b/cryptoTools/Network/BtSocket.h index 2e29cab217..6487d66269 100644 --- a/cryptoTools/Network/BtSocket.h +++ b/cryptoTools/Network/BtSocket.h @@ -76,7 +76,7 @@ namespace osuCrypto { std::deque mSendQueue, mRecvQueue; bool mStopped; - std::atomic mOutstandingSendData, mMaxOutstandingSendData, mTotalSentData; + std::atomic mOutstandingSendData, mMaxOutstandingSendData, mTotalSentData, mTotalRecvData; }; inline BtSocket::BtSocket(BtIOService& ios) : @@ -86,7 +86,8 @@ namespace osuCrypto { mStopped(false), mOutstandingSendData(0), mMaxOutstandingSendData(0), - mTotalSentData(0) + mTotalSentData(0), + mTotalRecvData(0) {} diff --git a/cryptoTools/Network/Channel.h b/cryptoTools/Network/Channel.h index 8dbc66c2d5..3e78b91e0f 100644 --- a/cryptoTools/Network/Channel.h +++ b/cryptoTools/Network/Channel.h @@ -47,6 +47,7 @@ namespace osuCrypto { virtual void resetStats() {}; virtual u64 getTotalDataSent() const = 0; + virtual u64 getTotalDataRecv() const = 0; virtual u64 getMaxOutstandingSendData() const = 0; diff --git a/frontend/OtBinMain.cpp b/frontend/OtBinMain.cpp index a2b858e783..0c58b8dcd9 100644 --- a/frontend/OtBinMain.cpp +++ b/frontend/OtBinMain.cpp @@ -1,4 +1,3 @@ -#include "bloomFilterMain.h" #include "Network/BtEndpoint.h" #include "OPPRF/OPPRFReceiver.h" @@ -31,909 +30,1036 @@ std::vector sendSet; std::vector mSet; u64 nParties(3); -#if 0 -void BarkOPRSend() +void Channel_test() { - Log::out << "dsfds" << Log::endl; + std::string name("psi"); - setThreadName("CP_Test_Thread"); - u64 numThreads(1); + BtIOService ios(0); + BtEndpoint ep0(ios, "localhost", 1212, false, name); + BtEndpoint ep1(ios, "localhost", 1212, true, name); + u8 dummy = 1; + u8 revDummy; + std::vector recvChl{ &ep0.addChannel(name, name) }; + std::vector sendChl{ &ep1.addChannel(name, name) }; - std::fstream online, offline; - online.open("./online.txt", online.trunc | online.out); - offline.open("./offline.txt", offline.trunc | offline.out); + std::thread thrd([&]() { + sendChl[0]->asyncSend(&dummy, 1); + }); - std::cout << "role = sender (" << numThreads << ") otBin" << std::endl; + recvChl[0]->recv(&revDummy, 1); + std::cout << static_cast(revDummy) << std::endl; - std::string name("psi"); + sendChl[0]->close(); + recvChl[0]->close(); - BtIOService ios(0); - BtEndpoint sendEP(ios, "localhost", 1213, true, name); + ep0.stop(); + ep1.stop(); + ios.stop(); +} +void Channel_party_test(u64 myIdx) +{ + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128, numThreads = 1; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - std::vector sendChls_(numThreads); - for (u64 i = 0; i < numThreads; ++i) - { - sendChls_[i] = &sendEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - } - u8 dummy[1]; + std::vector dummy(nParties); + std::vector revDummy(nParties); - senderGetLatency(*sendChls_[0]); - sendChls_[0]->resetStats(); - LinearCode code; - // code.loadBinFile(SOLUTION_DIR "/../libOTe/libOTe/Tools/bch511.bin"); + std::string name("psi"); + BtIOService ios(0); + + int btCount = nParties; + std::vector ep(nParties); - //for (auto pow : {/* 8,12,*/ 16/*, 20 */ }) - for (auto pow : pows) + for (u64 i = 0; i < nParties; ++i) { - - for (auto cc : threadss) + dummy[i] = myIdx * 10 + i; + if (i < myIdx) { - std::vector sendChls; - - if (pow == 8) - cc = std::min(8, cc); + u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 + ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + } + else if (i > myIdx) + { + u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 + ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver + } + } - //std::cout << "numTHreads = " << cc << std::endl; - sendChls.insert(sendChls.begin(), sendChls_.begin(), sendChls_.begin() + cc); + std::vector> chls(nParties); - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - //for (u64 numThreads = 1; numThreads < 129; numThreads *= 2) - for (u64 jj = 0; jj < numTrial; jj++) + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) { + //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); + chls[i][j] = &ep[i].addChannel(name, name); + } + } + } - //u64 repeatCount = 1; - u64 setSize = (1 << pow), psiSecParam = 40; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - - - - sendSet.resize(setSize); - for (u64 i = 0; i < setSize; ++i) - { - sendSet[i] = prng.get(); - } - std::cout << "s\n"; - std::cout << sendSet[5] << std::endl; + std::mutex printMtx1, printMtx2; + std::vector pThrds(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx < myIdx) { -#ifdef OOS - OosNcoOtReceiver otRecv(code); - OosNcoOtSender otSend(code); -#else - OPPRFReceiver otRecv; - KkrtNcoOtSender otSend; -#endif - BarkOPRFSender sendPSIs; - //gTimer.reset(); + chls[pIdx][0]->asyncSend(&dummy[pIdx], 1); + std::lock_guard lock(printMtx1); + std::cout << "s: " << myIdx << " -> " << pIdx << " : " << static_cast(dummy[pIdx]) << std::endl; - sendChls[0]->asyncSend(dummy, 1); - sendChls[0]->recv(dummy, 1); - u64 otIdx = 0; - //std::cout << "sender init" << std::endl; - sendPSIs.init(setSize, psiSecParam, 128, sendChls, otSend, prng.get()); - //std::cout << "s\n"; - // std::cout << otSend.mGens[5].mSeed << std::endl; + } + else if (pIdx > myIdx) { + chls[pIdx][0]->recv(&revDummy[pIdx], 1); + std::lock_guard lock(printMtx2); + std::cout << "r: " << myIdx << " <- " << pIdx << " : " << static_cast(revDummy[pIdx]) << std::endl; - //return; - sendChls[0]->asyncSend(dummy, 1); - sendChls[0]->recv(dummy, 1); - //std::cout << "sender init done" << std::endl; + } + }); + } - sendPSIs.sendInput(sendSet, sendChls); - // sendPSIs.mBins.print(); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + // if(pIdx!=myIdx) + pThrds[pIdx].join(); + } - u64 dataSent = 0; - for (u64 g = 0; g < sendChls.size(); ++g) - { - dataSent += sendChls[g]->getTotalDataSent(); - } - //std::accumulate(sendChls[0]->getTotalDataSent()) - //std::cout << setSize << " " << dataSent / std::pow(2, 20) << " byte " << std::endl; - for (u64 g = 0; g < sendChls.size(); ++g) - sendChls[g]->resetStats(); - //std::cout << gTimer << std::endl; + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) + { + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->close(); } - } - - } - for (u64 i = 0; i < numThreads; ++i) + + for (u64 i = 0; i < nParties; ++i) { - sendChls_[i]->close();// = &sendEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); + if (i != myIdx) + ep[i].stop(); } - //sendChl.close(); - //recvChl.close(); - sendEP.stop(); ios.stop(); } -void BarkOPRFRecv() -{ - setThreadName("CP_Test_Thread"); - u64 numThreads(1); +void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet) +{ + //nParties = 4; + std::fstream runtime; + if (myIdx == 0) + runtime.open("./runtime" + nParties, runtime.trunc | runtime.out); - std::fstream online, offline; - online.open("./online.txt", online.trunc | online.out); - offline.open("./offline.txt", offline.trunc | offline.out); + u64 offlineAvgTime(0), hashingAvgTime(0), getOPRFAvgTime(0), + ss2DirAvgTime(0), ssRoundAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); + u64 psiSecParam = 40, bitSize = 128, numThreads = 1; + PRNG prng(_mm_set_epi32(4253465, 3434565, myIdx, myIdx)); std::string name("psi"); - BtIOService ios(0); - BtEndpoint recvEP(ios, "localhost", 1213, false, name); - - LinearCode code; - // code.loadBinFile(SOLUTION_DIR "/../libOTe/libOTe/Tools/bch511.bin"); + int btCount = nParties; + std::vector ep(nParties); - std::vector recvChls_(numThreads); - for (u64 i = 0; i < numThreads; ++i) + for (u64 i = 0; i < nParties; ++i) { - recvChls_[i] = &recvEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); + if (i < myIdx) + { + u32 port = 1120 + i * 100 + myIdx;;//get the same port; i=1 & pIdx=2 =>port=102 + ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + } + else if (i > myIdx) + { + u32 port = 1120 + myIdx * 100 + i;//get the same port; i=2 & pIdx=1 =>port=102 + ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver + } } - std::cout << "role = recv(" << numThreads << ") otBin" << std::endl; - u8 dummy[1]; - recverGetLatency(*recvChls_[0]); + std::vector> chls(nParties); - //for (auto pow : {/* 8,12,*/16/*,20*/ }) - for (auto pow : pows) + for (u64 i = 0; i < nParties; ++i) { - for (auto cc : threadss) - { - std::vector recvChls; - - if (pow == 8) - cc = std::min(8, cc); - - u64 setSize = (1 << pow), psiSecParam = 40; - - std::cout << "numTHreads = " << cc << " n=" << setSize << std::endl; - - recvChls.insert(recvChls.begin(), recvChls_.begin(), recvChls_.begin() + cc); - - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - //for (u64 numThreads = 1; numThreads < 129; numThreads *= 2) - for (u64 jj = 0; jj < numTrial; jj++) + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) { + //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); + chls[i][j] = &ep[i].addChannel(name, name); + } + } + } - //u64 repeatCount = 1; - PRNG prng(_mm_set_epi32(42553465, 343452565, 2364435, 23923587)); - - - std::vector recvSet(setSize); - - + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + for (u64 idxTrial = 0; idxTrial < numTrial; idxTrial++) + { + std::vector set(setSize); + std::vector> sendPayLoads(nParties), recvPayLoads(nParties); - for (u64 i = 0; i < setSize; ++i) + for (u64 i = 0; i < setSize; ++i) + { + set[i] = mSet[i]; + } + PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, myIdx)); + set[0] = prng1.get();; + for (u64 idxP = 0; idxP < nParties; ++idxP) + { + sendPayLoads[idxP].resize(setSize); + recvPayLoads[idxP].resize(setSize); + for (u64 i = 0; i < setSize; ++i) + sendPayLoads[idxP][i] = prng.get(); + } + u64 nextNeighbor = (myIdx + 1) % nParties; + u64 prevNeighbor = (myIdx - 1 + nParties) % nParties; + //sum share of other party =0 => compute the share to his neighbor = sum of other shares + if (myIdx != 0) { + for (u64 i = 0; i < setSize; ++i) + { + block sum = ZeroBlock; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - recvSet[i] = prng.get(); - // sendSet[i];// = prng.get(); + if ((idxP != myIdx && idxP != nextNeighbor)) + sum = sum ^ sendPayLoads[idxP][i]; } - for (u64 i = 1; i < 3; ++i) + sendPayLoads[nextNeighbor][i] = sum; + + } + } + else + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[myIdx][i] = ZeroBlock; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - recvSet[i] = sendSet[i]; + if (idxP != myIdx) + sendPayLoads[myIdx][i] = sendPayLoads[myIdx][i] ^ sendPayLoads[idxP][i]; } + } - for (u64 i = setSize - 3; i < setSize; ++i) +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx != 0) { + for (u64 i = 0; i < setSize; ++i) + { + block check = ZeroBlock; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - recvSet[i] = sendSet[i]; + if (idxP != myIdx) + check = check ^ sendPayLoads[idxP][i]; } - - std::cout << "s\n"; - std::cout << recvSet[5] << std::endl; -#ifdef OOS - OosNcoOtReceiver otRecv(code); - OosNcoOtSender otSend(code); -#else - KkrtNcoOtReceiver otRecv; + if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) + std::cout << "Error ss values: myIdx: " << myIdx + << " value: " << check << std::endl; + } + } + else + for (u64 i = 0; i < setSize; ++i) + { + block check = ZeroBlock; + for (u64 idxP = 0; idxP < nParties; ++idxP) + { + check = check ^ sendPayLoads[idxP][i]; + } + if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) + std::cout << "Error ss values: myIdx: " << myIdx + << " value: " << check << std::endl; + } + std::cout << IoStream::unlock; #endif - BarkOPRFReceiver recvPSIs; - - - recvChls[0]->recv(dummy, 1); - gTimer.reset(); - recvChls[0]->asyncSend(dummy, 1); - - u64 otIdx = 0; - - Timer timer; - auto start = timer.setTimePoint("start"); - recvPSIs.init(setSize, psiSecParam, 128, recvChls, otRecv, ZeroBlock); - /*std::cout << "r\n"; - std::cout << otRecv.mGens[5][0].mSeed << std::endl; - std::cout << otRecv.mGens[5][1].mSeed << std::endl;*/ - - //return; - - - //std::vector sss(recvChls.size()); - //for (u64 g = 0; g < recvChls.size(); ++g) - //{ - // sss[g] = recvChls[g]->getTotalDataSent(); - //} - - recvChls[0]->asyncSend(dummy, 1); - recvChls[0]->recv(dummy, 1); - auto mid = timer.setTimePoint("init"); - - - recvPSIs.sendInput(recvSet, recvChls); - //recvPSIs.mBins.print(); - - - auto end = timer.setTimePoint("done"); - - auto offlineTime = std::chrono::duration_cast(mid - start).count(); - auto onlineTime = std::chrono::duration_cast(end - mid).count(); - - - offlineTimeTot += offlineTime; - onlineTimeTot += onlineTime; - //auto byteSent = recvChls[0]->getTotalDataSent() *recvChls.size(); - - u64 dataSent = 0; - for (u64 g = 0; g < recvChls.size(); ++g) - { - dataSent += recvChls[g]->getTotalDataSent(); - //std::cout << "chl[" << g << "] " << recvChls[g]->getTotalDataSent() << " " << sss[g] << std::endl; - } - - double time = offlineTime + onlineTime; - time /= 1000; - auto Mbps = dataSent * 8 / time / (1 << 20); - - std::cout << setSize << " " << offlineTime << " " << onlineTime << " " << Mbps << " Mbps " << (dataSent / std::pow(2.0, 20)) << " MB" << std::endl; - - for (u64 g = 0; g < recvChls.size(); ++g) - recvChls[g]->resetStats(); - - //std::cout << "threads = " << numThreads << std::endl << timer << std::endl << std::endl << std::endl; - - - //std::cout << numThreads << std::endl; - //std::cout << timer << std::endl; - - // std::cout << gTimer << std::endl; - - //if (recv.mIntersection.size() != setSize) - // throw std::runtime_error(""); + std::vector otRecv(nParties); + std::vector otSend(nParties); + std::vector send(nParties - myIdx - 1); + std::vector recv(myIdx); + binSet bins; + std::vector pThrds(nParties); + //########################## + //### Offline Phasing + //########################## + Timer timer; + auto start = timer.setTimePoint("start"); + bins.init(myIdx, nParties, setSize, psiSecParam); + u64 otCountSend = bins.mSimpleBins.mBins.size(); + u64 otCountRecv = bins.mCuckooBins.mBins.size(); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx < myIdx) { + //I am a receiver if other party idx < mine + recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, true); + } + else if (pIdx > myIdx) { + send[pIdx - myIdx - 1].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), true); + } + }); + } + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + auto initDone = timer.setTimePoint("initDone"); +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + Log::out << otSend[2].mGens[0].get() << Log::endl; + if (otRecv[2].hasBaseOts()) + { + Log::out << otRecv[2].mGens[0][0].get() << Log::endl; + Log::out << otRecv[2].mGens[0][1].get() << Log::endl; } - - - - online << onlineTimeTot / numTrial << "-"; - offline << offlineTimeTot / numTrial << "-"; - + Log::out << "------------" << Log::endl; } - } + if (myIdx == 2) + { + if (otSend[0].hasBaseOts()) + Log::out << otSend[0].mGens[0].get() << Log::endl; - for (u64 i = 0; i < numThreads; ++i) - { - recvChls_[i]->close();// = &recvEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - } - //sendChl.close(); - //recvChl.close(); + Log::out << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << otRecv[0].mGens[0][1].get() << Log::endl; + } + std::cout << IoStream::unlock; +#endif - recvEP.stop(); + //########################## + //### Hashing + //########################## + bins.hashing2Bins(set, 1); - ios.stop(); -} + //if(myIdx==0) + // bins.mSimpleBins.print(myIdx, true, false, false, false); + //if (myIdx == 2) + // bins.mCuckooBins.print(myIdx, true, false, false); -void OPPRFSend() -{ - Log::out << "dsfds" << Log::endl; + auto hashingDone = timer.setTimePoint("hashingDone"); + //########################## + //### Online Phasing - compute OPRF + //########################## - setThreadName("CP_Test_Thread"); - u64 numThreads(1); + pThrds.clear(); + pThrds.resize(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx < myIdx) { + //I am a receiver if other party idx < mine + recv[pIdx].getOPRFkeys(pIdx, bins, chls[pIdx], true); + } + else if (pIdx > myIdx) { + send[pIdx - myIdx - 1].getOPRFKeys(pIdx, bins, chls[pIdx], true); + } + }); + } - std::fstream online, offline; - online.open("./online.txt", online.trunc | online.out); - offline.open("./offline.txt", offline.trunc | offline.out); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + //if (myIdx == 0) + //{ + // bins.mSimpleBins.print(2, true, true, false, false); + // //bins.mCuckooBins.print(2, true, false, false); + // Log::out << "------------" << Log::endl; + //} + //if (myIdx == 2) + //{ + // //bins.mSimpleBins.print(myIdx, true, false, false, false); + // bins.mCuckooBins.print(0, true, true, false); + //} + auto getOPRFDone = timer.setTimePoint("getOPRFDone"); - std::cout << "role = sender (" << numThreads << ") otBin" << std::endl; + //########################## + //### online phasing - secretsharing + //########################## + pThrds.clear(); + pThrds.resize(nParties); - std::string name("psi"); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if ((pIdx < myIdx && pIdx != prevNeighbor)) { + //I am a receiver if other party idx < mine + recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); + } + else if (pIdx > myIdx && pIdx != nextNeighbor) { + send[pIdx - myIdx - 1].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); + send[pIdx - myIdx - 1].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + } - BtIOService ios(0); - BtEndpoint sendEP(ios, "localhost", 1213, true, name); + else if (pIdx == prevNeighbor && myIdx != 0) { + recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); + } + else if (pIdx == nextNeighbor && myIdx != nParties - 1) + { + send[pIdx - myIdx - 1].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + } - std::vector sendChls_(numThreads); + else if (pIdx == nParties - 1 && myIdx == 0) { + send[pIdx - myIdx - 1].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); + } - for (u64 i = 0; i < numThreads; ++i) - { - sendChls_[i] = &sendEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - } - u8 dummy[1]; + else if (pIdx == 0 && myIdx == nParties - 1) + { + recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + } - senderGetLatency(*sendChls_[0]); - sendChls_[0]->resetStats(); + }); + } - LinearCode code; - // code.loadBinFile(SOLUTION_DIR "/../libOTe/libOTe/Tools/bch511.bin"); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); - //for (auto pow : {/* 8,12,*/ 16/*, 20 */ }) - for (auto pow : pows) - { + auto getSSDone2Dir = timer.setTimePoint("getSSDone2Dir"); - for (auto cc : threadss) +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) { - std::vector sendChls; - - if (pow == 8) - cc = std::min(8, cc); - - //std::cout << "numTHreads = " << cc << std::endl; - - sendChls.insert(sendChls.begin(), sendChls_.begin(), sendChls_.begin() + cc); - - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - //for (u64 numThreads = 1; numThreads < 129; numThreads *= 2) - for (u64 jj = 0; jj < numTrial; jj++) + for (int i = 0; i < 3; i++) { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[2][i], maskSize); + Log::out << "s " << myIdx << " - 2: Idx" << i << " - " << temp << Log::endl; - //u64 repeatCount = 1; - u64 setSize = (1 << pow), psiSecParam = 40; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - - + block temp1 = ZeroBlock; + memcpy((u8*)&temp1, (u8*)&recvPayLoads[2][i], maskSize); + Log::out << "r " << myIdx << " - 2: Idx" << i << " - " << temp1 << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == 2) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); + Log::out << "r " << myIdx << " - 0: Idx" << i << " - " << temp << Log::endl; - sendSet.resize(setSize); + block temp1 = ZeroBlock; + memcpy((u8*)&temp1, (u8*)&sendPayLoads[0][i], maskSize); + Log::out << "s " << myIdx << " - 0: Idx" << i << " - " << temp1 << Log::endl; + } + Log::out << "------------" << Log::endl; + } + std::cout << IoStream::unlock; +#endif + //########################## + //### online phasing - secretsharing - round + //########################## - for (u64 i = 0; i < setSize; ++i) + if (myIdx == 0) + { + // Xor the received shares + for (u64 i = 0; i < setSize; ++i) + { + for (u64 idxP = 0; idxP < nParties; ++idxP) { - sendSet[i] = prng.get(); + if (idxP != myIdx && idxP != prevNeighbor) + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; } + } - std::cout << "s\n"; - std::cout << sendSet[5] << std::endl; - -#ifdef OOS - OosNcoOtReceiver otRecv(code); - OosNcoOtSender otSend(code); -#else - OPPRFReceiver otRecv; - KkrtNcoOtSender otSend; -#endif - OPPRFSender sendPSIs; - - //gTimer.reset(); - - sendChls[0]->asyncSend(dummy, 1); - sendChls[0]->recv(dummy, 1); - u64 otIdx = 0; - //std::cout << "sender init" << std::endl; - // sendPSIs.init(setSize, psiSecParam, 128, sendChls, otSend, prng.get()); - //std::cout << "s\n"; - // std::cout << otSend.mGens[5].mSeed << std::endl; - + send[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); + send[nextNeighbor - myIdx - 1].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - //return; - sendChls[0]->asyncSend(dummy, 1); - sendChls[0]->recv(dummy, 1); - //std::cout << "sender init done" << std::endl; + } + else if (myIdx == nParties - 1) + { + recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - //sendPSIs.sendInput(sendSet.data(), sendSet.size(), sendChls); - // sendPSIs.sendInput(sendSet, sendChls); + //Xor the received shares + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; + for (u64 idxP = 0; idxP < nParties; ++idxP) + { + if (idxP != myIdx && idxP != prevNeighbor) + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; + } + } - // sendPSIs.mBins.print(); + recv[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); - u64 dataSent = 0; - for (u64 g = 0; g < sendChls.size(); ++g) + } + else + { + recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); + //Xor the received shares + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - dataSent += sendChls[g]->getTotalDataSent(); + if (idxP != myIdx && idxP != prevNeighbor) + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; } + } + send[nextNeighbor - myIdx - 1].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); + } - //std::accumulate(sendChls[0]->getTotalDataSent()) + auto getSSDoneRound = timer.setTimePoint("getSSDoneRound"); - //std::cout << setSize << " " << dataSent / std::pow(2, 20) << " byte " << std::endl; - for (u64 g = 0; g < sendChls.size(); ++g) - sendChls[g]->resetStats(); - //std::cout << gTimer << std::endl; +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + for (int i = 0; i < 5; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[1][i], maskSize); + Log::out << myIdx << " - " << temp << Log::endl; + //Log::out << recvPayLoads[2][i] << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == 1) + { + for (int i = 0; i < 5; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); + Log::out << myIdx << " - " << temp << Log::endl; + //Log::out << sendPayLoads[0][i] << Log::endl; } - } + std::cout << IoStream::unlock; +#endif + //########################## + //### online phasing - compute intersection + //########################## - } - for (u64 i = 0; i < numThreads; ++i) - { - sendChls_[i]->close();// = &sendEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - } - //sendChl.close(); - //recvChl.close(); + if (myIdx == 0) { + std::vector mIntersection; + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + for (u64 i = 0; i < setSize; ++i) + { + if (!memcmp((u8*)&sendPayLoads[myIdx][i], &recvPayLoads[prevNeighbor][i], maskSize)) + { + mIntersection.push_back(i); + } + } + Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; + } + auto getIntersection = timer.setTimePoint("getIntersection"); - sendEP.stop(); - ios.stop(); -} - -void OPPRFRecv() -{ - - setThreadName("CP_Test_Thread"); - u64 numThreads(1); - - std::fstream online, offline; - online.open("./online.txt", online.trunc | online.out); - offline.open("./offline.txt", offline.trunc | offline.out); + if (myIdx == 0) { + auto offlineTime = std::chrono::duration_cast(initDone - start).count(); + auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); + auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); + auto ss2DirTime = std::chrono::duration_cast(getSSDone2Dir - getOPRFDone).count(); + auto ssRoundTime = std::chrono::duration_cast(getSSDoneRound - getSSDone2Dir).count(); + auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDoneRound).count(); + double onlineTime = hashingTime + getOPRFTime + ss2DirTime + ssRoundTime + intersectionTime; - std::string name("psi"); + double time = offlineTime + onlineTime; + time /= 1000; - BtIOService ios(0); - BtEndpoint recvEP(ios, "localhost", 1213, false, name); + std::cout << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineTime << " ms\n" + << "hashingTime: " << hashingTime << " ms\n" + << "getOPRFTime: " << getOPRFTime << " ms\n" + << "ss2DirTime: " << ss2DirTime << " ms\n" + << "ssRoundTime: " << ssRoundTime << " ms\n" + << "intersection: " << intersectionTime << " ms\n" + << "onlineTime: " << onlineTime << " ms\n" + << "Total time: " << time << " s\n" + << "------------------\n"; - LinearCode code; - // code.loadBinFile(SOLUTION_DIR "/../libOTe/libOTe/Tools/bch511.bin"); + offlineAvgTime += offlineTime; + hashingAvgTime += hashingTime; + getOPRFAvgTime += getOPRFTime; + ss2DirAvgTime += ss2DirTime; + ssRoundAvgTime += ssRoundTime; + intersectionAvgTime += intersectionTime; + onlineAvgTime += onlineTime; - std::vector recvChls_(numThreads); - for (u64 i = 0; i < numThreads; ++i) - { - recvChls_[i] = &recvEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); + } + } - std::cout << "role = recv(" << numThreads << ") otBin" << std::endl; - u8 dummy[1]; - recverGetLatency(*recvChls_[0]); + if (myIdx == 0) { + double avgTime = (offlineAvgTime + onlineAvgTime); + avgTime /= 1000; + std::cout << "=========avg==========\n" + << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime.close(); + } - //for (auto pow : {/* 8,12,*/16/*,20*/ }) - for (auto pow : pows) + for (u64 i = 0; i < nParties; ++i) { - for (auto cc : threadss) + if (i != myIdx) { - std::vector recvChls; - - if (pow == 8) - cc = std::min(8, cc); - - u64 setSize = (1 << pow), psiSecParam = 40; - - std::cout << "numTHreads = " << cc << " n=" << setSize << std::endl; - - recvChls.insert(recvChls.begin(), recvChls_.begin(), recvChls_.begin() + cc); - - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - //for (u64 numThreads = 1; numThreads < 129; numThreads *= 2) - for (u64 jj = 0; jj < numTrial; jj++) + for (u64 j = 0; j < numThreads; ++j) { - - //u64 repeatCount = 1; - PRNG prng(_mm_set_epi32(42553465, 343452565, 2364435, 23923587)); - - - std::vector recvSet(setSize); - - - - - for (u64 i = 0; i < setSize; ++i) - { - recvSet[i] = prng.get(); - // sendSet[i];// = prng.get(); - } - for (u64 i = 1; i < 3; ++i) - { - recvSet[i] = sendSet[i]; - } - - for (u64 i = setSize - 3; i < setSize; ++i) - { - recvSet[i] = sendSet[i]; - } - - std::cout << "s\n"; - std::cout << recvSet[5] << std::endl; -#ifdef OOS - OosNcoOtReceiver otRecv(code); - OosNcoOtSender otSend(code); -#else - KkrtNcoOtReceiver otRecv; -#endif - OPPRFReceiver recvPSIs; - - - recvChls[0]->recv(dummy, 1); - gTimer.reset(); - recvChls[0]->asyncSend(dummy, 1); - - u64 otIdx = 0; - - - Timer timer; - auto start = timer.setTimePoint("start"); - // recvPSIs.init(setSize, psiSecParam, 128, recvChls, otRecv, ZeroBlock); - - /*std::cout << "r\n"; - std::cout << otRecv.mGens[5][0].mSeed << std::endl; - std::cout << otRecv.mGens[5][1].mSeed << std::endl;*/ - - //return; - - - //std::vector sss(recvChls.size()); - //for (u64 g = 0; g < recvChls.size(); ++g) - //{ - // sss[g] = recvChls[g]->getTotalDataSent(); - //} - - recvChls[0]->asyncSend(dummy, 1); - recvChls[0]->recv(dummy, 1); - auto mid = timer.setTimePoint("init"); - - - // recvPSIs.sendInput(recvSet.data(), recvSet.size(), recvChls); - // recvPSIs.sendInput(recvSet, recvChls); - //recvPSIs.mBins.print(); - - - auto end = timer.setTimePoint("done"); - - auto offlineTime = std::chrono::duration_cast(mid - start).count(); - auto onlineTime = std::chrono::duration_cast(end - mid).count(); - - - offlineTimeTot += offlineTime; - onlineTimeTot += onlineTime; - //auto byteSent = recvChls[0]->getTotalDataSent() *recvChls.size(); - - u64 dataSent = 0; - for (u64 g = 0; g < recvChls.size(); ++g) - { - dataSent += recvChls[g]->getTotalDataSent(); - //std::cout << "chl[" << g << "] " << recvChls[g]->getTotalDataSent() << " " << sss[g] << std::endl; - } - - double time = offlineTime + onlineTime; - time /= 1000; - auto Mbps = dataSent * 8 / time / (1 << 20); - - std::cout << setSize << " " << offlineTime << " " << onlineTime << " " << Mbps << " Mbps " << (dataSent / std::pow(2.0, 20)) << " MB" << std::endl; - - for (u64 g = 0; g < recvChls.size(); ++g) - recvChls[g]->resetStats(); - - //std::cout << "threads = " << numThreads << std::endl << timer << std::endl << std::endl << std::endl; - - - //std::cout << numThreads << std::endl; - //std::cout << timer << std::endl; - - // std::cout << gTimer << std::endl; - - //if (recv.mIntersection.size() != setSize) - // throw std::runtime_error(""); - - - - - - - + chls[i][j]->close(); } - - - - online << onlineTimeTot / numTrial << "-"; - offline << offlineTimeTot / numTrial << "-"; - } } - for (u64 i = 0; i < numThreads; ++i) + for (u64 i = 0; i < nParties; ++i) { - recvChls_[i]->close();// = &recvEP.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); + if (i != myIdx) + ep[i].stop(); } - //sendChl.close(); - //recvChl.close(); - recvEP.stop(); ios.stop(); } - -void OPPRF2_EmptrySet_Test() +void party3(u64 myIdx, u64 setSize, u64 nTrials) { - u64 setSize = 1 << 20, psiSecParam = 40, bitSize = 128, numParties = 2; + std::fstream runtime; + if (myIdx == 0) + runtime.open("./runtime3.txt", runtime.trunc | runtime.out); + + u64 offlineAvgTime(0), hashingAvgTime(0), getOPRFAvgTime(0), + secretSharingAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); + + u64 psiSecParam = 40, bitSize = 128, numThreads = 1; PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - std::vector sendSet(setSize), recvSet(setSize); - std::vector> sendPayLoads(numParties), recvPayLoads(numParties); + std::string name("psi"); + BtIOService ios(0); + int btCount = nParties; + std::vector ep(nParties); + u64 offlineTimeTot(0); + u64 onlineTimeTot(0); + Timer timer; - for (u64 i = 0; i < setSize; ++i) + for (u64 i = 0; i < nParties; ++i) { - sendSet[i] = prng.get(); - recvSet[i] = sendSet[i]; - } - - for (u64 j = 0; j < numParties; ++j) { - sendPayLoads[j].resize(setSize); - recvPayLoads[j].resize(setSize); - for (u64 i = 0; i < setSize; ++i) + if (i < myIdx) { - sendPayLoads[j][i] = prng.get(); + u32 port = 1120 + i * 100 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 + ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + } + else if (i > myIdx) + { + u32 port = 1120 + myIdx * 100 + i;//get the same port; i=2 & pIdx=1 =>port=102 + ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver } } - for (u64 i = 1; i < 3; ++i) - { - recvSet[i] = sendSet[i]; - } - - std::string name("psi"); - - BtIOService ios(0); - BtEndpoint ep0(ios, "localhost", 1212, true, name); - BtEndpoint ep1(ios, "localhost", 1212, false, name); - - - std::vector recvChl{ &ep1.addChannel(name, name) }; - std::vector sendChl{ &ep0.addChannel(name, name) }; - - KkrtNcoOtReceiver otRecv0, otRecv1; - KkrtNcoOtSender otSend0, otSend1; + std::vector> chls(nParties); + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); + chls[i][j] = &ep[i].addChannel(name, name); + } + } + } - OPPRFSender send; - OPPRFReceiver recv; - std::thread thrd([&]() { - - - send.init(numParties, setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); - //send.hash2Bins(sendSet, sendChl); - //send.getOPRFKeys(1, sendChl); - //send.sendSecretSharing(1, sendPayLoads[1], sendChl); - //send.revSecretSharing(1, recvPayLoads[1], sendChl); - //Log::out << "send.mSimpleBins.print(true, false, false,false);" << Log::endl; - //send.mSimpleBins.print(1, true, true, true, true); - //Log::out << "send.mCuckooBins.print(true, false, false);" << Log::endl; - //send.mCuckooBins.print(1,true, true, false); + PRNG prngSame(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + PRNG prngDiff(_mm_set_epi32(434653, 23, myIdx, myIdx)); + u64 expected_intersection; + u64 num_intersection; + double dataSent=0, Mbps=0, dateRecv=0, MbpsRecv=0; + for (u64 idxTrial = 0; idxTrial < nTrials; idxTrial++) + { + std::vector set(setSize); + block blk_rand = prngSame.get(); + expected_intersection = (*(u64*)&blk_rand) % setSize; - }); - Timer timer; - auto start = timer.setTimePoint("start"); - recv.init(numParties, setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); - - auto mid = timer.setTimePoint("init"); - //recv.hash2Bins(recvSet, recvChl); - auto mid_hashing = timer.setTimePoint("hashing"); - //recv.getOPRFkeys(0, recvChl); - //recv.revSecretSharing(0, recvPayLoads[0], recvChl); - //recv.sendSecretSharing(0, sendPayLoads[0], recvChl); - auto end = timer.setTimePoint("done"); - //Log::out << "recv.mCuckooBins.print(true, false, false);" << Log::endl; - //recv.mCuckooBins.print(0, true, true, false); - //Log::out << "recv.mSimpleBins.print(true, false, false,false);" << Log::endl; - //recv.mSimpleBins.print(0,true, true, true, true); - - auto offlineTime = std::chrono::duration_cast(mid - start).count(); - auto hashingTime = std::chrono::duration_cast(mid_hashing - mid).count(); - auto onlineTime = std::chrono::duration_cast(end - mid).count(); - - double time = offlineTime + onlineTime; - time /= 1000; - - std::cout << setSize << " " << offlineTime << " " << onlineTime - << " " << hashingTime - << " " << time << std::endl; - - for (u64 i = 1; i < recvPayLoads[0].size(); ++i) - { - if (memcmp((u8*)&recvPayLoads[0][i], &sendPayLoads[1][i], sizeof(block))) + for (u64 i = 0; i < expected_intersection; ++i) { - Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; - Log::out << recvSet[i] << Log::endl; - Log::out << sendSet[i] << Log::endl; - Log::out << i << Log::endl; + set[i] = prngSame.get(); } - } - for (u64 i = 1; i < recvPayLoads[1].size(); ++i) - { - if (memcmp((u8*)&recvPayLoads[1][i], &sendPayLoads[0][i], sizeof(block))) + for (u64 i = expected_intersection; i < setSize; ++i) { - Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; - Log::out << recvSet[i] << Log::endl; - Log::out << sendSet[i] << Log::endl; - Log::out << i << Log::endl; + set[i] = prngDiff.get(); } - } -#ifdef PRINT - std::cout << IoStream::lock; - for (u64 i = 1; i < recvPayLoads.size(); ++i) - { - Log::out << recvPayLoads[i] << Log::endl; - Log::out << sendPayLoads[i] << Log::endl; - if (memcmp((u8*)&recvPayLoads[i], &sendPayLoads[i], sizeof(block))) + std::vector sendPayLoads(setSize); + std::vector recvPayLoads(setSize); + + //only P0 genaretes secret sharing + if (myIdx == 0) { - Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; - Log::out << recvSet[i] << Log::endl; - Log::out << sendSet[i] << Log::endl; - Log::out << i << Log::endl; + for (u64 i = 0; i < setSize; ++i) + sendPayLoads[i] = prng.get(); } - } + std::vector otRecv(nParties); + std::vector otSend(nParties); - std::cout << IoStream::unlock; + OPPRFSender send; + OPPRFReceiver recv; + binSet bins; - std::cout << IoStream::lock; - Log::out << otSend0.mT.size()[1] << Log::endl; - Log::out << otSend1.mT.size()[1] << Log::endl; - Log::out << otSend0.mGens[0].get() << Log::endl; - Log::out << otRecv0.mGens[0][0].get() << Log::endl; - Log::out << otRecv0.mGens[0][1].get() << Log::endl; - Log::out << "------------" << Log::endl; - Log::out << otSend1.mGens[0].get() << Log::endl; - Log::out << otRecv1.mGens[0][0].get() << Log::endl; - Log::out << otRecv1.mGens[0][1].get() << Log::endl; - std::cout << IoStream::unlock; + std::vector pThrds(nParties); -#endif + //########################## + //### Offline Phasing + //########################## - thrd.join(); + auto start = timer.setTimePoint("start"); + bins.init(myIdx, nParties, setSize, psiSecParam); + u64 otCountSend = bins.mSimpleBins.mBins.size(); + u64 otCountRecv = bins.mCuckooBins.mBins.size(); + u64 nextNeibough = (myIdx + 1) % nParties; + u64 prevNeibough = (myIdx - 1 + nParties) % nParties; + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx == nextNeibough) { + //I am a sender to my next neigbour + send.init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), false); + } + else if (pIdx == prevNeibough) { + //I am a recv to my previous neigbour + recv.init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, false); + } + }); + } + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + Log::out << "------0------" << Log::endl; + Log::out << otSend[1].mGens[0].get() << Log::endl; + Log::out << otRecv[2].mGens[0][0].get() << Log::endl; + Log::out << otRecv[2].mGens[0][1].get() << Log::endl; + } + if (myIdx == 1) + { + Log::out << "------1------" << Log::endl; + Log::out << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << otRecv[0].mGens[0][1].get() << Log::endl; + Log::out << otSend[2].mGens[0].get() << Log::endl; + } - sendChl[0]->close(); - recvChl[0]->close(); - - ep0.stop(); - ep1.stop(); - ios.stop(); -} + if (myIdx == 2) + { + Log::out << "------2------" << Log::endl; + Log::out << otRecv[1].mGens[0][0].get() << Log::endl; + Log::out << otRecv[1].mGens[0][1].get() << Log::endl; + Log::out << otSend[0].mGens[0].get() << Log::endl; + } + std::cout << IoStream::unlock; #endif -void Channel_test() -{ - std::string name("psi"); + auto initDone = timer.setTimePoint("initDone"); - BtIOService ios(0); - BtEndpoint ep0(ios, "localhost", 1212, false, name); - BtEndpoint ep1(ios, "localhost", 1212, true, name); - u8 dummy = 1; - u8 revDummy; - std::vector recvChl{ &ep0.addChannel(name, name) }; - std::vector sendChl{ &ep1.addChannel(name, name) }; + //########################## + //### Hashing + //########################## + bins.hashing2Bins(set, nParties); + //bins.mSimpleBins.print(myIdx, true, false, false, false); + //bins.mCuckooBins.print(myIdx, true, false, false); - std::thread thrd([&]() { - sendChl[0]->asyncSend(&dummy, 1); - }); + auto hashingDone = timer.setTimePoint("hashingDone"); + //########################## + //### Online Phasing - compute OPRF + //########################## - recvChl[0]->recv(&revDummy, 1); - std::cout << static_cast(revDummy) << std::endl; + pThrds.clear(); + pThrds.resize(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { - sendChl[0]->close(); - recvChl[0]->close(); + if (pIdx == nextNeibough) { + //I am a sender to my next neigbour + send.getOPRFKeys(pIdx, bins, chls[pIdx], false); + } + else if (pIdx == prevNeibough) { + //I am a recv to my previous neigbour + recv.getOPRFkeys(pIdx, bins, chls[pIdx], false); + } + }); + } - ep0.stop(); - ep1.stop(); - ios.stop(); -} -void Channel_party_test(u64 myIdx) -{ - u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128, numThreads = 1; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + //if (myIdx == 2) + //{ + // //bins.mSimpleBins.print(2, true, true, false, false); + // bins.mCuckooBins.print(1, true, true, false); + // Log::out << "------------" << Log::endl; + //} + //if (myIdx == 1) + //{ + // bins.mSimpleBins.print(2, true, true, false, false); + // //bins.mCuckooBins.print(0, true, true, false); + //} - std::vector dummy(nParties); - std::vector revDummy(nParties); + auto getOPRFDone = timer.setTimePoint("getOPRFDone"); - std::string name("psi"); - BtIOService ios(0); + //########################## + //### online phasing - secretsharing + //########################## - int btCount = nParties; - std::vector ep(nParties); + pThrds.clear(); + pThrds.resize(nParties - 1); - for (u64 i = 0; i < nParties; ++i) - { - dummy[i] = myIdx * 10 + i; - if (i < myIdx) + if (myIdx == 0) { - u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 - ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + //{ + // pThrds[pIdx] = std::thread([&, pIdx]() { + // if (pIdx == 0) { + // send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); + // } + // else if (pIdx == 1) { + // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + // } + // }); + //} + send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); + recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); } - else if (i > myIdx) + else { - u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 - ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver + /*for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx == 0) { + recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + } + else if (pIdx == 1) { + send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); + } + }); + } */ + recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + //sendPayLoads = recvPayLoads; + send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); + } - } + /*for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join();*/ - std::vector> chls(nParties); + auto getSSDone = timer.setTimePoint("getSSDone"); - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) { - chls[i].resize(numThreads); - for (u64 j = 0; j < numThreads; ++j) +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + for (int i = 0; i < 5; i++) { - //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); - chls[i][j] = &ep[i].addChannel(name, name); + Log::out << sendPayLoads[i] << Log::endl; + //Log::out << recvPayLoads[2][i] << Log::endl; } + Log::out << "------------" << Log::endl; } - } + if (myIdx == 1) + { + for (int i = 0; i < 5; i++) + { + //Log::out << recvPayLoads[i] << Log::endl; + Log::out << sendPayLoads[i] << Log::endl; + } + } + if (myIdx == 2) + { + for (int i = 0; i < 5; i++) + { + Log::out << sendPayLoads[i] << Log::endl; + } + } + std::cout << IoStream::unlock; +#endif + //########################## + //### online phasing - compute intersection + //########################## + std::vector mIntersection; - std::mutex printMtx1, printMtx2; - std::vector pThrds(nParties); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx < myIdx) { + if (myIdx == 0) { + + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + for (u64 i = 0; i < setSize; ++i) + { + if (!memcmp((u8*)&sendPayLoads[i], &recvPayLoads[i], maskSize)) + { + mIntersection.push_back(i); + } + } + Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; + } + auto getIntersection = timer.setTimePoint("getIntersection"); + num_intersection = mIntersection.size(); - chls[pIdx][0]->asyncSend(&dummy[pIdx], 1); - std::lock_guard lock(printMtx1); - std::cout << "s: " << myIdx << " -> " << pIdx << " : " << static_cast(dummy[pIdx]) << std::endl; + if (myIdx == 0) { + auto offlineTime = std::chrono::duration_cast(initDone - start).count(); + auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); + auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); + auto secretSharingTime = std::chrono::duration_cast(getSSDone - getOPRFDone).count(); + auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDone).count(); + + double onlineTime = hashingTime + getOPRFTime + secretSharingTime + intersectionTime; + + double time = offlineTime + onlineTime; + time /= 1000; + + + dataSent = 0; + dateRecv = 0; + Mbps = 0; + MbpsRecv = 0; + + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + dataSent += chls[i][j]->getTotalDataSent(); + dateRecv += chls[i][j]->getTotalDataRecv(); + } + } } - else if (pIdx > myIdx) { - chls[pIdx][0]->recv(&revDummy[pIdx], 1); - std::lock_guard lock(printMtx2); - std::cout << "r: " << myIdx << " <- " << pIdx << " : " << static_cast(revDummy[pIdx]) << std::endl; + Mbps = dataSent * 8 / time / (1 << 20); + MbpsRecv = dataSent * 8 / time / (1 << 20); + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->resetStats(); + } + } } - }); - } + + Log::out << "#Output Intersection: " << num_intersection << Log::endl; + Log::out << "#Expected Intersection: " << expected_intersection << Log::endl; - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - // if(pIdx!=myIdx) - pThrds[pIdx].join(); - } + std::cout << "(ROUND OPPRF) numParty: " << nParties + << " setSize: " << setSize << "\n" + << "offlineTime: " << offlineTime << " ms\n" + << "hashingTime: " << hashingTime << " ms\n" + << "getOPRFTime: " << getOPRFTime << " ms\n" + << "secretSharing: " << secretSharingTime << " ms\n" + << "intersection: " << intersectionTime << " ms\n" + << "onlineTime: " << onlineTime << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << time << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dateRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + offlineAvgTime += offlineTime; + hashingAvgTime += hashingTime; + getOPRFAvgTime += getOPRFTime; + secretSharingAvgTime += secretSharingTime; + intersectionAvgTime += intersectionTime; + onlineAvgTime += onlineTime; + } + } + if (myIdx == 0) { + double avgTime = (offlineAvgTime + onlineAvgTime); + avgTime /= 1000; + std::cout << "=========avg==========\n" + << "(ROUND OPPRF) numParty: " << nParties + << " setSize: " << setSize + << " nTrials:" << nTrials << "\n" + << "offlineTime: " << offlineAvgTime / nTrials << " ms\n" + << "hashingTime: " << hashingAvgTime / nTrials << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / nTrials << " ms\n" + << "secretSharing: " << secretSharingAvgTime / nTrials << " ms\n" + << "intersection: " << intersectionAvgTime / nTrials << " ms\n" + << "onlineTime: " << onlineAvgTime / nTrials << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << time << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dateRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + + runtime << "(ROUND OPPRF) numParty: " << nParties + << " setSize: " << setSize + << " nTrials:" << nTrials << "\n" + << "offlineTime: " << offlineAvgTime / nTrials << " ms\n" + << "hashingTime: " << hashingAvgTime / nTrials << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / nTrials << " ms\n" + << "secretSharing: " << secretSharingAvgTime / nTrials << " ms\n" + << "intersection: " << intersectionAvgTime / nTrials << " ms\n" + << "onlineTime: " << onlineAvgTime / nTrials << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << time << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dateRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + runtime.close(); + } for (u64 i = 0; i < nParties; ++i) { @@ -956,19 +1082,29 @@ void Channel_party_test(u64 myIdx) ios.stop(); } - -void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet) +void party2(u64 myIdx, u64 setSize) { - //nParties = 4; - std::fstream runtime; - if (myIdx == 0) - runtime.open("./runtime" + nParties, runtime.trunc | runtime.out); + nParties = 2; + u64 psiSecParam = 40, bitSize = 128, numThreads = 1; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - u64 offlineAvgTime(0), hashingAvgTime(0), getOPRFAvgTime(0), - ss2DirAvgTime(0), ssRoundAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); + std::vector set(setSize); + for (u64 i = 0; i < setSize; ++i) + set[i] = mSet[i]; + + PRNG prng1(_mm_set_epi32(4253465, myIdx, myIdx, myIdx)); //for test + set[0] = prng1.get();; + + std::vector sendPayLoads(setSize); + std::vector recvPayLoads(setSize); + + //only P0 genaretes secret sharing + if (myIdx == 0) + { + for (u64 i = 0; i < setSize; ++i) + sendPayLoads[i] = prng.get(); + } - u64 psiSecParam = 40, bitSize = 128, numThreads = 1; - PRNG prng(_mm_set_epi32(4253465, 3434565, myIdx, myIdx)); std::string name("psi"); BtIOService ios(0); @@ -976,20 +1112,25 @@ void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet) int btCount = nParties; std::vector ep(nParties); + u64 offlineTimeTot(0); + u64 onlineTimeTot(0); + Timer timer; + for (u64 i = 0; i < nParties; ++i) { if (i < myIdx) { - u32 port = 1120 + i * 100 + myIdx;;//get the same port; i=1 & pIdx=2 =>port=102 + u32 port = 1210 + i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender } else if (i > myIdx) { - u32 port = 1120 + myIdx * 100 + i;//get the same port; i=2 & pIdx=1 =>port=102 + u32 port = 1210 + myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver } } + std::vector> chls(nParties); for (u64 i = 0; i < nParties; ++i) @@ -1004,429 +1145,186 @@ void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet) } } - u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; - - for (u64 idxTrial = 0; idxTrial < numTrial; idxTrial++) - { - std::vector set(setSize); - std::vector> sendPayLoads(nParties), recvPayLoads(nParties); - - for (u64 i = 0; i < setSize; ++i) - { - set[i] = mSet[i]; - } - PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, myIdx)); - set[0] = prng1.get();; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - sendPayLoads[idxP].resize(setSize); - recvPayLoads[idxP].resize(setSize); - for (u64 i = 0; i < setSize; ++i) - sendPayLoads[idxP][i] = prng.get(); - } - u64 nextNeighbor = (myIdx + 1) % nParties; - u64 prevNeighbor = (myIdx - 1 + nParties) % nParties; - //sum share of other party =0 => compute the share to his neighbor = sum of other shares - if (myIdx != 0) { - for (u64 i = 0; i < setSize; ++i) - { - block sum = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if ((idxP != myIdx && idxP != nextNeighbor)) - sum = sum ^ sendPayLoads[idxP][i]; - } - sendPayLoads[nextNeighbor][i] = sum; - - } - } - else - for (u64 i = 0; i < setSize; ++i) - { - sendPayLoads[myIdx][i] = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx) - sendPayLoads[myIdx][i] = sendPayLoads[myIdx][i] ^ sendPayLoads[idxP][i]; - } - } + std::vector otRecv(nParties); + std::vector otSend(nParties); -#ifdef PRINT - std::cout << IoStream::lock; - if (myIdx != 0) { - for (u64 i = 0; i < setSize; ++i) - { - block check = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx) - check = check ^ sendPayLoads[idxP][i]; - } - if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) - std::cout << "Error ss values: myIdx: " << myIdx - << " value: " << check << std::endl; - } - } - else - for (u64 i = 0; i < setSize; ++i) - { - block check = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - check = check ^ sendPayLoads[idxP][i]; - } - if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) - std::cout << "Error ss values: myIdx: " << myIdx - << " value: " << check << std::endl; - } - std::cout << IoStream::unlock; -#endif + OPPRFSender send; + OPPRFReceiver recv; + binSet bins; + std::vector pThrds(nParties); - std::vector otRecv(nParties); - std::vector otSend(nParties); + //########################## + //### Offline Phasing + //########################## - std::vector send(nParties - myIdx - 1); - std::vector recv(myIdx); - binSet bins; + auto start = timer.setTimePoint("start"); - std::vector pThrds(nParties); + bins.init(myIdx, nParties, setSize, psiSecParam); + u64 otCountSend = bins.mSimpleBins.mBins.size(); + u64 otCountRecv = bins.mCuckooBins.mBins.size(); - //########################## - //### Offline Phasing - //########################## - Timer timer; - auto start = timer.setTimePoint("start"); - bins.init(myIdx, nParties, setSize, psiSecParam); - u64 otCountSend = bins.mSimpleBins.mBins.size(); - u64 otCountRecv = bins.mCuckooBins.mBins.size(); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx < myIdx) { - //I am a receiver if other party idx < mine - recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, true); - } - else if (pIdx > myIdx) { - send[pIdx - myIdx - 1].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), true); - } - }); - } + if (myIdx == 0) { + //I am a sender to my next neigbour + send.init(nParties, setSize, psiSecParam, bitSize, chls[1], otCountSend, otSend[1], otRecv[1], prng.get(), false); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); + } + else if (myIdx == 1) { + //I am a recv to my previous neigbour + recv.init(nParties, setSize, psiSecParam, bitSize, chls[0], otCountRecv, otRecv[0], otSend[0], ZeroBlock, false); + } - auto initDone = timer.setTimePoint("initDone"); #ifdef PRINT - std::cout << IoStream::lock; - if (myIdx == 0) - { - Log::out << otSend[2].mGens[0].get() << Log::endl; - if (otRecv[2].hasBaseOts()) - { - Log::out << otRecv[2].mGens[0][0].get() << Log::endl; - Log::out << otRecv[2].mGens[0][1].get() << Log::endl; - } - Log::out << "------------" << Log::endl; - } - if (myIdx == 2) - { - if (otSend[0].hasBaseOts()) - Log::out << otSend[0].mGens[0].get() << Log::endl; + std::cout << IoStream::lock; + if (myIdx == 0) + { + Log::out << "------0------" << Log::endl; + Log::out << otSend[1].mGens[0].get() << Log::endl; + Log::out << otRecv[2].mGens[0][0].get() << Log::endl; + Log::out << otRecv[2].mGens[0][1].get() << Log::endl; + } + if (myIdx == 1) + { + Log::out << "------1------" << Log::endl; + Log::out << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << otRecv[0].mGens[0][1].get() << Log::endl; + Log::out << otSend[2].mGens[0].get() << Log::endl; + } - Log::out << otRecv[0].mGens[0][0].get() << Log::endl; - Log::out << otRecv[0].mGens[0][1].get() << Log::endl; - } - std::cout << IoStream::unlock; + if (myIdx == 2) + { + Log::out << "------2------" << Log::endl; + Log::out << otRecv[1].mGens[0][0].get() << Log::endl; + Log::out << otRecv[1].mGens[0][1].get() << Log::endl; + Log::out << otSend[0].mGens[0].get() << Log::endl; + } + std::cout << IoStream::unlock; #endif - //########################## - //### Hashing - //########################## - bins.hashing2Bins(set, 1); - - //if(myIdx==0) - // bins.mSimpleBins.print(myIdx, true, false, false, false); - //if (myIdx == 2) - // bins.mCuckooBins.print(myIdx, true, false, false); - - auto hashingDone = timer.setTimePoint("hashingDone"); - //########################## - //### Online Phasing - compute OPRF - //########################## + auto initDone = timer.setTimePoint("initDone"); - pThrds.clear(); - pThrds.resize(nParties); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx < myIdx) { - //I am a receiver if other party idx < mine - recv[pIdx].getOPRFkeys(pIdx, bins, chls[pIdx], true); - } - else if (pIdx > myIdx) { - send[pIdx - myIdx - 1].getOPRFKeys(pIdx, bins, chls[pIdx], true); - } - }); - } + //########################## + //### Hashing + //########################## + bins.hashing2Bins(set, 1); + //bins.mSimpleBins.print(myIdx, true, false, false, false); + //bins.mCuckooBins.print(myIdx, true, false, false); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); + auto hashingDone = timer.setTimePoint("hashingDone"); - //if (myIdx == 0) - //{ - // bins.mSimpleBins.print(2, true, true, false, false); - // //bins.mCuckooBins.print(2, true, false, false); - // Log::out << "------------" << Log::endl; - //} - //if (myIdx == 2) - //{ - // //bins.mSimpleBins.print(myIdx, true, false, false, false); - // bins.mCuckooBins.print(0, true, true, false); - //} + //########################## + //### Online Phasing - compute OPRF + //########################## - auto getOPRFDone = timer.setTimePoint("getOPRFDone"); - //########################## - //### online phasing - secretsharing - //########################## - pThrds.clear(); - pThrds.resize(nParties); + if (myIdx == 0) { + //I am a sender to my next neigbour + send.getOPRFKeys(1, bins, chls[1], false); + } + else if (myIdx == 1) { + //I am a recv to my previous neigbour + recv.getOPRFkeys(0, bins, chls[0], false); + } - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - pThrds[pIdx] = std::thread([&, pIdx]() { - if ((pIdx < myIdx && pIdx != prevNeighbor)) { - //I am a receiver if other party idx < mine - recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); - recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); - } - else if (pIdx > myIdx && pIdx != nextNeighbor) { - send[pIdx - myIdx - 1].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); - send[pIdx - myIdx - 1].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); - } - else if (pIdx == prevNeighbor && myIdx != 0) { - recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); - } - else if (pIdx == nextNeighbor && myIdx != nParties - 1) - { - send[pIdx - myIdx - 1].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); - } + //if (myIdx == 2) + //{ + // //bins.mSimpleBins.print(2, true, true, false, false); + // bins.mCuckooBins.print(1, true, true, false); + // Log::out << "------------" << Log::endl; + //} + //if (myIdx == 1) + //{ + // bins.mSimpleBins.print(2, true, true, false, false); + // //bins.mCuckooBins.print(0, true, true, false); + //} - else if (pIdx == nParties - 1 && myIdx == 0) { - send[pIdx - myIdx - 1].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); - } + auto getOPRFDone = timer.setTimePoint("getOPRFDone"); - else if (pIdx == 0 && myIdx == nParties - 1) - { - recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); - } - }); - } + //########################## + //### online phasing - secretsharing + //########################## + if (myIdx == 0) + { + // send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); + // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); + } + else + { + // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + //sendPayLoads = recvPayLoads; + // send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); + } - auto getSSDone2Dir = timer.setTimePoint("getSSDone2Dir"); #ifdef PRINT - std::cout << IoStream::lock; - if (myIdx == 0) - { - for (int i = 0; i < 3; i++) - { - block temp = ZeroBlock; - memcpy((u8*)&temp, (u8*)&sendPayLoads[2][i], maskSize); - Log::out << "s " << myIdx << " - 2: Idx" << i << " - " << temp << Log::endl; - - block temp1 = ZeroBlock; - memcpy((u8*)&temp1, (u8*)&recvPayLoads[2][i], maskSize); - Log::out << "r " << myIdx << " - 2: Idx" << i << " - " << temp1 << Log::endl; - } - Log::out << "------------" << Log::endl; - } - if (myIdx == 2) - { - for (int i = 0; i < 3; i++) - { - block temp = ZeroBlock; - memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); - Log::out << "r " << myIdx << " - 0: Idx" << i << " - " << temp << Log::endl; - - block temp1 = ZeroBlock; - memcpy((u8*)&temp1, (u8*)&sendPayLoads[0][i], maskSize); - Log::out << "s " << myIdx << " - 0: Idx" << i << " - " << temp1 << Log::endl; - } - Log::out << "------------" << Log::endl; - } - std::cout << IoStream::unlock; -#endif - //########################## - //### online phasing - secretsharing - round - //########################## - - if (myIdx == 0) - { - // Xor the received shares - for (u64 i = 0; i < setSize; ++i) - { - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx && idxP != prevNeighbor) - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; - } - } - - send[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); - send[nextNeighbor - myIdx - 1].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - - } - else if (myIdx == nParties - 1) - { - recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - - //Xor the received shares - for (u64 i = 0; i < setSize; ++i) - { - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx && idxP != prevNeighbor) - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; - } - } - - recv[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); - - } - else + std::cout << IoStream::lock; + if (myIdx == 0) + { + for (int i = 0; i < 5; i++) { - recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - //Xor the received shares - for (u64 i = 0; i < setSize; ++i) - { - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx && idxP != prevNeighbor) - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; - } - } - send[nextNeighbor - myIdx - 1].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); + Log::out << sendPayLoads[i] << Log::endl; + //Log::out << recvPayLoads[2][i] << Log::endl; } - - auto getSSDoneRound = timer.setTimePoint("getSSDoneRound"); - - -#ifdef PRINT - std::cout << IoStream::lock; - if (myIdx == 0) + Log::out << "------------" << Log::endl; + } + if (myIdx == 1) + { + for (int i = 0; i < 5; i++) { - for (int i = 0; i < 5; i++) - { - block temp = ZeroBlock; - memcpy((u8*)&temp, (u8*)&sendPayLoads[1][i], maskSize); - Log::out << myIdx << " - " << temp << Log::endl; - //Log::out << recvPayLoads[2][i] << Log::endl; - } - Log::out << "------------" << Log::endl; + //Log::out << recvPayLoads[i] << Log::endl; + Log::out << sendPayLoads[i] << Log::endl; } - if (myIdx == 1) + } + if (myIdx == 2) + { + for (int i = 0; i < 5; i++) { - for (int i = 0; i < 5; i++) - { - block temp = ZeroBlock; - memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); - Log::out << myIdx << " - " << temp << Log::endl; - //Log::out << sendPayLoads[0][i] << Log::endl; - } + Log::out << sendPayLoads[i] << Log::endl; } - std::cout << IoStream::unlock; + } + std::cout << IoStream::unlock; #endif - //########################## - //### online phasing - compute intersection - //########################## - - if (myIdx == 0) { - std::vector mIntersection; - u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; - for (u64 i = 0; i < setSize; ++i) - { - if (!memcmp((u8*)&sendPayLoads[myIdx][i], &recvPayLoads[prevNeighbor][i], maskSize)) - { - mIntersection.push_back(i); - } - } - Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; - } - auto getIntersection = timer.setTimePoint("getIntersection"); - - - if (myIdx == 0) { - auto offlineTime = std::chrono::duration_cast(initDone - start).count(); - auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); - auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); - auto ss2DirTime = std::chrono::duration_cast(getSSDone2Dir - getOPRFDone).count(); - auto ssRoundTime = std::chrono::duration_cast(getSSDoneRound - getSSDone2Dir).count(); - auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDoneRound).count(); - - double onlineTime = hashingTime + getOPRFTime + ss2DirTime + ssRoundTime + intersectionTime; - - double time = offlineTime + onlineTime; - time /= 1000; - - std::cout << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineTime << " ms\n" - << "hashingTime: " << hashingTime << " ms\n" - << "getOPRFTime: " << getOPRFTime << " ms\n" - << "ss2DirTime: " << ss2DirTime << " ms\n" - << "ssRoundTime: " << ssRoundTime << " ms\n" - << "intersection: " << intersectionTime << " ms\n" - << "onlineTime: " << onlineTime << " ms\n" - << "Total time: " << time << " s\n" - << "------------------\n"; - - - offlineAvgTime += offlineTime; - hashingAvgTime += hashingTime; - getOPRFAvgTime += getOPRFTime; - ss2DirAvgTime += ss2DirTime; - ssRoundAvgTime += ssRoundTime; - intersectionAvgTime += intersectionTime; - onlineAvgTime += onlineTime; + //########################## + //### online phasing - compute intersection + //########################## + if (myIdx == 0) { + std::vector mIntersection; + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + for (u64 i = 0; i < setSize; ++i) + { + if (!memcmp((u8*)&sendPayLoads[i], &recvPayLoads[i], maskSize)) + { + // mIntersection.push_back(i); + } } - + Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; } + auto end = timer.setTimePoint("getOPRFDone"); + if (myIdx == 0) { - double avgTime = (offlineAvgTime + onlineAvgTime); - avgTime /= 1000; - std::cout << "=========avg==========\n" - << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" - << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" - << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" - << "ss2DirTime: " << ss2DirAvgTime << " ms\n" - << "ssRoundTime: " << ssRoundAvgTime << " ms\n" - << "intersection: " << intersectionAvgTime / numTrial << " ms\n" - << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" - << "Total time: " << avgTime / numTrial << " s\n"; - runtime << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" - << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" - << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" - << "ss2DirTime: " << ss2DirAvgTime << " ms\n" - << "ssRoundTime: " << ssRoundAvgTime << " ms\n" - << "intersection: " << intersectionAvgTime / numTrial << " ms\n" - << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" - << "Total time: " << avgTime / numTrial << " s\n"; - runtime.close(); + auto offlineTime = std::chrono::duration_cast(initDone - start).count(); + auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); + auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); + auto endTime = std::chrono::duration_cast(end - getOPRFDone).count(); + + double time = offlineTime + hashingTime + getOPRFTime + endTime; + time /= 1000; + + std::cout << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineTime << "\n" + << "hashingTime: " << hashingTime << "\n" + << "getOPRFTime: " << getOPRFTime << "\n" + << "secretSharing: " << endTime << "\n" + << "onlineTime: " << hashingTime + getOPRFTime + endTime << "\n" + << "time: " << time << std::endl; } for (u64 i = 0; i < nParties; ++i) @@ -1449,34 +1347,65 @@ void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet) ios.stop(); } +bool is_in_dual_area(u64 startIdx, u64 endIdx, u64 numIdx, u64 checkIdx) { + bool res = false; + if (startIdx <= endIdx) + { + if (startIdx <= checkIdx && checkIdx <= endIdx) + res = true; + } + else //crosing 0, e.i, areas: startIdx....n-1, 0...endIdx + { + if ((0 <= checkIdx && checkIdx <= endIdx) //0...endIdx + || (startIdx <= checkIdx && checkIdx <= numIdx)) + //startIdx...n-1 + res = true; + } + return res; +} -void party3(u64 myIdx, u64 setSize, std::vector& mSet) +//leader is n-1 +void tparty(u64 myIdx, u64 nParties, u64 tParties, u64 setSize, u64 nTrials) { std::fstream runtime; + u64 leaderIdx = nParties - 1; //leader party + if (myIdx == 0) - runtime.open("./runtime3.txt", runtime.trunc | runtime.out); + runtime.open("./runtime_client.txt", runtime.app | runtime.out); + + if (myIdx == leaderIdx) + runtime.open("./runtime_leader.txt", runtime.app | runtime.out); + + +#pragma region setup + + u64 ttParties= tParties; + if (tParties == nParties - 1)//it is sufficient to prevent n-2 ssClientTimecorrupted parties since if n-1 corrupted and only now the part of intersection if all has x, i.e. x is in intersection. + ttParties = tParties - 1; + else if (tParties < 1) //make sure to do ss with at least one client + ttParties = 1; + + u64 nSS = nParties - 1; //n-2 parties joinly operated secrete sharing + int tSS = ttParties; //ss with t next parties, and last for leader => t+1 + u64 offlineAvgTime(0), hashingAvgTime(0), getOPRFAvgTime(0), - secretSharingAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); + ss2DirAvgTime(0), ssRoundAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); u64 psiSecParam = 40, bitSize = 128, numThreads = 1; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + PRNG prng(_mm_set_epi32(4253465, 3434565, myIdx, myIdx)); std::string name("psi"); BtIOService ios(0); - int btCount = nParties; - std::vector ep(nParties); - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - Timer timer; + std::vector ep(nParties); for (u64 i = 0; i < nParties; ++i) { if (i < myIdx) { - u32 port = 1120 + i * 100 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 + u32 port = 1120 + i * 100 + myIdx;;//get the same port; i=1 & pIdx=2 =>port=102 ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender } else if (i > myIdx) @@ -1486,573 +1415,822 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) } } - std::vector> chls(nParties); + std::vector dummy(nParties); + std::vector revDummy(nParties); for (u64 i = 0; i < nParties; ++i) { + dummy[i] = myIdx * 10 + i; + if (i != myIdx) { chls[i].resize(numThreads); for (u64 j = 0; j < numThreads; ++j) { //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); chls[i][j] = &ep[i].addChannel(name, name); + //chls[i][j].mEndpoint; + + + } } } - for (u64 idxTrial = 0; idxTrial < numTrial; idxTrial++) + + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + u64 nextNeighbor = (myIdx + 1) % nParties; + u64 prevNeighbor = (myIdx - 1 + nParties) % nParties; + u64 num_intersection; + double dataSent, Mbps, MbpsRecv, dataRecv ; +#pragma endregion + + PRNG prngSame(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + PRNG prngDiff(_mm_set_epi32(434653, 23, myIdx, myIdx)); + u64 expected_intersection; + + for (u64 idxTrial = 0; idxTrial < nTrials; idxTrial++) { +#pragma region input std::vector set(setSize); - for (u64 i = 0; i < setSize; ++i) - set[i] = mSet[i]; - PRNG prng1(_mm_set_epi32(4253465, myIdx, myIdx, myIdx)); //for test - set[0] = prng1.get();; + std::vector> + sendPayLoads(ttParties + 1), //include the last PayLoads to leader + recvPayLoads(ttParties); //received form clients - std::vector sendPayLoads(setSize); - std::vector recvPayLoads(setSize); + block blk_rand = prngSame.get(); + expected_intersection = (*(u64*)&blk_rand) % setSize; - //only P0 genaretes secret sharing - if (myIdx == 0) + for (u64 i = 0; i < expected_intersection; ++i) { - for (u64 i = 0; i < setSize; ++i) - sendPayLoads[i] = prng.get(); + set[i] = prngSame.get(); } - std::vector otRecv(nParties); - std::vector otSend(nParties); - - OPPRFSender send; - OPPRFReceiver recv; - binSet bins; - - std::vector pThrds(nParties); - - //########################## - //### Offline Phasing - //########################## + for (u64 i = expected_intersection; i < setSize; ++i) + { + set[i] = prngDiff.get(); + } - auto start = timer.setTimePoint("start"); - bins.init(myIdx, nParties, setSize, psiSecParam); - u64 otCountSend = bins.mSimpleBins.mBins.size(); - u64 otCountRecv = bins.mCuckooBins.mBins.size(); + if (myIdx != leaderIdx) {//generate share of zero for leader myIDx!=n-1 + for (u64 idxP = 0; idxP < ttParties; ++idxP) + { + sendPayLoads[idxP].resize(setSize); + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[idxP][i] = prng.get(); + } + } - u64 nextNeibough = (myIdx + 1) % nParties; - u64 prevNeibough = (myIdx - 1 + nParties) % nParties; + sendPayLoads[ttParties].resize(setSize); //share to leader at second phase + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[ttParties][i] = ZeroBlock; + for (u64 idxP = 0; idxP < ttParties; ++idxP) + { + sendPayLoads[ttParties][i] = + sendPayLoads[ttParties][i] ^ sendPayLoads[idxP][i]; + } + } + for (u64 idxP = 0; idxP < recvPayLoads.size(); ++idxP) + { + recvPayLoads[idxP].resize(setSize); + } - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + } + else { - pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx == nextNeibough) { - //I am a sender to my next neigbour - send.init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), false); + //leader: dont send; only receive ss from clients + sendPayLoads.resize(0);// + recvPayLoads.resize(nParties - 1); + for (u64 idxP = 0; idxP < recvPayLoads.size(); ++idxP) + { + recvPayLoads[idxP].resize(setSize); + } - } - else if (pIdx == prevNeibough) { - //I am a recv to my previous neigbour - recv.init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, false); - } - }); } - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); -#ifdef PRINT +#ifdef PRINT std::cout << IoStream::lock; - if (myIdx == 0) - { - Log::out << "------0------" << Log::endl; - Log::out << otSend[1].mGens[0].get() << Log::endl; - Log::out << otRecv[2].mGens[0][0].get() << Log::endl; - Log::out << otRecv[2].mGens[0][1].get() << Log::endl; + if (myIdx != leaderIdx) { + for (u64 i = 0; i < setSize; ++i) + { + block check = ZeroBlock; + for (u64 idxP = 0; idxP < ttParties + 1; ++idxP) + { + //if (idxP != myIdx) + check = check ^ sendPayLoads[idxP][i]; + } + if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) + std::cout << "Error ss values: myIdx: " << myIdx + << " value: " << check << std::endl; + } } - if (myIdx == 1) - { - Log::out << "------1------" << Log::endl; - Log::out << otRecv[0].mGens[0][0].get() << Log::endl; - Log::out << otRecv[0].mGens[0][1].get() << Log::endl; - Log::out << otSend[2].mGens[0].get() << Log::endl; + std::cout << IoStream::unlock; +#endif +#pragma endregion + u64 num_threads = nParties - 1; //except P0, and my + bool isDual = true; + u64 idx_start_dual = 0; + u64 idx_end_dual = 0; + u64 t_prev_shift = tSS; + + if (myIdx != leaderIdx) { + if (2 * tSS < nSS) + { + num_threads = 2 * tSS + 1; + isDual = false; + } + else { + idx_start_dual = (myIdx - tSS + nSS) % nSS; + idx_end_dual = (myIdx + tSS) % nSS; + } + + /*std::cout << IoStream::lock; + std::cout << myIdx << "| " << idx_start_dual << " " << idx_end_dual << "\n"; + std::cout << IoStream::unlock;*/ } + std::vector pThrds(num_threads); - if (myIdx == 2) + std::vector otRecv(nParties); + std::vector otSend(nParties); + std::vector send(nParties); + std::vector recv(nParties); + + if (myIdx == leaderIdx) { - Log::out << "------2------" << Log::endl; - Log::out << otRecv[1].mGens[0][0].get() << Log::endl; - Log::out << otRecv[1].mGens[0][1].get() << Log::endl; - Log::out << otSend[0].mGens[0].get() << Log::endl; + /*otRecv.resize(nParties - 1); + otSend.resize(nParties - 1); + send.resize(nParties - 1); + recv.resize(nParties - 1);*/ + pThrds.resize(nParties - 1); } - std::cout << IoStream::unlock; -#endif - auto initDone = timer.setTimePoint("initDone"); + + + binSet bins; //########################## - //### Hashing + //### Offline Phasing //########################## - bins.hashing2Bins(set, nParties); - //bins.mSimpleBins.print(myIdx, true, false, false, false); - //bins.mCuckooBins.print(myIdx, true, false, false); + Timer timer; + auto start = timer.setTimePoint("start"); + bins.init(myIdx, nParties, setSize, psiSecParam); + u64 otCountSend = bins.mSimpleBins.mBins.size(); + u64 otCountRecv = bins.mCuckooBins.mBins.size(); - auto hashingDone = timer.setTimePoint("hashingDone"); +#pragma region base OT //########################## - //### Online Phasing - compute OPRF + //### Base OT //########################## - pThrds.clear(); - pThrds.resize(nParties); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + if (myIdx != leaderIdx) { - pThrds[pIdx] = std::thread([&, pIdx]() { + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; - if (pIdx == nextNeibough) { - //I am a sender to my next neigbour - send.getOPRFKeys(pIdx, bins, chls[pIdx], false); - } - else if (pIdx == prevNeibough) { - //I am a recv to my previous neigbour - recv.getOPRFkeys(pIdx, bins, chls[pIdx], false); - } - }); - } + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) + { + u64 thr = t_prev_shift + pIdx; - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); + pThrds[thr] = std::thread([&, prevIdx, thr]() { - //if (myIdx == 2) - //{ - // //bins.mSimpleBins.print(2, true, true, false, false); - // bins.mCuckooBins.print(1, true, true, false); - // Log::out << "------------" << Log::endl; - //} - //if (myIdx == 1) - //{ - // bins.mSimpleBins.print(2, true, true, false, false); - // //bins.mCuckooBins.print(0, true, true, false); - //} + //chls[prevIdx][0]->recv(&revDummy[prevIdx], 1); + //std::cout << IoStream::lock; + //std::cout << myIdx << "| : " << "| thr[" << thr << "]:" << prevIdx << " --> " << myIdx << ": " << static_cast(revDummy[prevIdx]) << "\n"; + //std::cout << IoStream::unlock; - auto getOPRFDone = timer.setTimePoint("getOPRFDone"); + //prevIdx << " --> " << myIdx + recv[prevIdx].init(nParties, setSize, psiSecParam, bitSize, chls[prevIdx], otCountRecv, otRecv[prevIdx], otSend[prevIdx], ZeroBlock, false); - //########################## - //### online phasing - secretsharing - //########################## + }); - pThrds.clear(); - pThrds.resize(nParties - 1); - if (myIdx == 0) - { - //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - //{ - // pThrds[pIdx] = std::thread([&, pIdx]() { - // if (pIdx == 0) { - // send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); - // } - // else if (pIdx == 1) { - // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); - // } - // }); - //} - send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); - recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + + } + } + + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; + + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { + + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + + + //dual myIdx << " <-> " << nextIdx + if (myIdx < nextIdx) + { + //chls[nextIdx][0]->asyncSend(&dummy[nextIdx], 1); + //std::cout << IoStream::lock; + //std::cout << myIdx << "| d: " << "| thr[" << pIdx << "]:" << myIdx << " <->> " << nextIdx << ": " << static_cast(dummy[nextIdx]) << "\n"; + //std::cout << IoStream::unlock; + + send[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountSend, otSend[nextIdx], otRecv[nextIdx], prng.get(), true); + } + else if (myIdx > nextIdx) //by index + { + /* chls[nextIdx][0]->recv(&revDummy[nextIdx], 1); + + std::cout << IoStream::lock; + std::cout << myIdx << "| d: " << "| thr[" << pIdx << "]:" << myIdx << " <<-> " << nextIdx << ": " << static_cast(revDummy[nextIdx]) << "\n"; + std::cout << IoStream::unlock;*/ + + recv[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountRecv, otRecv[nextIdx], otSend[nextIdx], ZeroBlock, true); + } + }); + + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + + //chls[nextIdx][0]->asyncSend(&dummy[nextIdx], 1); + //std::cout << IoStream::lock; + //std::cout << myIdx << "| : " << "| thr[" << pIdx << "]:" << myIdx << " -> " << nextIdx << ": " << static_cast(dummy[nextIdx]) << "\n"; + //std::cout << IoStream::unlock; + send[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountSend, otSend[nextIdx], otRecv[nextIdx], prng.get(), false); + }); + } + } + + //last thread for connecting with leader + u64 tLeaderIdx = pThrds.size() - 1; + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + + // chls[leaderIdx][0]->asyncSend(&dummy[leaderIdx], 1); + + //std::cout << IoStream::lock; + //std::cout << myIdx << "| : " << "| thr[" << pThrds.size() - 1 << "]:" << myIdx << " --> " << leaderIdx << ": " << static_cast(dummy[leaderIdx]) << "\n"; + //std::cout << IoStream::unlock; + + send[leaderIdx].init(nParties, setSize, psiSecParam, bitSize, chls[leaderIdx], otCountSend, otSend[leaderIdx], otRecv[leaderIdx], prng.get(), false); + }); + } else - { - /*for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { //leader party + + for (u64 pIdx = 0; pIdx < nSS; ++pIdx) { pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx == 0) { - recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); - } - else if (pIdx == 1) { - send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); - } + /* chls[pIdx][0]->recv(&revDummy[pIdx], 1); + std::cout << IoStream::lock; + std::cout << myIdx << "| : " << "| thr[" << pIdx << "]:" << pIdx << " --> " << myIdx << ": " << static_cast(revDummy[pIdx]) << "\n"; + std::cout << IoStream::unlock;*/ + + recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, false); }); - } */ - recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); - //sendPayLoads = recvPayLoads; - send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); + } } - /*for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join();*/ + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + + auto initDone = timer.setTimePoint("initDone"); - auto getSSDone = timer.setTimePoint("getSSDone"); #ifdef PRINT std::cout << IoStream::lock; if (myIdx == 0) { - for (int i = 0; i < 5; i++) + Log::out << myIdx << "| -> " << otSend[1].mGens[0].get() << Log::endl; + if (otRecv[1].hasBaseOts()) { - Log::out << sendPayLoads[i] << Log::endl; - //Log::out << recvPayLoads[2][i] << Log::endl; + Log::out << myIdx << "| <- " << otRecv[1].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <- " << otRecv[1].mGens[0][1].get() << Log::endl; } Log::out << "------------" << Log::endl; } if (myIdx == 1) { - for (int i = 0; i < 5; i++) - { - //Log::out << recvPayLoads[i] << Log::endl; - Log::out << sendPayLoads[i] << Log::endl; - } + if (otSend[0].hasBaseOts()) + Log::out << myIdx << "| -> " << otSend[0].mGens[0].get() << Log::endl; + + Log::out << myIdx << "| <- " << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <- " << otRecv[0].mGens[0][1].get() << Log::endl; } - if (myIdx == 2) + + if (isDual) { - for (int i = 0; i < 5; i++) + if (myIdx == 0) { - Log::out << sendPayLoads[i] << Log::endl; + Log::out << myIdx << "| <->> " << otSend[tSS].mGens[0].get() << Log::endl; + if (otRecv[tSS].hasBaseOts()) + { + Log::out << myIdx << "| <<-> " << otRecv[tSS].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <<-> " << otRecv[tSS].mGens[0][1].get() << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == tSS) + { + if (otSend[0].hasBaseOts()) + Log::out << myIdx << "| <->> " << otSend[0].mGens[0].get() << Log::endl; + + Log::out << myIdx << "| <<-> " << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <<-> " << otRecv[0].mGens[0][1].get() << Log::endl; } } std::cout << IoStream::unlock; #endif +#pragma endregion + + //########################## - //### online phasing - compute intersection + //### Hashing //########################## + bins.hashing2Bins(set, 1); - if (myIdx == 0) { - std::vector mIntersection; - u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; - for (u64 i = 0; i < setSize; ++i) + /*if(myIdx==0) + bins.mSimpleBins.print(myIdx, true, false, false, false); + if (myIdx == 1) + bins.mCuckooBins.print(myIdx, true, false, false);*/ + + auto hashingDone = timer.setTimePoint("hashingDone"); + +#pragma region compute OPRF + + //########################## + //### Online Phasing - compute OPRF + //########################## + + pThrds.clear(); + pThrds.resize(num_threads); + if (myIdx == leaderIdx) + { + pThrds.resize(nParties - 1); + } + + if (myIdx != leaderIdx) + { + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) { - if (!memcmp((u8*)&sendPayLoads[i], &recvPayLoads[i], maskSize)) + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; + + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) { - mIntersection.push_back(i); + u64 thr = t_prev_shift + pIdx; + + pThrds[thr] = std::thread([&, prevIdx]() { + + //prevIdx << " --> " << myIdx + recv[prevIdx].getOPRFkeys(prevIdx, bins, chls[prevIdx], false); + + }); } } - Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; - } - auto getIntersection = timer.setTimePoint("getIntersection"); + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; + + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { + + pThrds[pIdx] = std::thread([&, nextIdx]() { + //dual myIdx << " <-> " << nextIdx + if (myIdx < nextIdx) + { + send[nextIdx].getOPRFKeys(nextIdx, bins, chls[nextIdx], true); + } + else if (myIdx > nextIdx) //by index + { + recv[nextIdx].getOPRFkeys(nextIdx, bins, chls[nextIdx], true); + } + }); + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx]() { + send[nextIdx].getOPRFKeys(nextIdx, bins, chls[nextIdx], false); + }); + } + } - if (myIdx == 0) { - auto offlineTime = std::chrono::duration_cast(initDone - start).count(); - auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); - auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); - auto secretSharingTime = std::chrono::duration_cast(getSSDone - getOPRFDone).count(); - auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDone).count(); + //last thread for connecting with leader + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + send[leaderIdx].getOPRFKeys(leaderIdx, bins, chls[leaderIdx], false); + }); - double onlineTime = hashingTime + getOPRFTime + secretSharingTime + intersectionTime; + } + else + { //leader party + for (u64 pIdx = 0; pIdx < nSS; ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + recv[pIdx].getOPRFkeys(pIdx, bins, chls[pIdx], false); + }); + } + } - double time = offlineTime + onlineTime; - time /= 1000; + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); - std::cout << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineTime << " ms\n" - << "hashingTime: " << hashingTime << " ms\n" - << "getOPRFTime: " << getOPRFTime << " ms\n" - << "secretSharing: " << secretSharingTime << " ms\n" - << "intersection: " << intersectionTime << " ms\n" - << "onlineTime: " << onlineTime << " ms\n" - << "Total time: " << time << " s\n" - << "------------------\n"; + auto getOPRFDone = timer.setTimePoint("getOPRFDone"); - offlineAvgTime += offlineTime; - hashingAvgTime += hashingTime; - getOPRFAvgTime += getOPRFTime; - secretSharingAvgTime += secretSharingTime; - intersectionAvgTime += intersectionTime; - onlineAvgTime += onlineTime; - } - } - if (myIdx == 0) { - double avgTime = (offlineAvgTime + onlineAvgTime); - avgTime /= 1000; - std::cout << "=========avg==========\n" - << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" - << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" - << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" - << "secretSharing: " << secretSharingAvgTime / numTrial << " ms\n" - << "intersection: " << intersectionAvgTime / numTrial << " ms\n" - << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" - << "Total time: " << avgTime / numTrial << " s\n"; - runtime << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" - << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" - << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" - << "secretSharing: " << secretSharingAvgTime / numTrial << " ms\n" - << "intersection: " << intersectionAvgTime / numTrial << " ms\n" - << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" - << "Total time: " << avgTime / numTrial << " s\n"; - runtime.close(); - } +#ifdef BIN_PRINT - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) + if (myIdx == 0) { - for (u64 j = 0; j < numThreads; ++j) + bins.mSimpleBins.print(1, true, true, false, false); + } + if (myIdx == 1) + { + bins.mCuckooBins.print(0, true, true, false); + } + + if (isDual) + { + if (myIdx == 0) { - chls[i][j]->close(); + bins.mCuckooBins.print(tSS, true, true, false); + } + if (myIdx == tSS) + { + bins.mSimpleBins.print(0, true, true, false, false); } } - } - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) - ep[i].stop(); - } +#endif +#pragma endregion +#pragma region SS - ios.stop(); -} + //########################## + //### online phasing - secretsharing + //########################## -void party2(u64 myIdx, u64 setSize) -{ - nParties = 2; - u64 psiSecParam = 40, bitSize = 128, numThreads = 1; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + pThrds.clear(); - std::vector set(setSize); - for (u64 i = 0; i < setSize; ++i) - set[i] = mSet[i]; + if (myIdx != leaderIdx) + { + pThrds.resize(num_threads); + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; - PRNG prng1(_mm_set_epi32(4253465, myIdx, myIdx, myIdx)); //for test - set[0] = prng1.get();; + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) + { + u64 thr = t_prev_shift + pIdx; - std::vector sendPayLoads(setSize); - std::vector recvPayLoads(setSize); + pThrds[thr] = std::thread([&, prevIdx, pIdx]() { - //only P0 genaretes secret sharing - if (myIdx == 0) - { - for (u64 i = 0; i < setSize; ++i) - sendPayLoads[i] = prng.get(); - } + //prevIdx << " --> " << myIdx + recv[prevIdx].revSecretSharing(prevIdx, bins, recvPayLoads[pIdx], chls[prevIdx]); + }); + } + } - std::string name("psi"); - BtIOService ios(0); + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; - int btCount = nParties; - std::vector ep(nParties); + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { - u64 offlineTimeTot(0); - u64 onlineTimeTot(0); - Timer timer; + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + //dual myIdx << " <-> " << nextIdx + //send OPRF can receive payload + if (myIdx < nextIdx) + { + send[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); - for (u64 i = 0; i < nParties; ++i) - { - if (i < myIdx) + send[nextIdx].revSecretSharing(nextIdx, bins, recvPayLoads[pIdx], chls[nextIdx]); + } + else if (myIdx > nextIdx) //by index + { + recv[nextIdx].revSecretSharing(nextIdx, bins, recvPayLoads[pIdx], chls[nextIdx]); + + recv[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); + + } + }); + + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + send[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); + }); + } + } + + //last thread for connecting with leader + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + //send[leaderIdx].getOPRFKeys(leaderIdx, bins, chls[leaderIdx], false); + }); + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + } + + auto getSsClientsDone = timer.setTimePoint("secretsharingClientDone"); + + +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) { - u32 port = 1210 + i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 - ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[0][i], maskSize); + Log::out << myIdx << "| -> 1: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; } - else if (i > myIdx) + if (myIdx == 1) { - u32 port = 1210 + myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 - ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); + Log::out << myIdx << "| <- 0: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + } + + if (isDual) + { + /*if (myIdx == 0) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[tSS][i], maskSize); + Log::out << myIdx << "| <- "<< tSS<<": (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == tSS) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[0][i], maskSize); + Log::out << myIdx << "| -> 0: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + }*/ } - } + std::cout << IoStream::unlock; +#endif +#pragma endregion - std::vector> chls(nParties); + //########################## + //### online phasing - send XOR of zero share to leader + //########################## + pThrds.clear(); - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) { - chls[i].resize(numThreads); - for (u64 j = 0; j < numThreads; ++j) + if (myIdx != leaderIdx) + { + + for (u64 i = 0; i < setSize; ++i) { - //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); - chls[i][j] = &ep[i].addChannel(name, name); + //xor all received share + for (u64 idxP = 0; idxP < ttParties; ++idxP) + { + sendPayLoads[ttParties][i] = sendPayLoads[ttParties][i] ^ recvPayLoads[idxP][i]; + } } + //send to leader + send[leaderIdx].sendSecretSharing(leaderIdx, bins, sendPayLoads[ttParties], chls[leaderIdx]); } - } + else + { + pThrds.resize(nParties - 1); - std::vector otRecv(nParties); - std::vector otSend(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { + pThrds[pIdx] = std::thread([&, pIdx]() { + recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + }); + } - OPPRFSender send; - OPPRFReceiver recv; - binSet bins; + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + } - std::vector pThrds(nParties); - //########################## - //### Offline Phasing - //########################## + auto getSSLeaderDone = timer.setTimePoint("leaderGetXorDone"); - auto start = timer.setTimePoint("start"); - bins.init(myIdx, nParties, setSize, psiSecParam); - u64 otCountSend = bins.mSimpleBins.mBins.size(); - u64 otCountRecv = bins.mCuckooBins.mBins.size(); + //########################## + //### online phasing - compute intersection + //########################## + std::vector mIntersection; + if (myIdx == leaderIdx) { + + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; - if (myIdx == 0) { - //I am a sender to my next neigbour - send.init(nParties, setSize, psiSecParam, bitSize, chls[1], otCountSend, otSend[1], otRecv[1], prng.get(), false); + for (u64 i = 0; i < setSize; ++i) + { - } - else if (myIdx == 1) { - //I am a recv to my previous neigbour - recv.init(nParties, setSize, psiSecParam, bitSize, chls[0], otCountRecv, otRecv[0], otSend[0], ZeroBlock, false); - } + //xor all received share + block sum = ZeroBlock; + for (u64 idxP = 0; idxP < nParties - 1; ++idxP) + { + sum = sum ^ recvPayLoads[idxP][i]; + } + if (!memcmp((u8*)&ZeroBlock, &sum, maskSize)) + { + mIntersection.push_back(i); + } + } -#ifdef PRINT - std::cout << IoStream::lock; - if (myIdx == 0) - { - Log::out << "------0------" << Log::endl; - Log::out << otSend[1].mGens[0].get() << Log::endl; - Log::out << otRecv[2].mGens[0][0].get() << Log::endl; - Log::out << otRecv[2].mGens[0][1].get() << Log::endl; - } - if (myIdx == 1) - { - Log::out << "------1------" << Log::endl; - Log::out << otRecv[0].mGens[0][0].get() << Log::endl; - Log::out << otRecv[0].mGens[0][1].get() << Log::endl; - Log::out << otSend[2].mGens[0].get() << Log::endl; - } + } + auto getIntersection = timer.setTimePoint("getIntersection"); - if (myIdx == 2) - { - Log::out << "------2------" << Log::endl; - Log::out << otRecv[1].mGens[0][0].get() << Log::endl; - Log::out << otRecv[1].mGens[0][1].get() << Log::endl; - Log::out << otSend[0].mGens[0].get() << Log::endl; - } - std::cout << IoStream::unlock; -#endif + std::cout << IoStream::lock; - auto initDone = timer.setTimePoint("initDone"); + if (myIdx == 0 || myIdx == 1 || myIdx == leaderIdx) { + auto offlineTime = std::chrono::duration_cast(initDone - start).count(); + auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); + auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); + auto ssClientTime = std::chrono::duration_cast(getSsClientsDone - getOPRFDone).count(); + auto ssServerTime = std::chrono::duration_cast(getSSLeaderDone - getSsClientsDone).count(); + auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSLeaderDone).count(); - //########################## - //### Hashing - //########################## - bins.hashing2Bins(set, 1); - //bins.mSimpleBins.print(myIdx, true, false, false, false); - //bins.mCuckooBins.print(myIdx, true, false, false); + double onlineTime = hashingTime + getOPRFTime + ssClientTime + ssServerTime + intersectionTime; - auto hashingDone = timer.setTimePoint("hashingDone"); + double time = offlineTime + onlineTime; + time /= 1000; - //########################## - //### Online Phasing - compute OPRF - //########################## + dataSent = 0; + dataRecv = 0; + Mbps = 0; + MbpsRecv = 0; + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + dataSent += chls[i][j]->getTotalDataSent(); + dataRecv += chls[i][j]->getTotalDataRecv(); + } + } + } - if (myIdx == 0) { - //I am a sender to my next neigbour - send.getOPRFKeys(1, bins, chls[1], false); - } - else if (myIdx == 1) { - //I am a recv to my previous neigbour - recv.getOPRFkeys(0, bins, chls[0], false); - } + Mbps = dataSent * 8 / time / (1 << 20); + MbpsRecv = dataRecv * 8 / time / (1 << 20); + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->resetStats(); + } + } + } + + if (myIdx == 0 || myIdx == 1) + { + std::cout << "Client Idx: " << myIdx << "\n"; + } + else + { + std::cout << "\nLeader Idx: " << myIdx << "\n"; + } - //if (myIdx == 2) - //{ - // //bins.mSimpleBins.print(2, true, true, false, false); - // bins.mCuckooBins.print(1, true, true, false); - // Log::out << "------------" << Log::endl; - //} - //if (myIdx == 1) - //{ - // bins.mSimpleBins.print(2, true, true, false, false); - // //bins.mCuckooBins.print(0, true, true, false); - //} + if (myIdx == leaderIdx) { + Log::out << "#Output Intersection: " << mIntersection.size() << Log::endl; + Log::out << "#Expected Intersection: " << expected_intersection << Log::endl; + num_intersection = mIntersection.size(); + } - auto getOPRFDone = timer.setTimePoint("getOPRFDone"); + std::cout << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineTime << " ms\n" + << "hashingTime: " << hashingTime << " ms\n" + << "getOPRFTime: " << getOPRFTime << " ms\n" + << "ss2DirTime: " << ssClientTime << " ms\n" + << "ssRoundTime: " << ssServerTime << " ms\n" + << "intersection: " << intersectionTime << " ms\n" + << "onlineTime: " << onlineTime << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << time << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dataRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + + - //########################## - //### online phasing - secretsharing - //########################## - if (myIdx == 0) - { - // send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); - // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); + offlineAvgTime += offlineTime; + hashingAvgTime += hashingTime; + getOPRFAvgTime += getOPRFTime; + ss2DirAvgTime += ssClientTime; + ssRoundAvgTime += ssServerTime; + intersectionAvgTime += intersectionTime; + onlineAvgTime += onlineTime; + } + std::cout << IoStream::unlock; } - else - { - // recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); - //sendPayLoads = recvPayLoads; - // send.sendSecretSharing(nextNeibough, bins, recvPayLoads, chls[nextNeibough]); - } - -#ifdef PRINT std::cout << IoStream::lock; - if (myIdx == 0) - { - for (int i = 0; i < 5; i++) - { - Log::out << sendPayLoads[i] << Log::endl; - //Log::out << recvPayLoads[2][i] << Log::endl; - } - Log::out << "------------" << Log::endl; - } - if (myIdx == 1) - { - for (int i = 0; i < 5; i++) - { - //Log::out << recvPayLoads[i] << Log::endl; - Log::out << sendPayLoads[i] << Log::endl; - } - } - if (myIdx == 2) - { - for (int i = 0; i < 5; i++) - { - Log::out << sendPayLoads[i] << Log::endl; - } - } - std::cout << IoStream::unlock; -#endif + if (myIdx == 0 || myIdx == leaderIdx) { + double avgTime = (offlineAvgTime + onlineAvgTime); + avgTime /= 1000; - //########################## - //### online phasing - compute intersection - //########################## + std::cout << "=========avg==========\n"; + runtime << "=========avg==========\n"; + runtime << "numParty: " << nParties + << " numCorrupted: " << tParties + << " setSize: " << setSize + << " nTrials:" << nTrials << "\n"; - if (myIdx == 0) { - std::vector mIntersection; - u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; - for (u64 i = 0; i < setSize; ++i) + if (myIdx == 0) { - if (!memcmp((u8*)&sendPayLoads[i], &recvPayLoads[i], maskSize)) - { - // mIntersection.push_back(i); - } + std::cout << "Client Idx: " << myIdx << "\n"; + runtime << "Client Idx: " << myIdx << "\n"; + } - Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; - } - auto end = timer.setTimePoint("getOPRFDone"); + else + { + std::cout << "Leader Idx: " << myIdx << "\n"; + Log::out << "#Output Intersection: " << num_intersection << Log::endl; + Log::out << "#Expected Intersection: " << expected_intersection << Log::endl; + runtime << "Leader Idx: " << myIdx << "\n"; + runtime << "#Output Intersection: " << num_intersection << "\n"; + runtime << "#Expected Intersection: " << expected_intersection << "\n"; + } - if (myIdx == 0) { - auto offlineTime = std::chrono::duration_cast(initDone - start).count(); - auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); - auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); - auto endTime = std::chrono::duration_cast(end - getOPRFDone).count(); - double time = offlineTime + hashingTime + getOPRFTime + endTime; - time /= 1000; - std::cout << "setSize: " << setSize << "\n" - << "offlineTime: " << offlineTime << "\n" - << "hashingTime: " << hashingTime << "\n" - << "getOPRFTime: " << getOPRFTime << "\n" - << "secretSharing: " << endTime << "\n" - << "onlineTime: " << hashingTime + getOPRFTime + endTime << "\n" - << "time: " << time << std::endl; + std::cout << "numParty: " << nParties + << " numCorrupted: " << tParties + << " setSize: " << setSize + << " nTrials:" << nTrials << "\n" + << "offlineTime: " << offlineAvgTime / nTrials << " ms\n" + << "hashingTime: " << hashingAvgTime / nTrials << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / nTrials << " ms\n" + << "ssClientTime: " << ss2DirAvgTime / nTrials << " ms\n" + << "ssLeaderTime: " << ssRoundAvgTime / nTrials << " ms\n" + << "intersection: " << intersectionAvgTime / nTrials << " ms\n" + << "onlineTime: " << onlineAvgTime / nTrials << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << avgTime << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dataRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + + runtime << "offlineTime: " << offlineAvgTime / nTrials << " ms\n" + << "hashingTime: " << hashingAvgTime / nTrials << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / nTrials << " ms\n" + << "ssClientTime: " << ss2DirAvgTime / nTrials << " ms\n" + << "ssLeaderTime: " << ssRoundAvgTime / nTrials << " ms\n" + << "intersection: " << intersectionAvgTime / nTrials << " ms\n" + << "onlineTime: " << onlineAvgTime / nTrials << " ms\n" + << "Bandwidth: Send: " << Mbps << " Mbps,\t Recv: " << MbpsRecv << " Mbps\n" + << "Total time: " << avgTime << " s\n" + << "Total Comm: Send:" << (dataSent / std::pow(2.0, 20)) << " MB" + << "\t Recv: " << (dataRecv / std::pow(2.0, 20)) << " MB\n" + << "------------------\n"; + runtime.close(); } + std::cout << IoStream::unlock; + /*if (myIdx == 0) { + double avgTime = (offlineAvgTime + onlineAvgTime); + avgTime /= 1000; + std::cout << "=========avg==========\n" + << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime.close(); + } + */ for (u64 i = 0; i < nParties; ++i) { if (i != myIdx) @@ -2073,6 +2251,29 @@ void party2(u64 myIdx, u64 setSize) ios.stop(); } +void OPPRFnt_EmptrySet_Test_Main() +{ + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; + + u64 nParties = 5; + u64 tParties = 2; + + + std::vector pThrds(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + { + pThrds[pIdx] = std::thread([&, pIdx]() { + // Channel_party_test(pIdx); + tparty(pIdx, nParties, tParties, setSize, 1); + }); + } + } + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + + +} void OPPRFn_EmptrySet_Test_Main() { @@ -2100,20 +2301,14 @@ void OPPRFn_EmptrySet_Test_Main() void OPPRF3_EmptrySet_Test_Main() { - u64 setSize = 1 << 20, psiSecParam = 40, bitSize = 128; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - mSet.resize(setSize); - for (u64 i = 0; i < setSize; ++i) - { - mSet[i] = prng.get(); - } + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; nParties = 3; std::vector pThrds(nParties); for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { pThrds[pIdx] = std::thread([&, pIdx]() { // Channel_party_test(pIdx); - party3(pIdx, setSize, mSet); + party3(pIdx, setSize, 1); }); } for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) diff --git a/frontend/OtBinMain.h b/frontend/OtBinMain.h index 838c176911..2218db2078 100644 --- a/frontend/OtBinMain.h +++ b/frontend/OtBinMain.h @@ -13,6 +13,8 @@ void OPPRF3_EmptrySet_Test_Main(); void OPPRFn_EmptrySet_Test_Main(); void OPPRF2_EmptrySet_Test_Main(); void Bit_Position_Random_Test(); -void party3(u64 myIdx, u64 setSize, std::vector& mSet); +void OPPRFnt_EmptrySet_Test_Main(); +void party3(u64 myIdx, u64 setSize, u64 nTrials); void party(u64 myIdx, u64 nParties, u64 setSize, std::vector& mSet); +void tparty(u64 myIdx, u64 nParties, u64 tParties, u64 setSize, u64 nTrials); //void OPPRFn_EmptrySet_Test(); \ No newline at end of file diff --git a/frontend/bitPosition.cpp b/frontend/bitPosition.cpp index 6db10808f1..6aad43706a 100644 --- a/frontend/bitPosition.cpp +++ b/frontend/bitPosition.cpp @@ -1,4 +1,3 @@ -#include "bloomFilterMain.h" #include "Network/BtEndpoint.h" #include "OPPRF/OPPRFReceiver.h" diff --git a/frontend/main.cpp b/frontend/main.cpp index 17f4f302e5..8e3934c6e8 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -33,54 +33,60 @@ int main(int argc, char** argv) //Bit_Position_Random_Test(); //return 0; //OPPRF2_EmptrySet_Test_Main(); - OPPRFn_EmptrySet_Test_Main(); - return 0; + //OPPRFn_EmptrySet_Test_Main(); + //OPPRF3_EmptrySet_Test_Main(); + //return 0; + //OPPRFnt_EmptrySet_Test_Main(); + //OPPRFnt_EmptrySet_Test_Main(); + + u64 trials=1; + std::vector mSet; u64 setSize = 1 << 20, psiSecParam = 40, bitSize = 128; - u64 nParties=3; + u64 nParties, tParties; + u64 roundOPPRF; PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); mSet.resize(setSize); for (u64 i = 0; i < setSize; ++i) - { mSet[i] = prng.get(); - } - /*if (argc == 3) { - if (argv[1][0] == '-' && argv[1][1] == 'p' && atoi(argv[2]) == 0) { - party3(0, setSize, mSet); - } - else if (argv[1][0] == '-' && argv[1][1] == 'p' && atoi(argv[2]) == 1) { - party3(1, setSize, mSet); - } - else if (argv[1][0] == '-' && argv[1][1] == 'p' && atoi(argv[2]) == 2) { - party3(2, setSize, mSet); - } - else { - usage(argv[0]); - } - }*/ - - - if (argc == 5) { + if (argc == 7) { if (argv[1][0] == '-' && argv[1][1] == 'n') nParties = atoi(argv[2]); + if (nParties == 3) { - if (argv[3][0] == '-' && argv[3][1] == 'p') { - u64 pIdx = atoi(argv[4]); - party3(pIdx, setSize, mSet); + if (argv[3][0] == '-' && argv[3][1] == 'r') + roundOPPRF = atoi(argv[4]); + + if (argv[5][0] == '-' && argv[5][1] == 'p') { + u64 pIdx = atoi(argv[6]); + if(roundOPPRF==1) + party3(pIdx, setSize, trials); + else + tparty(pIdx, nParties, 2, setSize, trials); } } else { - if (argv[3][0] == '-' && argv[3][1] == 'p') { - u64 pIdx = atoi(argv[4]); - party(pIdx, nParties, setSize, mSet); + if (argv[3][0] == '-' && argv[3][1] == 't') + tParties = atoi(argv[4]); + + if (argv[5][0] == '-' && argv[5][1] == 'p') { + u64 pIdx = atoi(argv[6]); + std::cout << "pIdx: " << pIdx << "\n"; + tparty(pIdx, nParties, tParties, setSize, trials); } } } + else if (argc == 2) { + if (argv[1][0] == '-' && argv[1][1] == 'u') + { + OPPRFnt_EmptrySet_Test_Main(); + } + } else { usage(argv[0]); } diff --git a/frontend/runtime3.txt b/frontend/runtime3.txt new file mode 100644 index 0000000000..33df214f75 --- /dev/null +++ b/frontend/runtime3.txt @@ -0,0 +1,8 @@ +setSize: 32 +offlineTime: 4052 ms +hashingTime: 15 ms +getOPRFTime: 74 ms +secretSharing: 166 ms +intersection: 0 ms +onlineTime: 255 ms +Total time: 4.308 s diff --git a/frontend/runtime_client.txt b/frontend/runtime_client.txt new file mode 100644 index 0000000000..ade111b16d --- /dev/null +++ b/frontend/runtime_client.txt @@ -0,0 +1,169 @@ +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2586 ms +hashingTime: 4 ms +getOPRFTime: 112 ms +ssClientTime: 73 ms +ssLeaderTime: 27 ms +intersection: 0 ms +onlineTime: 216 ms +Total time: 2.802 s +data/second: 0.272338 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2689 ms +hashingTime: 4 ms +getOPRFTime: 47 ms +ssClientTime: 79 ms +ssLeaderTime: 29 ms +intersection: 0 ms +onlineTime: 159 ms +Total time: 2.848 s +data/second: 0.26794 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2678 ms +hashingTime: 4 ms +getOPRFTime: 46 ms +ssClientTime: 77 ms +ssLeaderTime: 26 ms +intersection: 0 ms +onlineTime: 153 ms +Total time: 2.831 s +data/second: 0.269549 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2736 ms +hashingTime: 4 ms +getOPRFTime: 26 ms +ssClientTime: 66 ms +ssLeaderTime: 22 ms +intersection: 0 ms +onlineTime: 118 ms +Total time: 2.854 s +data/second: 0.267376 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2581 ms +hashingTime: 8 ms +getOPRFTime: 81 ms +ssClientTime: 71 ms +ssLeaderTime: 27 ms +intersection: 0 ms +onlineTime: 187 ms +Total time: 2.768 s +data/second: 0.275684 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2704 ms +hashingTime: 4 ms +getOPRFTime: 87 ms +ssClientTime: 63 ms +ssLeaderTime: 24 ms +intersection: 0 ms +onlineTime: 178 ms +Total time: 2.882 s +data/second: 0.264779 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2668 ms +hashingTime: 3 ms +getOPRFTime: 18 ms +ssClientTime: 69 ms +ssLeaderTime: 22 ms +intersection: 0 ms +onlineTime: 112 ms +Total time: 2.78 s +data/second: 0.274494 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2691 ms +hashingTime: 4 ms +getOPRFTime: 22 ms +ssClientTime: 68 ms +ssLeaderTime: 21 ms +intersection: 0 ms +onlineTime: 115 ms +Total time: 2.806 s +data/second: 0.27195 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2750 ms +hashingTime: 4 ms +getOPRFTime: 21 ms +ssClientTime: 72 ms +ssLeaderTime: 22 ms +intersection: 0 ms +onlineTime: 119 ms +Total time: 2.869 s +data/second: 0.265978 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2668 ms +hashingTime: 3 ms +getOPRFTime: 26 ms +ssClientTime: 64 ms +ssLeaderTime: 20 ms +intersection: 0 ms +onlineTime: 113 ms +Total time: 2.781 s +data/second: 0.274395 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2586 ms +hashingTime: 4 ms +getOPRFTime: 40 ms +ssClientTime: 81 ms +ssLeaderTime: 24 ms +intersection: 0 ms +onlineTime: 149 ms +Total time: 2.735 s +data/second: 0.27901 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2580 ms +hashingTime: 4 ms +getOPRFTime: 61 ms +ssClientTime: 67 ms +ssLeaderTime: 28 ms +intersection: 0 ms +onlineTime: 160 ms +Total time: 2.74 s +data/second: 0.278501 Mbps +Total data: 0.0953865 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Client Idx: 0 +offlineTime: 2892 ms +hashingTime: 4 ms +getOPRFTime: 25 ms +ssClientTime: 66 ms +ssLeaderTime: 23 ms +intersection: 0 ms +onlineTime: 118 ms +data/second: 0.253519 Mbps +Total time: 3.01 s Total data: 0.0953865 MB +------------------ diff --git a/frontend/runtime_leader.txt b/frontend/runtime_leader.txt new file mode 100644 index 0000000000..ad83143ba3 --- /dev/null +++ b/frontend/runtime_leader.txt @@ -0,0 +1,184 @@ +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 0 +offlineTime: 2409 ms +hashingTime: 8 ms +getOPRFTime: 14 ms +ssClientTime: 0 ms +ssLeaderTime: 377 ms +intersection: 0 ms +onlineTime: 399 ms +Total time: 2.808 s +data/second: 0.0804999 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 0 +offlineTime: 2527 ms +hashingTime: 4 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 313 ms +intersection: 0 ms +onlineTime: 322 ms +Total time: 2.849 s +data/second: 0.0793414 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 0 +offlineTime: 2589 ms +hashingTime: 9 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 232 ms +intersection: 0 ms +onlineTime: 246 ms +Total time: 2.835 s +data/second: 0.0797332 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 0 +offlineTime: 2371 ms +hashingTime: 9 ms +getOPRFTime: 20 ms +ssClientTime: 0 ms +ssLeaderTime: 452 ms +intersection: 0 ms +onlineTime: 481 ms +Total time: 2.852 s +data/second: 0.079258 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 31 +offlineTime: 2623 ms +hashingTime: 4 ms +getOPRFTime: 6 ms +ssClientTime: 0 ms +ssLeaderTime: 136 ms +intersection: 7 ms +onlineTime: 153 ms +Total time: 2.776 s +data/second: 0.0814278 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 31 +offlineTime: 2465 ms +hashingTime: 8 ms +getOPRFTime: 7 ms +ssClientTime: 0 ms +ssLeaderTime: 415 ms +intersection: 1 ms +onlineTime: 431 ms +Total time: 2.896 s +data/second: 0.0780538 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 32 +offlineTime: 2623 ms +hashingTime: 4 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 154 ms +intersection: 0 ms +onlineTime: 163 ms +Total time: 2.786 s +data/second: 0.0811356 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 32 +offlineTime: 2558 ms +hashingTime: 4 ms +getOPRFTime: 6 ms +ssClientTime: 0 ms +ssLeaderTime: 239 ms +intersection: 0 ms +onlineTime: 249 ms +Total time: 2.807 s +data/second: 0.0805286 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 16 +offlineTime: 2476 ms +hashingTime: 6 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 382 ms +intersection: 0 ms +onlineTime: 393 ms +Total time: 2.869 s +data/second: 0.0787883 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 16 +offlineTime: 2469 ms +hashingTime: 5 ms +getOPRFTime: 6 ms +ssClientTime: 0 ms +ssLeaderTime: 304 ms +intersection: 0 ms +onlineTime: 315 ms +Total time: 2.784 s +data/second: 0.0811939 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +mIntersection.size(): 16 +offlineTime: 2530 ms +hashingTime: 3 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 197 ms +intersection: 0 ms +onlineTime: 205 ms +Total time: 2.735 s +data/second: 0.0826485 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +#Output Intersection: 16 +#Expected Intersection: 16 +offlineTime: 2533 ms +hashingTime: 7 ms +getOPRFTime: 5 ms +ssClientTime: 0 ms +ssLeaderTime: 196 ms +intersection: 0 ms +onlineTime: 208 ms +Total time: 2.741 s +data/second: 0.0824676 Mbps +Total data: 0.0282555 MB +=========avg========== +numParty: 5 numCorrupted: 2 setSize: 32 nTrials:1 +Leader Idx: 4 +#Output Intersection: 16 +#Expected Intersection: 16 +offlineTime: 2604 ms +hashingTime: 7 ms +getOPRFTime: 65 ms +ssClientTime: 0 ms +ssLeaderTime: 337 ms +intersection: 0 ms +onlineTime: 409 ms +data/second: 0.0750228 Mbps +Total time: 3.013 s Total data: 0.0282555 MB +------------------ diff --git a/libOPRF/Hashing/CuckooHasher1.cpp b/libOPRF/Hashing/CuckooHasher1.cpp index be594f89ed..a0b2a5bb18 100644 --- a/libOPRF/Hashing/CuckooHasher1.cpp +++ b/libOPRF/Hashing/CuckooHasher1.cpp @@ -315,7 +315,7 @@ namespace osuCrypto u64 width = mStashHashesView.size()[1]; u64 remaining = inputIdxs.size(); - std::cout << "inputStashIdxs.size(): " << inputIdxs.size() << std::endl; + // std::cout << "inputStashIdxs.size(): " << inputIdxs.size() << std::endl; u64 tryCount = 0; diff --git a/libOPRF/OPPRF/OPPRFSender.cpp b/libOPRF/OPPRF/OPPRFSender.cpp index 987f8815c2..2826b8961b 100644 --- a/libOPRF/OPPRF/OPPRFSender.cpp +++ b/libOPRF/OPPRF/OPPRFSender.cpp @@ -430,7 +430,12 @@ namespace osuCrypto if (tIdx == 0) gTimer.setTimePoint("online.send.otSend.finalOPRF"); + +#ifdef PRINT std::cout << "getPosTime" << IdxP << ": " << mPosBitsTime / pow(10, 6) << std::endl; +#endif // PRINT + + #pragma endregion #endif diff --git a/libOTe/.gitignore b/libOTe/.gitignore new file mode 100644 index 0000000000..1eb70dcf20 --- /dev/null +++ b/libOTe/.gitignore @@ -0,0 +1,218 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. + +# User-specific files +*.suo +*.user +*.sln.docstates + +CMakeFiles/* +*/CMakeFiles/* +*cmake_install.cmake + +CMakeCache.txt +*/CMakeCache.txt + +*.a + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +x64/ +build/ +bld/ +[Bb]in/ +[Oo]bj/ + +# Roslyn cache directories +*.ide/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +#NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.obj +*.pch +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opensdf +*.sdf +*.cachefile + +# Visual Studio profiler +*.psess +*.vsp +*.vspx + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding addin-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# NCrunch +_NCrunch_* +.*crunch*.local.xml + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +## TODO: Comment the next line if you want to checkin your +## web deploy settings but do note that will include unencrypted +## passwords +#*.pubxml + +# NuGet Packages Directory +packages/* +## TODO: If the tool you use requires repositories.config +## uncomment the next line +#!packages/repositories.config + +# Enable "build/" folder in the NuGet Packages folder since +# NuGet packages use it for MSBuild targets. +# This line needs to be after the ignore of the build folder +# (and the packages folder if the line above has been uncommented) +!packages/build/ + +# Windows Azure Build Output +csx/ +*.build.csdef + +# Windows Store app package directory +AppPackages/ + +# Others +sql/ +*.Cache +ClientBin/ +[Ss]tyle[Cc]op.* +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.pfx +*.publishsettings +node_modules/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm + +# SQL Server files +*.mdf +*.ldf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# LightSwitch generated files +GeneratedArtifacts/ +_Pvt_Extensions/ +ModelManifest.xml +/WeGarbleTests__ +/thirdparty +kProbe_* + +CodeDB +LinuxFrontEnd/VisualGDBCache +*.opendb +*.pdf +*.db +*.sln + +mpsi.VC* + +/psir_8s.txt +/psis_8s.txt + +testout.txt +online.txt +offline.txt +Makefile \ No newline at end of file diff --git a/libOTe/.gitmodules b/libOTe/.gitmodules new file mode 100644 index 0000000000..4f8b1a05eb --- /dev/null +++ b/libOTe/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cryptoTools"] + path = cryptoTools + url = https://github.com/ladnir/cryptoTools.git diff --git a/libOTe/README.md b/libOTe/README.md new file mode 100644 index 0000000000..40eacbe649 --- /dev/null +++ b/libOTe/README.md @@ -0,0 +1,156 @@ +# libOTe +A fast and portable C++11 library for Oblivious Transfer extension (OTe). The primary design goal of this library to obtain *high performance* while being *easy to use*. This library currently implements: + +* The semi-honest 1-out-of-2 OT [IKNP03]. +* The semi-honest 1-out-of-N OT [[KKRT16]](https://eprint.iacr.org/2016/799). +* The malicious secure 1-out-of-2 OT [[KOS15]](https://eprint.iacr.org/2015/546). +* The malicious secure 1-out-of-2 Delta-OT [[KOS15]](https://eprint.iacr.org/2015/546),[[BLNNOOSS15]](https://eprint.iacr.org/2015/472.pdf). +* The malicious secure 1-out-of-N OT [[OOS16]](http://eprint.iacr.org/2016/933). +* The malicious secure approximate K-out-of-N OT [[RR16]](https://eprint.iacr.org/2016/746). +* The malicious secure 1-out-of-2 base OT [NP00]. + +## Introduction + +This library provides several different classes of OT protocols. First is the base OT protocol of Naor Prinkas [NP00]. This protocol bootstraps all the other OT extension protocols. Within the OT extension protocols, we have 1-out-of-2, 1-out-of-N and ~K-out-of-N, both in the semi-honest and malicious settings. + +All implementations are highly optimized using fast SSE instructions and vectorization to obtain optimal performance both in the single and multi-threaded setting. See the **Performance** section for a comparison between protocols and to other libraries. + + + +## Performance + +The running time in seconds for computing n=224 OTs on a single Intel Xeon server (`2 36-cores Intel Xeon CPU E5-2699 v3 @ 2.30GHz and 256GB of RAM`) as of 11/16/2016. All timings shown reflect a "single" thread per party, with the expection that network IO in libOTe is performed in the background by a separate thread. + + +| *Type* | *Security* | *Protocol* | libOTe (SHA1/AES) | [Encrypto Group](https://github.com/encryptogroup/OTExtension) (SHA256) | [Apricot](https://github.com/bristolcrypto/apricot) (AES-hash) | OOS16 (blake2) | [emp-toolkit](https://github.com/emp-toolkit) (AES-hash) | +|--------------------- |----------- |-------------- |---------------- |---------------- |--------- |--------- |------------ | +| 1-out-of-N (N=276) | malicious | OOS16 | **11.7 / 9.2** | ~ | ~ | 24** | ~ | +| 1-out-of-N (N=2128)| passive| KKRT16 | **9.2 / 6.7** | ~ | ~ | ~ | ~ | +| 1-out-of-2 Delta-OT | malicious | KOS15 | **1.9*** | ~ | ~ | ~ | ~ | +| 1-out-of-2 | malicious | ALSZ15 | ~ | 17.3 | ~ | ~ | 10 | +| 1-out-of-2 | malicious | KOS15 | **3.9 / 0.7** | ~ | 1.1 | ~ | 2.9 | +| 1-out-of-2 | passive | IKNP03 | **3.7 / 0.6** | 11.3 | **0.6** | ~ | 2.7 | + + +\* Delta-OT does not use the SHA1 or AES hash function. + +\** This timing was taken from the [[OOS16]](http://eprint.iacr.org/2016/933) paper and their implementation used multiple threads. The number was not specified. When using the libOTe implementation with multiple threads, a timing of 2.6 seconds was obtained with the SHA1 hash function. + +It should be noted that the libOTe implementation uses the Boost ASIO library to perform more efficient asynchronous network IO. This involves using a background thread to help process network data. As such, this is not a completely fair comparison to the Apricot implementation but we don't expect it to have a large impact. It also appears that the Encrypto Group implementation uses asynchronous network IO. + + + The above timings were obtained with the follwoing options: + + 1-out-of-2 malicious: + * Apricot: `./ot.x -n 16777216 -p 0 -m a -l 100 & ./ot.x -p 1 -m a -n 16777216 -l 100` + * Encrypto Group: ` ./ot.exe -r 0 -n 16777216 -o 1 & ./ot.exe -r 1 -n 16777216 -o 1` + * emp-toolkit: 2x 223 `./mot 0 1212 & ./mot 1 1212` + +1-out-of-2 semi-honest: + * Apricot: `./ot.x -n 16777216 -p 0 -m a -l 100 -pas & ./ot.x -p 1 -m a -n 16777216 -l 100 -pas` + * Encrypto Group: ` ./ot.exe -r 0 -n 16777216 -o 0 & ./ot.exe -r 1 -n 16777216 -o 0` + * emp-toolkit: 2*223 `./shot 0 1212 & ./shot 1 1212` + + +## License + +This project has been placed in the public domain. As such, you are unrestricted in how you use it, commercial or otherwise. However, no warranty of fitness is provided. If you found this project helpful, feel free to spread the word and cite us. + + + + +## Install + +The library is *cross platform* and has been tested on both Windows and Linux. The library should work on MAC but it has not been tested. There are two library dependencies including [Boost](http://www.boost.org/) (networking), and [Miracl](https://www.miracl.com/index) (Base OT). For each, we provide a script that automates the download and build steps. The version of Miracl used by this library requires specific configuration and therefore we advise using the coned repository that we provide. + +### Windows + +In `Powershell`, this will set up the project + +``` +git clone --recursive https://github.com/osu-crypto/libOTe.git +cd libOTe/thirdparty/win +getBoost.ps1; getMiracl.ps1 +cd ../.. +libOTe.sln +``` + +Requirements: `Powershell`, Powershell `Set-ExecutionPolicy Unrestricted`, `Visual Studio 2015`, CPU supporting `PCLMUL`, `AES-NI`, and `SSE4.1`. +Optional: `nasm` for improved SHA1 performance. + +Build the solution within visual studio or with `MSBuild`. To see all the command line options, execute the program + +`frontend.exe` + +If the cryptoTools directory is empty `git submodule update --init --recursive`. + +IMPORTANT: By default, the build system needs the NASM compiler to be located at `C:\NASM\nasm.exe`. In the event that it isn't, there are two options, install it, or enable the pure c++ implementation. The latter option is done by excluding `libOTe/Crypto/asm/sha_win64.asm` from the build system and undefining `INTEL_ASM_SHA1` on line 28 of `libOTe/Crypto/sha1.cpp`. + + + + +### Linux + + In short, this will build the project + +``` +git clone --recursive https://github.com/osu-crypto/libOTe.git +cd libOTe/thirdparty/linux +bash all.get +cd ../.. +CMake -G "Unix Makefiles" +make +``` + +Requirements: `CMake`, `Make`, `g++` or similar, CPU supporting `PCLMUL`, `AES-NI`, and `SSE4.1`. Optional: `nasm` for improved SHA1 performance. + +The libraries will be placed in `libOTe/lib` and the binary `frontend.exe` will be placed in `libOTe/bin` To see all the command line options, execute the program + +`./bin/frontend.exe` + +Note: In the case that miracl or boost is already installed, the steps `cd libOTe/thirdparty/linux; bash all.get` can be skipped and CMake will attempt to find them instead. Boost is found with the CMake findBoost package and miracl is found with the `find_library(miracl)` command. + + If the cryptoTools directory is empty `git submodule update --init --recursive`. + +## Citing + + Spread the word! + +``` +@misc{libOTe, + author = {Peter Rindal}, + title = {{libOTe: an efficient, portable, and easy to use Oblivious Transfer Library}}, + howpublished = {\url{https://github.com/osu-crypto/libOTe}}, +} +``` +## Protocol Details +The 1-out-of-N [OOS16] protocol currently is set to work forn N=276 but is capable of supporting arbitrary codes given the generator matrix in text format. See `./libOTe/Tools/Bch511.txt` for an example. + +The 1-out-of-N [KKRT16] for arbitrary N is also implemented and slightly faster than [OOS16]. However, [KKRT16] is in the semi-honest setting. + +The approximate K-out-of-N OT [RR16] protocol is also implemented. This protocol allows for a rough bound on the value K with a very light weight cut and choose technique. It was introduced for a PSI protocol that builds on a Garbled Bloom Filter. + +## Help + +Contact Peter Rindal rindalp@oregonstate.edu for any assistance on building or running the library. + + + +## Citation + +[IKNP03] - Yuval Ishai and Joe Kilian and Kobbi Nissim and Erez Petrank, _Extending Oblivious Transfers Efficiently_. + +[KOS15] - Marcel Keller and Emmanuela Orsini and Peter Scholl, _Actively Secure OT Extension with Optimal Overhead_. [eprint/2015/546](https://eprint.iacr.org/2015/546) + +[OOS16] - Michele Orrù and Emmanuela Orsini and Peter Scholl, _Actively Secure 1-out-of-N OT Extension with Application to Private Set Intersection_. [eprint/2016/933](http://eprint.iacr.org/2016/933) + +[KKRT16] - Vladimir Kolesnikov and Ranjit Kumaresan and Mike Rosulek and Ni Trieu, _Efficient Batched Oblivious PRF with Applications to Private Set Intersection_. [eprint/2016/799](https://eprint.iacr.org/2016/799) + +[RR16] - Peter Rindal and Mike Rosulek, _Improved Private Set Intersection against Malicious Adversaries_. [eprint/2016/746](https://eprint.iacr.org/2016/746) + +[BLNNOOSS15] - Sai Sheshank Burra and Enrique Larraia and Jesper Buus Nielsen and Peter Sebastian Nordholt and Claudio Orlandi and Emmanuela Orsini and Peter Scholl and Nigel P. Smart, _High Performance Multi-Party Computation for Binary Circuits Based on Oblivious Transfe_. [eprint/2015/472](https://eprint.iacr.org/2015/472.pdf) + +[ALSZ15] - Gilad Asharov and Yehuda Lindell and Thomas Schneider and Michael Zohner, _More Efficient Oblivious Transfer Extensions with Security for Malicious Adversaries_. [eprint/2015/061](https://eprint.iacr.org/2015/061) + +[NP00] - Moni Naor, Benny Pinkas, _Efficient Oblivious Transfer Protocols_. + diff --git a/libOTe/buildAll.ps1 b/libOTe/buildAll.ps1 new file mode 100644 index 0000000000..bdc64faa0f --- /dev/null +++ b/libOTe/buildAll.ps1 @@ -0,0 +1,25 @@ + +# Update this if needed +$MSBuild = 'C:\Program Files (x86)\MSBuild\14.0\Bin\MSBuild.exe' + +if(!(Test-Path $MSBuild)) +{ + Write-Host "Could not find MSBuild as" + Write-Host " $MSBuild" + Write-Host "" + Write-Host "Please update its lication in the script" + + exit +} + +cd ./thirdparty/win + +& ./getBoost.ps1 +& ./getMiracl.ps1 + +cd ../.. + +& $MSBuild libOTe.sln /p:Configuration=Release /p:Platform=x64 +& $MSBuild libOTe.sln /p:Configuration=Debug /p:Platform=x64 + + diff --git a/libPSI_Tests/OPPRF_Tests.cpp b/libPSI_Tests/OPPRF_Tests.cpp index 6dc594dd65..b936895ddc 100644 --- a/libPSI_Tests/OPPRF_Tests.cpp +++ b/libPSI_Tests/OPPRF_Tests.cpp @@ -24,17 +24,18 @@ using namespace osuCrypto; #define PRINT +//#define BIN_PRINT void testPointer(std::vector* test) { //int length = test->size(); //std::cout << length << std::endl; - + AES ncoInputHasher; - - ncoInputHasher.setKey(_mm_set1_epi64x(112434)); - ncoInputHasher.ecbEncBlocks((*test).data() , test->size() - 1, (*test).data() ); + + ncoInputHasher.setKey(_mm_set1_epi64x(112434)); + ncoInputHasher.ecbEncBlocks((*test).data(), test->size() - 1, (*test).data()); //Log::out << "mHashingSeed: " << mHashingSeed << Log::endl; @@ -47,7 +48,7 @@ void testPointer2(std::vector& test) ncoInputHasher.setKey(_mm_set1_epi64x(112434)); ncoInputHasher.ecbEncBlocks(test.data(), test.size() - 1, test.data()); - + } @@ -64,7 +65,7 @@ void Bit_Position_Test_Impl() myvector.resize(0); #if 0 - u64 setSize = 1<<4; + u64 setSize = 1 << 4; std::vector testSet(setSize); PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); @@ -76,7 +77,7 @@ void Bit_Position_Test_Impl() testPointer2(testSet); for (u64 i = 0; i < setSize; ++i) { - std::cout < masks; -// b2.findPos(testSet, masks); - //std::cout << "\nmNumTrial: " << b2.mNumTrial << std::endl; - + // b2.findPos(testSet, masks); + //std::cout << "\nmNumTrial: " << b2.mNumTrial << std::endl; + for (u8 i = 0; i < masks.size(); i++) { @@ -110,7 +111,7 @@ double maxprob1(u64 balls, u64 bins, u64 k) { return std::log(bins * std::pow(balls * exp(1) / (bins * k), k)) / std::log(2); } -u64 findMaxBinSize(u64 n, u64 numBins, u64 numHash=2) +u64 findMaxBinSize(u64 n, u64 numBins, u64 numHash = 2) { u64 balls = numHash*n; u64 maxBin; @@ -130,7 +131,7 @@ double findScaleNumBins(u64 n, u64 maxBin, u64 numHash = 2) { u64 balls = numHash*n; double scale; - for (scale = 0.01; scale < 1; scale+= 0.01) + for (scale = 0.01; scale < 1; scale += 0.01) { // finds the min number of bins needed to get max occ. to be maxBin if (-maxprob1(balls, scale*n, maxBin) < 40) @@ -175,13 +176,13 @@ void hashing2Bins_Test_Impl() for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { pThrds[pIdx] = std::thread([&, pIdx]() { - + bins[pIdx].init(pIdx, 2, setSize, psiSecParam); bins[pIdx].hashing2Bins(set, 2); }); } - + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) pThrds[pIdx].join(); @@ -209,15 +210,15 @@ void Bit_Position_Recursive_Test_Impl() //{ // testSet[i].m128i_u16[2] = 1 << i; //} - + BitPosition b; #if 0 block test = ZeroBlock; //test.m128i_i8[0] = 126; - + BitPosition b; - b.init(5,5); + b.init(5, 5); for (size_t i = 0; i < 8; i++) { b.setBit(test, i); @@ -257,7 +258,7 @@ void Bit_Position_Recursive_Test_Impl() Log::out << " getPos="; for (u64 j = 0; j < b.mPos.size(); ++j) { - Log::out << static_cast(b.mPos[j])<<" "; + Log::out << static_cast(b.mPos[j]) << " "; } Log::out << Log::endl; @@ -272,10 +273,10 @@ void Bit_Position_Recursive_Test_Impl() void Bit_Position_Random_Test_Impl() { u64 power = 5; - u64 setSize = 1<< power; + u64 setSize = 1 << power; + + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - SimpleHasher1 mSimpleBins; mSimpleBins.init(setSize); @@ -285,7 +286,7 @@ void Bit_Position_Random_Test_Impl() for (u64 j = 0; j < setSize; ++j) { tempIdxBuff[j] = j; - for (u64 k = 0; k (); hashes[j][k] = *(u64*)&a; @@ -319,7 +320,7 @@ void Bit_Position_Random_Test_Impl() { auto& bin = mSimpleBins.mBins[bIdx]; if (bin.mIdx.size() > 0) - { + { bin.mBits[0].init(mSimpleBins.mNumBits[0]); bin.mBits[0].getPos1(bin.mValOPRF[0], 128); } @@ -329,7 +330,7 @@ void Bit_Position_Random_Test_Impl() for (u64 bIdx = 0; bIdx < mSimpleBins.mBinCount[1]; ++bIdx) { - auto& bin = mSimpleBins.mBins[mSimpleBins.mBinCount[0]+bIdx]; + auto& bin = mSimpleBins.mBins[mSimpleBins.mBinCount[0] + bIdx]; if (bin.mIdx.size() > 0) { bin.mBits[0].init(mSimpleBins.mNumBits[1]); @@ -377,56 +378,56 @@ void Bit_Position_Random_Test_Impl() void OPPRF_CuckooHasher_Test_Impl() { #if 0 - u64 setSize = 10000; + u64 setSize = 10000; - u64 h = 2; - std::vector _hashes(setSize * h + 1); - MatrixView hashes(_hashes.begin(), _hashes.end(), h); - PRNG prng(ZeroBlock); + u64 h = 2; + std::vector _hashes(setSize * h + 1); + MatrixView hashes(_hashes.begin(), _hashes.end(), h); + PRNG prng(ZeroBlock); - for (u64 i = 0; i < hashes.size()[0]; ++i) - { - for (u64 j = 0; j < h; ++j) - { - hashes[i][j] = prng.get(); - } - } + for (u64 i = 0; i < hashes.size()[0]; ++i) + { + for (u64 j = 0; j < h; ++j) + { + hashes[i][j] = prng.get(); + } + } - CuckooHasher hashMap0; - CuckooHasher hashMap1; - CuckooHasher::Workspace w(1); + CuckooHasher hashMap0; + CuckooHasher hashMap1; + CuckooHasher::Workspace w(1); - hashMap0.init(setSize, 40,1, true); - hashMap1.init(setSize, 40,1, true); + hashMap0.init(setSize, 40, 1, true); + hashMap1.init(setSize, 40, 1, true); - for (u64 i = 0; i < setSize; ++i) - { - //if (i == 6) hashMap0.print(); + for (u64 i = 0; i < setSize; ++i) + { + //if (i == 6) hashMap0.print(); - hashMap0.insert(i, hashes[i]); + hashMap0.insert(i, hashes[i]); - std::vector tt{ i }; - MatrixView mm(hashes[i].data(), 1, 2, false); - hashMap1.insertBatch(tt, mm, w); + std::vector tt{ i }; + MatrixView mm(hashes[i].data(), 1, 2, false); + hashMap1.insertBatch(tt, mm, w); - //if (i == 6) hashMap0.print(); - //if (i == 6) hashMap1.print(); + //if (i == 6) hashMap0.print(); + //if (i == 6) hashMap1.print(); - //if (hashMap0 != hashMap1) - //{ - // std::cout << i << std::endl; + //if (hashMap0 != hashMap1) + //{ + // std::cout << i << std::endl; - // throw UnitTestFail(); - //} + // throw UnitTestFail(); + //} - } + } - if (hashMap0 != hashMap1) - { - throw UnitTestFail(); - } + if (hashMap0 != hashMap1) + { + throw UnitTestFail(); + } #endif } @@ -515,8 +516,8 @@ void Channel_Test_Impl() { { for (u64 j = 0; j < numParties; ++j) { - if(i!=j) - ep[i*numParties+j].stop(); + if (i != j) + ep[i*numParties + j].stop(); } } @@ -524,108 +525,108 @@ void Channel_Test_Impl() { } void OPPRF2_EmptrySet_Test_Impl() { - u64 setSize = 1<<5, psiSecParam = 40, bitSize = 128 , numParties=2; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128, numParties = 2; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - std::vector sendSet(setSize), recvSet(setSize); + std::vector sendSet(setSize), recvSet(setSize); std::vector sendPayLoads(setSize), recvPayLoads(setSize); - for (u64 i = 0; i < setSize; ++i) - { - sendSet[i] = prng.get(); - sendPayLoads[i]= prng.get(); + for (u64 i = 0; i < setSize; ++i) + { + sendSet[i] = prng.get(); + sendPayLoads[i] = prng.get(); recvSet[i] = prng.get(); recvSet[i] = sendSet[i]; - } + } for (u64 i = 1; i < 3; ++i) { recvSet[i] = sendSet[i]; } - std::string name("psi"); + std::string name("psi"); - BtIOService ios(0); - BtEndpoint ep0(ios, "localhost", 1212, true, name); - BtEndpoint ep1(ios, "localhost", 1212, false, name); + BtIOService ios(0); + BtEndpoint ep0(ios, "localhost", 1212, true, name); + BtEndpoint ep1(ios, "localhost", 1212, false, name); - std::vector recvChl{ &ep1.addChannel(name, name) }; - std::vector sendChl{ &ep0.addChannel(name, name) }; + std::vector recvChl{ &ep1.addChannel(name, name) }; + std::vector sendChl{ &ep0.addChannel(name, name) }; KkrtNcoOtReceiver otRecv0, otRecv1; KkrtNcoOtSender otSend0, otSend1; - - OPPRFSender send; - OPPRFReceiver recv; - // std::thread thrd([&]() { - - // send.init(numParties,setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); - // send.hash2Bins(sendSet, sendChl); - //send.getOPRFKeys(1,sendChl); - //send.sendSecretSharing(1, sendPayLoads, sendChl); - // send.revSecretSharing(1, recvPayLoads, sendChl); - //Log::out << "send.mSimpleBins.print(true, false, false,false);" << Log::endl; - // send.mSimpleBins.print(1,true, true, true, true); - //Log::out << "send.mCuckooBins.print(true, false, false);" << Log::endl; - //send.mCuckooBins.print(1,true, true, false); - // }); -// recv.init(numParties,setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); - //recv.hash2Bins(recvSet, recvChl); -// recv.getOPRFkeys(0, recvChl); - //recv.revSecretSharing(0, recvPayLoads, recvChl); -// recv.sendSecretSharing(0, sendPayLoads, recvChl); + OPPRFSender send; + OPPRFReceiver recv; + // std::thread thrd([&]() { + + + // send.init(numParties,setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); + // send.hash2Bins(sendSet, sendChl); + //send.getOPRFKeys(1,sendChl); + //send.sendSecretSharing(1, sendPayLoads, sendChl); + // send.revSecretSharing(1, recvPayLoads, sendChl); + //Log::out << "send.mSimpleBins.print(true, false, false,false);" << Log::endl; + // send.mSimpleBins.print(1,true, true, true, true); + //Log::out << "send.mCuckooBins.print(true, false, false);" << Log::endl; + //send.mCuckooBins.print(1,true, true, false); + // }); + // recv.init(numParties,setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); + //recv.hash2Bins(recvSet, recvChl); + // recv.getOPRFkeys(0, recvChl); + //recv.revSecretSharing(0, recvPayLoads, recvChl); + // recv.sendSecretSharing(0, sendPayLoads, recvChl); Log::out << "recv.mCuckooBins.print(true, false, false);" << Log::endl; -// recv.mCuckooBins.print(0,true, true, false); - - //Log::out << "recv.mSimpleBins.print(true, false, false,false);" << Log::endl; - //recv.mSimpleBins.print(0,true, true, true, true); + // recv.mCuckooBins.print(0,true, true, false); + //Log::out << "recv.mSimpleBins.print(true, false, false,false);" << Log::endl; + //recv.mSimpleBins.print(0,true, true, true, true); - //std::vector pThrds(numParties); - //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - //{ - // pThrds[pIdx] = std::thread([&, pIdx]() { - // if (pIdx == 0) - // { - // send.init(numParties, setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); + //std::vector pThrds(numParties); - // } - // else if (pIdx == 1) { - // recv.init(numParties, setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); - // } - // }); - //} + //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + //{ + // pThrds[pIdx] = std::thread([&, pIdx]() { + // if (pIdx == 0) + // { + // send.init(numParties, setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); + + // } + // else if (pIdx == 1) { + // recv.init(numParties, setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); + // } + // }); + //} + + //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + // pThrds[pIdx].join(); - //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - // pThrds[pIdx].join(); - #ifdef PRINT -std::cout << IoStream::lock; -for (u64 i = 1; i < recvPayLoads.size(); ++i) -{ + std::cout << IoStream::lock; + for (u64 i = 1; i < recvPayLoads.size(); ++i) + { Log::out << recvPayLoads[i] << Log::endl; Log::out << sendPayLoads[i] << Log::endl; if (memcmp((u8*)&recvPayLoads[i], &sendPayLoads[i], sizeof(block))) { Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; - Log::out << recvSet[i] << Log::endl; - Log::out << sendSet[i] << Log::endl; + Log::out << recvSet[i] << Log::endl; + Log::out << sendSet[i] << Log::endl; Log::out << i << Log::endl; } } -std::cout << IoStream::unlock; + std::cout << IoStream::unlock; std::cout << IoStream::lock; Log::out << otSend0.mT.size()[1] << Log::endl; @@ -641,20 +642,20 @@ std::cout << IoStream::unlock; #endif -// thrd.join(); + // thrd.join(); + - - - sendChl[0]->close(); - recvChl[0]->close(); - ep0.stop(); - ep1.stop(); - ios.stop(); + sendChl[0]->close(); + recvChl[0]->close(); + + ep0.stop(); + ep1.stop(); + ios.stop(); } void OPPRF_EmptrySet_Test_Impl() { @@ -697,7 +698,7 @@ void OPPRF_EmptrySet_Test_Impl() KkrtNcoOtReceiver otRecv0, otRecv1; KkrtNcoOtSender otSend0, otSend1; OPPRFSender send; - OPPRFReceiver recv; + OPPRFReceiver recv; KkrtNcoOtReceiver otRecv02, otRecv12; KkrtNcoOtSender otSend02, otSend12; @@ -729,30 +730,30 @@ void OPPRF_EmptrySet_Test_Impl() pThrds[pIdx].join();*/ - /*std::thread thrd([&]() { - send.init(numParties, setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); - }); - recv.init(numParties, setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); + /*std::thread thrd([&]() { + send.init(numParties, setSize, psiSecParam, bitSize, sendChl, otSend0, otRecv1, prng.get()); + }); + recv.init(numParties, setSize, psiSecParam, bitSize, recvChl, otRecv0, otSend1, ZeroBlock); - thrd.join();*/ + thrd.join();*/ #ifdef PRINT - //std::cout << IoStream::lock; - //for (u64 i = 1; i < recvPayLoads.size(); ++i) - //{ - // Log::out << recvPayLoads[i] << Log::endl; - // Log::out << sendPayLoads[i] << Log::endl; - // if (memcmp((u8*)&recvPayLoads[i], &sendPayLoads[i], sizeof(block))) - // { - // Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; - // Log::out << recvSet[i] << Log::endl; - // Log::out << sendSet[i] << Log::endl; - // Log::out << i << Log::endl; - // } - - //} - - //std::cout << IoStream::unlock; + //std::cout << IoStream::lock; + //for (u64 i = 1; i < recvPayLoads.size(); ++i) + //{ + // Log::out << recvPayLoads[i] << Log::endl; + // Log::out << sendPayLoads[i] << Log::endl; + // if (memcmp((u8*)&recvPayLoads[i], &sendPayLoads[i], sizeof(block))) + // { + // Log::out << "recvPayLoads[i] != sendPayLoads[i]" << Log::endl; + // Log::out << recvSet[i] << Log::endl; + // Log::out << sendSet[i] << Log::endl; + // Log::out << i << Log::endl; + // } + + //} + + //std::cout << IoStream::unlock; std::cout << IoStream::lock; Log::out << otSend0.mT.size()[1] << Log::endl; @@ -822,7 +823,7 @@ void OPPRF3_EmptrySet_Test_Impl_draft() BtEndpoint ep10(ios, "localhost", 1212, false, name); std::vector recvChl10{ &ep10.addChannel(name, name) }; std::vector sendChl01{ &ep01.addChannel(name, name) }; - + BtEndpoint ep02(ios, "localhost", 1213, true, name); BtEndpoint ep20(ios, "localhost", 1213, false, name); std::vector recvChl20{ &ep20.addChannel(name, name) }; @@ -846,7 +847,7 @@ void OPPRF3_EmptrySet_Test_Impl_draft() std::vector recv(2); std::vector bins(3); - + std::vector pThrds(numParties); //for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) @@ -916,7 +917,7 @@ void OPPRF3_EmptrySet_Test_Impl_draft() #endif -// thrd.join(); + // thrd.join(); @@ -924,11 +925,11 @@ void OPPRF3_EmptrySet_Test_Impl_draft() sendChl01[0]->close(); - sendChl02[0]->close(); - recvChl10[0]->close(); - recvChl20[0]->close(); + sendChl02[0]->close(); + recvChl10[0]->close(); + recvChl20[0]->close(); - ep10.stop(); ep01.stop(); ep20.stop(); ep02.stop(); + ep10.stop(); ep01.stop(); ep20.stop(); ep02.stop(); ios.stop(); } @@ -975,12 +976,12 @@ void OPPRF_EmptrySet_hashing_Test_Impl() std::thread thrd([&]() { bins[0].init(0, numParties, setSize, psiSecParam); u64 otCountSend = bins[0].mSimpleBins.mBins.size(); - send[0].init(numParties, setSize, psiSecParam, bitSize, sendChl, otCountSend,otSend[0], otRecv[0], prng.get()); + send[0].init(numParties, setSize, psiSecParam, bitSize, sendChl, otCountSend, otSend[0], otRecv[0], prng.get()); bins[0].hashing2Bins(sendSet, 2); //send.hash2Bins(sendSet, sendChl); - send[0].getOPRFKeys(1, bins[0], sendChl,true); + send[0].getOPRFKeys(1, bins[0], sendChl, true); send[0].sendSecretSharing(1, bins[0], sendPayLoads, sendChl); //send.revSecretSharing(1, recvPayLoads, sendChl); //Log::out << "send.mSimpleBins.print(true, false, false,false);" << Log::endl; @@ -1010,7 +1011,7 @@ void OPPRF_EmptrySet_hashing_Test_Impl() bins[1].init(1, numParties, setSize, psiSecParam); u64 otCountRecv = bins[1].mCuckooBins.mBins.size(); - recv[0].init(numParties, setSize, psiSecParam, bitSize, recvChl, otCountRecv,otRecv[1], otSend[1], ZeroBlock); + recv[0].init(numParties, setSize, psiSecParam, bitSize, recvChl, otCountRecv, otRecv[1], otSend[1], ZeroBlock); bins[1].hashing2Bins(recvSet, 2); @@ -1094,33 +1095,33 @@ void testShareValue() u64 nextNeighbor = 2;// (myIdx + 1) % nParties; u64 prevNeighbor = (myIdx - 1 + nParties) % nParties; //sum share of other party =0 => compute the share to his neighbor = sum of other shares - for (u64 i = 0; i < setSize; ++i) + for (u64 i = 0; i < setSize; ++i) + { + block sum = ZeroBlock; + //sendPayLoads[nextNeighbor][i] = ZeroBlock; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - block sum = ZeroBlock; - //sendPayLoads[nextNeighbor][i] = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if((idxP!= myIdx && idxP!= nextNeighbor)) - sum = sum ^ sendPayLoads[idxP][i]; - } - std::cout << "sum: " << sum << std::endl; - sendPayLoads[nextNeighbor][i] = sum; + if ((idxP != myIdx && idxP != nextNeighbor)) + sum = sum ^ sendPayLoads[idxP][i]; + } + std::cout << "sum: " << sum << std::endl; + sendPayLoads[nextNeighbor][i] = sum; - block check = ZeroBlock; + block check = ZeroBlock; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if (idxP != myIdx) + for (u64 idxP = 0; idxP < nParties; ++idxP) + { + if (idxP != myIdx) check = check ^ sendPayLoads[idxP][i]; - } - std::cout << "check: " << check << std::endl; + } + std::cout << "check: " << check << std::endl; - block check2 = ZeroBlock; - check2 = sendPayLoads[0][i] ^ sendPayLoads[3][i]; - std::cout << "check2: " << check2 << std::endl; + block check2 = ZeroBlock; + check2 = sendPayLoads[0][i] ^ sendPayLoads[3][i]; + std::cout << "check2: " << check2 << std::endl; + + } - } - } void party(u64 myIdx, u64 setSize, std::vector& mSet) @@ -1137,7 +1138,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) set[i] = mSet[i]; } PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, myIdx)); - set[0]= prng1.get();; + set[0] = prng1.get();; for (u64 idxP = 0; idxP < nParties; ++idxP) { sendPayLoads[idxP].resize(setSize); @@ -1168,7 +1169,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) for (u64 idxP = 0; idxP < nParties; ++idxP) { if (idxP != myIdx) - sendPayLoads[myIdx][i] = sendPayLoads[myIdx][i] ^ sendPayLoads[idxP][i]; + sendPayLoads[myIdx][i] = sendPayLoads[myIdx][i] ^ sendPayLoads[idxP][i]; } } @@ -1183,9 +1184,9 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) if (idxP != myIdx) check = check ^ sendPayLoads[idxP][i]; } - if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) - std::cout << "Error ss values: myIdx: " << myIdx - << " value: "<< check << std::endl; + if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) + std::cout << "Error ss values: myIdx: " << myIdx + << " value: " << check << std::endl; } } else @@ -1216,7 +1217,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender } - else if (i >myIdx) + else if (i > myIdx) { u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver @@ -1241,7 +1242,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) std::vector otRecv(nParties); std::vector otSend(nParties); - std::vector send(nParties - myIdx-1); + std::vector send(nParties - myIdx - 1); std::vector recv(myIdx); binSet bins; @@ -1261,10 +1262,10 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) pThrds[pIdx] = std::thread([&, pIdx]() { if (pIdx < myIdx) { //I am a receiver if other party idx < mine - recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv,otRecv[pIdx], otSend[pIdx], ZeroBlock, true); + recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, true); } else if (pIdx > myIdx) { - send[pIdx- myIdx-1].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), true); + send[pIdx - myIdx - 1].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountSend, otSend[pIdx], otRecv[pIdx], prng.get(), true); } }); } @@ -1319,10 +1320,10 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) pThrds[pIdx] = std::thread([&, pIdx]() { if (pIdx < myIdx) { //I am a receiver if other party idx < mine - recv[pIdx].getOPRFkeys(pIdx, bins,chls[pIdx], true); - } + recv[pIdx].getOPRFkeys(pIdx, bins, chls[pIdx], true); + } else if (pIdx > myIdx) { - send[pIdx - myIdx - 1].getOPRFKeys(pIdx, bins,chls[pIdx], true); + send[pIdx - myIdx - 1].getOPRFKeys(pIdx, bins, chls[pIdx], true); } }); } @@ -1353,7 +1354,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { pThrds[pIdx] = std::thread([&, pIdx]() { - if ((pIdx < myIdx && pIdx!= prevNeighbor)) { + if ((pIdx < myIdx && pIdx != prevNeighbor)) { //I am a receiver if other party idx < mine recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); @@ -1363,7 +1364,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) send[pIdx - myIdx - 1].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); } - else if (pIdx == prevNeighbor && myIdx !=0) { + else if (pIdx == prevNeighbor && myIdx != 0) { recv[pIdx].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); } else if (pIdx == nextNeighbor && myIdx != nParties - 1) @@ -1374,12 +1375,12 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) else if (pIdx == nParties - 1 && myIdx == 0) { send[pIdx - myIdx - 1].sendSecretSharing(pIdx, bins, sendPayLoads[pIdx], chls[pIdx]); } - + else if (pIdx == 0 && myIdx == nParties - 1) { recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); } - + }); } @@ -1394,13 +1395,13 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) { for (int i = 0; i < 3; i++) { - block temp=ZeroBlock; - memcpy((u8*)&temp,(u8*)&sendPayLoads[2][i],maskSize); + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[2][i], maskSize); Log::out << "s " << myIdx << " - 2: Idx" << i << " - " << temp << Log::endl; block temp1 = ZeroBlock; memcpy((u8*)&temp1, (u8*)&recvPayLoads[2][i], maskSize); - Log::out << "r " << myIdx << " - 2: Idx" << i << " - " << temp1 << Log::endl; + Log::out << "r " << myIdx << " - 2: Idx" << i << " - " << temp1 << Log::endl; } Log::out << "------------" << Log::endl; } @@ -1410,7 +1411,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) { block temp = ZeroBlock; memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); - Log::out <<"r " << myIdx << " - 0: Idx" << i << " - " << temp << Log::endl; + Log::out << "r " << myIdx << " - 0: Idx" << i << " - " << temp << Log::endl; block temp1 = ZeroBlock; memcpy((u8*)&temp1, (u8*)&sendPayLoads[0][i], maskSize); @@ -1427,51 +1428,51 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) if (myIdx == 0) { // Xor the received shares - for (u64 i = 0; i < setSize; ++i) + for (u64 i = 0; i < setSize; ++i) + { + for (u64 idxP = 0; idxP < nParties; ++idxP) { - for (u64 idxP = 0; idxP < nParties ; ++idxP) - { - if(idxP != myIdx && idxP != prevNeighbor) + if (idxP != myIdx && idxP != prevNeighbor) sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; - } } + } send[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); send[nextNeighbor - myIdx - 1].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); } - else if(myIdx == nParties - 1) + else if (myIdx == nParties - 1) { - recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - - //Xor the received shares - for (u64 i = 0; i < setSize; ++i) + recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); + + //Xor the received shares + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; + for (u64 idxP = 0; idxP < nParties; ++idxP) { - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; - for (u64 idxP = 0; idxP < nParties; ++idxP) - { - if(idxP != myIdx && idxP != prevNeighbor) + if (idxP != myIdx && idxP != prevNeighbor) sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; - } } - - recv[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); + } + + recv[nextNeighbor].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); } else { - recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); - //Xor the received shares - for (u64 i = 0; i < setSize; ++i) - { - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; - for (u64 idxP = 0; idxP < nParties; ++idxP) + recv[prevNeighbor].revSecretSharing(prevNeighbor, bins, recvPayLoads[prevNeighbor], chls[prevNeighbor]); + //Xor the received shares + for (u64 i = 0; i < setSize; ++i) { - if(idxP != myIdx && idxP != prevNeighbor) - sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[prevNeighbor][i]; + for (u64 idxP = 0; idxP < nParties; ++idxP) + { + if (idxP != myIdx && idxP != prevNeighbor) + sendPayLoads[nextNeighbor][i] = sendPayLoads[nextNeighbor][i] ^ recvPayLoads[idxP][i]; + } } - } - send[nextNeighbor - myIdx - 1].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); + send[nextNeighbor - myIdx - 1].sendSecretSharing(nextNeighbor, bins, sendPayLoads[nextNeighbor], chls[nextNeighbor]); } auto getSSDoneRound = timer.setTimePoint("getSSDoneRound"); @@ -1530,7 +1531,7 @@ void party(u64 myIdx, u64 setSize, std::vector& mSet) auto ssRoundTime = std::chrono::duration_cast(getSSDoneRound - getSSDone2Dir).count(); auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDoneRound).count(); - double onlineTime = hashingTime + getOPRFTime + ss2DirTime+ ssRoundTime + intersectionTime; + double onlineTime = hashingTime + getOPRFTime + ss2DirTime + ssRoundTime + intersectionTime; double time = offlineTime + onlineTime; time /= 1000; @@ -1586,21 +1587,21 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) std::vector set(setSize); for (u64 i = 0; i < setSize; ++i) set[i] = mSet[i]; - + PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, myIdx)); //for test set[0] = prng1.get();; std::vector sendPayLoads(setSize); std::vector recvPayLoads(setSize); - + //only P0 genaretes secret sharing if (myIdx == 0) { for (u64 i = 0; i < setSize; ++i) sendPayLoads[i] = prng.get(); } - - + + std::string name("psi"); BtIOService ios(0); @@ -1614,7 +1615,7 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender } - else if (i >myIdx) + else if (i > myIdx) { u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver @@ -1654,7 +1655,7 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) u64 otCountRecv = bins.mCuckooBins.mBins.size(); u64 nextNeibough = (myIdx + 1) % nParties; - u64 prevNeibough = (myIdx - 1+ nParties) % nParties; + u64 prevNeibough = (myIdx - 1 + nParties) % nParties; for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { @@ -1667,7 +1668,7 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) else if (pIdx == prevNeibough) { //I am a recv to my previous neigbour recv.init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, false); - } + } }); } @@ -1750,7 +1751,7 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) //### online phasing - secretsharing //########################## if (myIdx == 0) - { + { send.sendSecretSharing(nextNeibough, bins, sendPayLoads, chls[nextNeibough]); recv.revSecretSharing(prevNeibough, bins, recvPayLoads, chls[prevNeibough]); @@ -1797,14 +1798,14 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; for (u64 i = 0; i < setSize; ++i) { - // if (sendPayLoads[i]== recvPayLoads[i]) + // if (sendPayLoads[i]== recvPayLoads[i]) if (!memcmp((u8*)&sendPayLoads[i], &recvPayLoads[i], maskSize)) { mIntersection.push_back(i); } } Log::out << mIntersection.size() << Log::endl; - } + } for (u64 i = 0; i < nParties; ++i) @@ -1828,149 +1829,840 @@ void party3(u64 myIdx, u64 setSize, std::vector& mSet) ios.stop(); } -void Channel_party_test(u64 myIdx) +bool is_in_dual_area(u64 startIdx, u64 endIdx, u64 numIdx, u64 checkIdx) { + bool res = false; + if (startIdx <= endIdx) + { + if (startIdx <= checkIdx && checkIdx <= endIdx) + res = true; + } + else //crosing 0, e.i, areas: startIdx....n-1, 0...endIdx + { + if ((0 <= checkIdx && checkIdx <= endIdx) //0...endIdx + || (startIdx <= checkIdx && checkIdx <= numIdx)) + //startIdx...n-1 + res = true; + } + return res; +} + +//leader is n-1 +void tparty(u64 myIdx, u64 nParties, u64 tParties, u64 setSize, std::vector& mSet, u64 nTrials) { - u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128, numThreads = 1; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); +#pragma region setup - std::vector dummy(nParties); - std::vector revDummy(nParties); + //nParties = 4; + /*std::fstream runtime; + if (myIdx == 0) + runtime.open("./runtime" + nParties, runtime.trunc | runtime.out);*/ + + u64 leaderIdx = nParties - 1; //leader party + u64 nSS = nParties - 1; //n-2 parties joinly operated secrete sharing + int tSS = tParties; //ss with t next parties, and last for leader => t+1 + + + u64 offlineAvgTime(0), hashingAvgTime(0), getOPRFAvgTime(0), + ss2DirAvgTime(0), ssRoundAvgTime(0), intersectionAvgTime(0), onlineAvgTime(0); + + u64 psiSecParam = 40, bitSize = 128, numThreads = 1; + PRNG prng(_mm_set_epi32(4253465, 3434565, myIdx, myIdx)); + std::string name("psi"); BtIOService ios(0); - int btCount = nParties; + std::vector ep(nParties); for (u64 i = 0; i < nParties; ++i) { - dummy[i] = myIdx * 10 + i; if (i < myIdx) { - u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 + u32 port = 1120 + i * 100 + myIdx;;//get the same port; i=1 & pIdx=2 =>port=102 ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender } - else if (i >myIdx) + else if (i > myIdx) { - u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 + u32 port = 1120 + myIdx * 100 + i;//get the same port; i=2 & pIdx=1 =>port=102 ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver } } - std::vector> chls(nParties); + std::vector dummy(nParties); + std::vector revDummy(nParties); for (u64 i = 0; i < nParties; ++i) { + dummy[i] = myIdx * 10 + i; + if (i != myIdx) { chls[i].resize(numThreads); for (u64 j = 0; j < numThreads; ++j) { //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); chls[i][j] = &ep[i].addChannel(name, name); + //chls[i][j].mEndpoint; + + + } } } + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + u64 nextNeighbor = (myIdx + 1) % nParties; + u64 prevNeighbor = (myIdx - 1 + nParties) % nParties; + +#pragma endregion - std::mutex printMtx1, printMtx2; - std::vector pThrds(nParties); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + for (u64 idxTrial = 0; idxTrial < nTrials; idxTrial++) { - pThrds[pIdx] = std::thread([&, pIdx]() { - if (pIdx < myIdx) { +#pragma region input + std::vector set(setSize); + std::vector> + sendPayLoads(tParties + 1), //include the last PayLoads to leader + recvPayLoads(tParties); //received form clients - chls[pIdx][0]->asyncSend(&dummy[pIdx], 1); - //std::lock_guard lock(printMtx1); - // std::cout << "s: " << myIdx << " -> " << pIdx << " : " << static_cast(dummy[pIdx]) << std::endl; - - } - else if (pIdx > myIdx) { + for (u64 i = 0; i < setSize; ++i) + { + set[i] = mSet[i]; + } + PRNG prng1(_mm_set_epi32(4253465, 3434565, 234435, myIdx)); + set[0] = prng1.get();; - chls[pIdx][0]->recv(&revDummy[pIdx], 1); - std::lock_guard lock(printMtx2); - std::cout << "r: " << myIdx << " <- " << pIdx << " : " << static_cast(revDummy[pIdx]) << std::endl; + if (myIdx != leaderIdx) {//generate share of zero for leader myIDx!=n-1 + for (u64 idxP = 0; idxP < tParties; ++idxP) + { + sendPayLoads[idxP].resize(setSize); + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[idxP][i] = prng.get(); + } } - }); - } - - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - { - // if(pIdx!=myIdx) - pThrds[pIdx].join(); - } + sendPayLoads[tParties].resize(setSize); //share to leader at second phase + for (u64 i = 0; i < setSize; ++i) + { + sendPayLoads[tParties][i] = ZeroBlock; + for (u64 idxP = 0; idxP < tParties; ++idxP) + { + sendPayLoads[tParties][i] = + sendPayLoads[tParties][i] ^ sendPayLoads[idxP][i]; + } + } + for (u64 idxP = 0; idxP < recvPayLoads.size(); ++idxP) + { + recvPayLoads[idxP].resize(setSize); + } + } + else + { + //leader: dont send; only receive ss from clients + sendPayLoads.resize(0);// + recvPayLoads.resize(nParties - 1); + for (u64 idxP = 0; idxP < recvPayLoads.size(); ++idxP) + { + recvPayLoads[idxP].resize(setSize); + } + } - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) - { - for (u64 j = 0; j < numThreads; ++j) +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx != leaderIdx) { + for (u64 i = 0; i < setSize; ++i) { - chls[i][j]->close(); + block check = ZeroBlock; + for (u64 idxP = 0; idxP < tParties + 1; ++idxP) + { + //if (idxP != myIdx) + check = check ^ sendPayLoads[idxP][i]; + } + if (memcmp((u8*)&check, &ZeroBlock, sizeof(block))) + std::cout << "Error ss values: myIdx: " << myIdx + << " value: " << check << std::endl; } } - } + std::cout << IoStream::unlock; +#endif +#pragma endregion + u64 num_threads = nParties - 1; //except P0, and my + bool isDual = true; + u64 idx_start_dual = 0; + u64 idx_end_dual = 0; + u64 t_prev_shift = tSS; + + if (myIdx != leaderIdx) { + if (2 * tSS < nSS) + { + num_threads = 2 * tSS + 1; + isDual = false; + } + else { + idx_start_dual = (myIdx - tSS + nSS) % nSS; + idx_end_dual = (myIdx + tSS) % nSS; + } - for (u64 i = 0; i < nParties; ++i) - { - if (i != myIdx) - ep[i].stop(); - } + std::cout << IoStream::lock; + std::cout << myIdx << "| " << idx_start_dual << " " << idx_end_dual << "\n"; + std::cout << IoStream::unlock; + } + std::vector pThrds(num_threads); + std::vector otRecv(nParties); + std::vector otSend(nParties); + std::vector send(nParties); + std::vector recv(nParties); - ios.stop(); -} + if (myIdx == leaderIdx) + { + /*otRecv.resize(nParties - 1); + otSend.resize(nParties - 1); + send.resize(nParties - 1); + recv.resize(nParties - 1);*/ + pThrds.resize(nParties - 1); + } -void OPPRFn_EmptrySet_Test_Impl() -{ - u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - mSet.resize(setSize); - for (u64 i = 0; i < setSize; ++i) - { - mSet[i] = prng.get(); + + + binSet bins; + + //########################## + //### Offline Phasing + //########################## + Timer timer; + auto start = timer.setTimePoint("start"); + bins.init(myIdx, nParties, setSize, psiSecParam); + u64 otCountSend = bins.mSimpleBins.mBins.size(); + u64 otCountRecv = bins.mCuckooBins.mBins.size(); + + +#pragma region base OT + //########################## + //### Base OT + //########################## + + if (myIdx != leaderIdx) + { + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; + + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) + { + u64 thr = t_prev_shift + pIdx; + + pThrds[thr] = std::thread([&, prevIdx, thr]() { + + chls[prevIdx][0]->recv(&revDummy[prevIdx], 1); + + std::cout << IoStream::lock; + std::cout << myIdx << "| : " << "| thr[" << thr << "]:" << prevIdx << " --> " << myIdx << ": " << static_cast(revDummy[prevIdx]) << "\n"; + + std::cout << IoStream::unlock; + + + //prevIdx << " --> " << myIdx + recv[prevIdx].init(nParties, setSize, psiSecParam, bitSize, chls[prevIdx], otCountRecv, otRecv[prevIdx], otSend[prevIdx], ZeroBlock, false); + + }); + + + + } + } + + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; + + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { + + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + + + //dual myIdx << " <-> " << nextIdx + if (myIdx < nextIdx) + { + chls[nextIdx][0]->asyncSend(&dummy[nextIdx], 1); + std::cout << IoStream::lock; + std::cout << myIdx << "| d: " << "| thr[" << pIdx << "]:" << myIdx << " <->> " << nextIdx << ": " << static_cast(dummy[nextIdx]) << "\n"; + std::cout << IoStream::unlock; + + send[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountSend, otSend[nextIdx], otRecv[nextIdx], prng.get(), true); + } + else if (myIdx > nextIdx) //by index + { + chls[nextIdx][0]->recv(&revDummy[nextIdx], 1); + + std::cout << IoStream::lock; + std::cout << myIdx << "| d: " << "| thr[" << pIdx << "]:" << myIdx << " <<-> " << nextIdx << ": " << static_cast(revDummy[nextIdx]) << "\n"; + std::cout << IoStream::unlock; + + recv[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountRecv, otRecv[nextIdx], otSend[nextIdx], ZeroBlock, true); + } + }); + + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + + chls[nextIdx][0]->asyncSend(&dummy[nextIdx], 1); + std::cout << IoStream::lock; + std::cout << myIdx << "| : " << "| thr[" << pIdx << "]:" << myIdx << " -> " << nextIdx << ": " << static_cast(dummy[nextIdx]) << "\n"; + std::cout << IoStream::unlock; + send[nextIdx].init(nParties, setSize, psiSecParam, bitSize, chls[nextIdx], otCountSend, otSend[nextIdx], otRecv[nextIdx], prng.get(), false); + }); + } + } + + //last thread for connecting with leader + u64 tLeaderIdx = pThrds.size() - 1; + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + + chls[leaderIdx][0]->asyncSend(&dummy[leaderIdx], 1); + + std::cout << IoStream::lock; + std::cout << myIdx << "| : " << "| thr[" << pThrds.size() - 1 << "]:" << myIdx << " --> " << leaderIdx << ": " << static_cast(dummy[leaderIdx]) << "\n"; + std::cout << IoStream::unlock; + + send[leaderIdx].init(nParties, setSize, psiSecParam, bitSize, chls[leaderIdx], otCountSend, otSend[leaderIdx], otRecv[leaderIdx], prng.get(), false); + }); + + } + else + { //leader party + + for (u64 pIdx = 0; pIdx < nSS; ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + chls[pIdx][0]->recv(&revDummy[pIdx], 1); + std::cout << IoStream::lock; + std::cout << myIdx << "| : " << "| thr[" << pIdx << "]:" << pIdx << " --> " << myIdx << ": " << static_cast(revDummy[pIdx]) << "\n"; + std::cout << IoStream::unlock; + + recv[pIdx].init(nParties, setSize, psiSecParam, bitSize, chls[pIdx], otCountRecv, otRecv[pIdx], otSend[pIdx], ZeroBlock, false); + }); + + } + } + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + + auto initDone = timer.setTimePoint("initDone"); + + +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + Log::out << myIdx << "| -> " << otSend[1].mGens[0].get() << Log::endl; + if (otRecv[1].hasBaseOts()) + { + Log::out << myIdx << "| <- " << otRecv[1].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <- " << otRecv[1].mGens[0][1].get() << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == 1) + { + if (otSend[0].hasBaseOts()) + Log::out << myIdx << "| -> " << otSend[0].mGens[0].get() << Log::endl; + + Log::out << myIdx << "| <- " << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <- " << otRecv[0].mGens[0][1].get() << Log::endl; + } + + if (isDual) + { + if (myIdx == 0) + { + Log::out << myIdx << "| <->> " << otSend[tSS].mGens[0].get() << Log::endl; + if (otRecv[tSS].hasBaseOts()) + { + Log::out << myIdx << "| <<-> " << otRecv[tSS].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <<-> " << otRecv[tSS].mGens[0][1].get() << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == tSS) + { + if (otSend[0].hasBaseOts()) + Log::out << myIdx << "| <->> " << otSend[0].mGens[0].get() << Log::endl; + + Log::out << myIdx << "| <<-> " << otRecv[0].mGens[0][0].get() << Log::endl; + Log::out << myIdx << "| <<-> " << otRecv[0].mGens[0][1].get() << Log::endl; + } + } + std::cout << IoStream::unlock; +#endif + +#pragma endregion + + + //########################## + //### Hashing + //########################## + bins.hashing2Bins(set, 1); + + /*if(myIdx==0) + bins.mSimpleBins.print(myIdx, true, false, false, false); + if (myIdx == 1) + bins.mCuckooBins.print(myIdx, true, false, false);*/ + + auto hashingDone = timer.setTimePoint("hashingDone"); + +#pragma region compute OPRF + + //########################## + //### Online Phasing - compute OPRF + //########################## + + pThrds.clear(); + pThrds.resize(num_threads); + if (myIdx == leaderIdx) + { + pThrds.resize(nParties - 1); + } + + if (myIdx != leaderIdx) + { + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; + + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) + { + u64 thr = t_prev_shift + pIdx; + + pThrds[thr] = std::thread([&, prevIdx]() { + + //prevIdx << " --> " << myIdx + recv[prevIdx].getOPRFkeys(prevIdx, bins, chls[prevIdx], false); + + }); + } + } + + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; + + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { + + pThrds[pIdx] = std::thread([&, nextIdx]() { + //dual myIdx << " <-> " << nextIdx + if (myIdx < nextIdx) + { + send[nextIdx].getOPRFKeys(nextIdx, bins, chls[nextIdx], true); + } + else if (myIdx > nextIdx) //by index + { + recv[nextIdx].getOPRFkeys(nextIdx, bins, chls[nextIdx], true); + } + }); + + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx]() { + send[nextIdx].getOPRFKeys(nextIdx, bins, chls[nextIdx], false); + }); + } + } + + //last thread for connecting with leader + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + send[leaderIdx].getOPRFKeys(leaderIdx, bins, chls[leaderIdx], false); + }); + + } + else + { //leader party + for (u64 pIdx = 0; pIdx < nSS; ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + recv[pIdx].getOPRFkeys(pIdx, bins, chls[pIdx], false); + }); + } + } + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + + auto getOPRFDone = timer.setTimePoint("getOPRFDone"); + + +#ifdef BIN_PRINT + + if (myIdx == 0) + { + bins.mSimpleBins.print(1, true, true, false, false); + } + if (myIdx == 1) + { + bins.mCuckooBins.print(0, true, true, false); + } + + if (isDual) + { + if (myIdx == 0) + { + bins.mCuckooBins.print(tSS, true, true, false); + } + if (myIdx == tSS) + { + bins.mSimpleBins.print(0, true, true, false, false); + } + } + +#endif +#pragma endregion + +#pragma region SS + + //########################## + //### online phasing - secretsharing + //########################## + + pThrds.clear(); + + if (myIdx != leaderIdx) + { + pThrds.resize(num_threads); + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 prevIdx = (myIdx - pIdx - 1 + nSS) % nSS; + + if (!(isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, prevIdx))) + { + u64 thr = t_prev_shift + pIdx; + + pThrds[thr] = std::thread([&, prevIdx, pIdx]() { + + //prevIdx << " --> " << myIdx + recv[prevIdx].revSecretSharing(prevIdx, bins, recvPayLoads[pIdx], chls[prevIdx]); + + }); + } + } + + for (u64 pIdx = 0; pIdx < tSS; ++pIdx) + { + u64 nextIdx = (myIdx + pIdx + 1) % nSS; + + if ((isDual && is_in_dual_area(idx_start_dual, idx_end_dual, nSS, nextIdx))) { + + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + //dual myIdx << " <-> " << nextIdx + //send OPRF can receive payload + if (myIdx < nextIdx) + { + send[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); + + send[nextIdx].revSecretSharing(nextIdx, bins, recvPayLoads[pIdx], chls[nextIdx]); + } + else if (myIdx > nextIdx) //by index + { + recv[nextIdx].revSecretSharing(nextIdx, bins, recvPayLoads[pIdx], chls[nextIdx]); + + recv[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); + + } + }); + + } + else + { + pThrds[pIdx] = std::thread([&, nextIdx, pIdx]() { + send[nextIdx].sendSecretSharing(nextIdx, bins, sendPayLoads[pIdx], chls[nextIdx]); + }); + } + } + + //last thread for connecting with leader + pThrds[pThrds.size() - 1] = std::thread([&, leaderIdx]() { + //send[leaderIdx].getOPRFKeys(leaderIdx, bins, chls[leaderIdx], false); + }); + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + } + + auto getSSDone2Dir = timer.setTimePoint("secretsharingDone"); + + +#ifdef PRINT + std::cout << IoStream::lock; + if (myIdx == 0) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[0][i], maskSize); + Log::out << myIdx << "| -> 1: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == 1) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[0][i], maskSize); + Log::out << myIdx << "| <- 0: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + } + + if (isDual) + { + /*if (myIdx == 0) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&recvPayLoads[tSS][i], maskSize); + Log::out << myIdx << "| <- "<< tSS<<": (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + } + if (myIdx == tSS) + { + for (int i = 0; i < 3; i++) + { + block temp = ZeroBlock; + memcpy((u8*)&temp, (u8*)&sendPayLoads[0][i], maskSize); + Log::out << myIdx << "| -> 0: (" << i << ", " << temp << ")" << Log::endl; + } + Log::out << "------------" << Log::endl; + }*/ + } + + std::cout << IoStream::unlock; +#endif +#pragma endregion + + //########################## + //### online phasing - send XOR of zero share to leader + //########################## + pThrds.clear(); + + if (myIdx != leaderIdx) + { + + for (u64 i = 0; i < setSize; ++i) + { + //xor all received share + for (u64 idxP = 0; idxP < tParties; ++idxP) + { + sendPayLoads[tParties][i] = sendPayLoads[tParties][i] ^ recvPayLoads[idxP][i]; + } + } + //send to leader + send[leaderIdx].sendSecretSharing(leaderIdx, bins, sendPayLoads[tParties], chls[leaderIdx]); + } + else + { + pThrds.resize(nParties - 1); + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { + pThrds[pIdx] = std::thread([&, pIdx]() { + recv[pIdx].revSecretSharing(pIdx, bins, recvPayLoads[pIdx], chls[pIdx]); + }); + } + + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); + } + + + auto getSSDoneRound = timer.setTimePoint("leaderGetXorDone"); + + + //########################## + //### online phasing - compute intersection + //########################## + + if (myIdx == leaderIdx) { + std::vector mIntersection; + u64 maskSize = roundUpTo(psiSecParam + 2 * std::log(setSize) - 1, 8) / 8; + + for (u64 i = 0; i < setSize; ++i) + { + + //xor all received share + block sum = ZeroBlock; + for (u64 idxP = 0; idxP < nParties - 1; ++idxP) + { + sum = sum ^ recvPayLoads[idxP][i]; + } + + if (!memcmp((u8*)&ZeroBlock, &sum, maskSize)) + { + mIntersection.push_back(i); + } + } + Log::out << "mIntersection.size(): " << mIntersection.size() << Log::endl; + } + auto getIntersection = timer.setTimePoint("getIntersection"); + + + + + //auto Mbps = dataSent * 8 / time / (1 << 20); + + + if (myIdx == 0 || myIdx==leaderIdx) { + auto offlineTime = std::chrono::duration_cast(initDone - start).count(); + auto hashingTime = std::chrono::duration_cast(hashingDone - initDone).count(); + auto getOPRFTime = std::chrono::duration_cast(getOPRFDone - hashingDone).count(); + auto ss2DirTime = std::chrono::duration_cast(getSSDone2Dir - getOPRFDone).count(); + auto ssRoundTime = std::chrono::duration_cast(getSSDoneRound - getSSDone2Dir).count(); + auto intersectionTime = std::chrono::duration_cast(getIntersection - getSSDoneRound).count(); + + double onlineTime = hashingTime + getOPRFTime + ss2DirTime + ssRoundTime + intersectionTime; + + double time = offlineTime + onlineTime; + time /= 1000; + + u64 dataSent = 0; + + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + dataSent += chls[i][j]->getTotalDataSent(); + } + } + } + auto Mbps = dataSent * 8 / time / (1 << 20); + + std::cout << setSize << " " << offlineTime << " " << onlineTime << " " << Mbps << " Mbps " << (dataSent / std::pow(2.0, 20)) << " MB" << std::endl; + + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->resetStats(); + } + } + } + + std::cout << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineTime << " ms\n" + << "hashingTime: " << hashingTime << " ms\n" + << "getOPRFTime: " << getOPRFTime << " ms\n" + << "ss2DirTime: " << ss2DirTime << " ms\n" + << "ssRoundTime: " << ssRoundTime << " ms\n" + << "intersection: " << intersectionTime << " ms\n" + << "onlineTime: " << onlineTime << " ms\n" + << "Total time: " << time << " s\n" + << "------------------\n"; + + + offlineAvgTime += offlineTime; + hashingAvgTime += hashingTime; + getOPRFAvgTime += getOPRFTime; + ss2DirAvgTime += ss2DirTime; + ssRoundAvgTime += ssRoundTime; + intersectionAvgTime += intersectionTime; + onlineAvgTime += onlineTime; + + } + + } + + + /*if (myIdx == 0) { + double avgTime = (offlineAvgTime + onlineAvgTime); + avgTime /= 1000; + std::cout << "=========avg==========\n" + << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime << "setSize: " << setSize << "\n" + << "offlineTime: " << offlineAvgTime / numTrial << " ms\n" + << "hashingTime: " << hashingAvgTime / numTrial << " ms\n" + << "getOPRFTime: " << getOPRFAvgTime / numTrial << " ms\n" + << "ss2DirTime: " << ss2DirAvgTime << " ms\n" + << "ssRoundTime: " << ssRoundAvgTime << " ms\n" + << "intersection: " << intersectionAvgTime / numTrial << " ms\n" + << "onlineTime: " << onlineAvgTime / numTrial << " ms\n" + << "Total time: " << avgTime / numTrial << " s\n"; + runtime.close(); } - std::vector pThrds(nParties); - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + */ + for (u64 i = 0; i < nParties; ++i) { - pThrds[pIdx] = std::thread([&, pIdx]() { - // Channel_party_test(pIdx); - party(pIdx, setSize, mSet); - }); + if (i != myIdx) + { + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->close(); + } + } } - for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) - pThrds[pIdx].join(); + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) + ep[i].stop(); + } -} -void OPPRF3_EmptrySet_Test_Impl() + ios.stop(); + } +void OPPRFnt_EmptrySet_Test_Impl() { - nParties = 3; u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); mSet.resize(setSize); for (u64 i = 0; i < setSize; ++i) { mSet[i] = prng.get(); } + nParties = 5; + u64 tParties = 2; + + if (tParties == nParties - 1)//max ss = n-1 + tParties--; + else if (tParties < 1) //make sure to do ss with at least one client + tParties = 1; std::vector pThrds(nParties); for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) { - pThrds[pIdx] = std::thread([&, pIdx]() { - // Channel_party_test(pIdx); - party3(pIdx, setSize, mSet); - }); + //if (pIdx == 0) + //{ + // //tparty0(pIdx, nParties, 1, setSize, mSet); + //} + //else + { + pThrds[pIdx] = std::thread([&, pIdx]() { + // Channel_party_test(pIdx); + tparty(pIdx, nParties, tParties, mSet.size(), mSet, 2); + }); + } } for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) pThrds[pIdx].join(); @@ -1978,126 +2670,152 @@ void OPPRF3_EmptrySet_Test_Impl() } - -#if 0 -void OPPRF_FullSet_Test_Impl() +void Channel_party_test(u64 myIdx) { - setThreadName("CP_Test_Thread"); - u64 setSize = 8, psiSecParam = 40, numThreads(1), bitSize = 128; - PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - - - std::vector sendSet(setSize), recvSet(setSize); - for (u64 i = 0; i < setSize; ++i) - { - sendSet[i] = recvSet[i] = prng.get(); - } - - std::shuffle(sendSet.begin(), sendSet.end(), prng); - - - std::string name("psi"); + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128, numThreads = 1; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); - BtIOService ios(0); - BtEndpoint ep0(ios, "localhost", 1212, true, name); - BtEndpoint ep1(ios, "localhost", 1212, false, name); + std::vector dummy(nParties); + std::vector revDummy(nParties); - std::vector sendChls(numThreads), recvChls(numThreads); - for (u64 i = 0; i < numThreads; ++i) - { - sendChls[i] = &ep1.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - recvChls[i] = &ep0.addChannel("chl" + std::to_string(i), "chl" + std::to_string(i)); - } + std::string name("psi"); + BtIOService ios(0); - KkrtNcoOtReceiver otRecv; - KkrtNcoOtSender otSend; + int btCount = nParties; + std::vector ep(nParties); - OPPRFSender send; - OPPRFReceiver recv; - std::thread thrd([&]() { + for (u64 i = 0; i < nParties; ++i) + { + dummy[i] = myIdx * 10 + i; + if (i < myIdx) + { + u32 port = i * 10 + myIdx;//get the same port; i=1 & pIdx=2 =>port=102 + ep[i].start(ios, "localhost", port, false, name); //channel bwt i and pIdx, where i is sender + } + else if (i > myIdx) + { + u32 port = myIdx * 10 + i;//get the same port; i=2 & pIdx=1 =>port=102 + ep[i].start(ios, "localhost", port, true, name); //channel bwt i and pIdx, where i is receiver + } + } - send.init(setSize, psiSecParam, bitSize, sendChls, otSend, prng.get()); - // send.sendInput(sendSet, sendChls); - }); + std::vector> chls(nParties); - recv.init(setSize, psiSecParam, bitSize, recvChls, otRecv, ZeroBlock); - // recv.sendInput(recvSet, recvChls); + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) { + chls[i].resize(numThreads); + for (u64 j = 0; j < numThreads; ++j) + { + //chls[i][j] = &ep[i].addChannel("chl" + std::to_string(j), "chl" + std::to_string(j)); + chls[i][j] = &ep[i].addChannel(name, name); + } + } + } - /* if (recv.mIntersection.size() != setSize) - throw UnitTestFail();*/ - thrd.join(); + std::mutex printMtx1, printMtx2; + std::vector pThrds(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + if (pIdx < myIdx) { - for (u64 i = 0; i < numThreads; ++i) - { - sendChls[i]->close(); - recvChls[i]->close(); - } - ep0.stop(); - ep1.stop(); - ios.stop(); + chls[pIdx][0]->asyncSend(&dummy[pIdx], 1); + //std::lock_guard lock(printMtx1); + // std::cout << "s: " << myIdx << " -> " << pIdx << " : " << static_cast(dummy[pIdx]) << std::endl; -} + } + else if (pIdx > myIdx) { -void OPPRF_SingltonSet_Test_Impl() -{ - setThreadName("Sender"); - u64 setSize = 128, psiSecParam = 40, bitSize = 128; + chls[pIdx][0]->recv(&revDummy[pIdx], 1); + std::lock_guard lock(printMtx2); + std::cout << "r: " << myIdx << " <- " << pIdx << " : " << static_cast(revDummy[pIdx]) << std::endl; - PRNG prng(_mm_set_epi32(4253465, 34354565, 234435, 23987045)); + } + }); + } - std::vector sendSet(setSize), recvSet(setSize); - for (u64 i = 0; i < setSize; ++i) - { - sendSet[i] = prng.get(); - recvSet[i] = prng.get(); - } - sendSet[0] = recvSet[0]; + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + // if(pIdx!=myIdx) + pThrds[pIdx].join(); + } - std::string name("psi"); - BtIOService ios(0); - BtEndpoint ep0(ios, "localhost", 1212, true, name); - BtEndpoint ep1(ios, "localhost", 1212, false, name); - Channel& recvChl = ep1.addChannel(name, name); - Channel& sendChl = ep0.addChannel(name, name); - KkrtNcoOtReceiver otRecv; - KkrtNcoOtSender otSend; + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) + { + for (u64 j = 0; j < numThreads; ++j) + { + chls[i][j]->close(); + } + } + } - OPPRFSender send; - OPPRFReceiver recv; - std::thread thrd([&]() { + for (u64 i = 0; i < nParties; ++i) + { + if (i != myIdx) + ep[i].stop(); + } - send.init(setSize, psiSecParam, bitSize, sendChl, otSend, prng.get()); - //send.sendInput(sendSet, sendChl); - }); + ios.stop(); +} - recv.init(setSize, psiSecParam, bitSize, recvChl, otRecv, ZeroBlock); - // recv.sendInput(recvSet, recvChl); +void OPPRFn_EmptrySet_Test_Impl() +{ + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + mSet.resize(setSize); + for (u64 i = 0; i < setSize; ++i) + { + mSet[i] = prng.get(); + } + std::vector pThrds(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + // Channel_party_test(pIdx); + party(pIdx, setSize, mSet); + }); + } + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); - thrd.join(); - /*if (recv.mIntersection.size() != 1 || - recv.mIntersection[0] != 0) - throw UnitTestFail();*/ +} +void OPPRF3_EmptrySet_Test_Impl() +{ + nParties = 3; + u64 setSize = 1 << 5, psiSecParam = 40, bitSize = 128; - //std::cout << gTimer << std::endl; + PRNG prng(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); + mSet.resize(setSize); + for (u64 i = 0; i < setSize; ++i) + { + mSet[i] = prng.get(); + } + std::vector pThrds(nParties); + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + { + pThrds[pIdx] = std::thread([&, pIdx]() { + // Channel_party_test(pIdx); + party3(pIdx, setSize, mSet); + }); + } + for (u64 pIdx = 0; pIdx < pThrds.size(); ++pIdx) + pThrds[pIdx].join(); - sendChl.close(); - recvChl.close(); - ep0.stop(); - ep1.stop(); - ios.stop(); } -#endif \ No newline at end of file diff --git a/libPSI_Tests/OPPRF_Tests.h b/libPSI_Tests/OPPRF_Tests.h index f2dc6ddc55..a3cfd6194e 100644 --- a/libPSI_Tests/OPPRF_Tests.h +++ b/libPSI_Tests/OPPRF_Tests.h @@ -13,6 +13,7 @@ void findMaxBinSize_Test_Impl(); void findScaleNumBins_Test_Impl(); void Bit_Position_Random_Test_Impl(); void testShareValue(); +void OPPRFnt_EmptrySet_Test_Impl(); //void OPPRF_FullSet_Test_Impl (); //void OPPRF_SingltonSet_Test_Impl(); diff --git a/libPSI_TestsVS/OPPRF_TestsVS.cpp b/libPSI_TestsVS/OPPRF_TestsVS.cpp index ca8fc12cf9..c39b30baa0 100644 --- a/libPSI_TestsVS/OPPRF_TestsVS.cpp +++ b/libPSI_TestsVS/OPPRF_TestsVS.cpp @@ -98,6 +98,15 @@ namespace WeGarbleTests InitDebugPrinting(); testShareValue(); } + + TEST_METHOD(OPPRFnt_EmptrySet_Test) + { + InitDebugPrinting(); + OPPRFnt_EmptrySet_Test_Impl(); + } + + + /*TEST_METHOD(Channel_Test) { InitDebugPrinting();