From af527f906e9e0cbeb0282ae0e77e53909775e23b Mon Sep 17 00:00:00 2001 From: Oleh Lomaka Date: Tue, 10 Dec 2024 06:54:42 -0500 Subject: [PATCH] Improve zkey loading (#27) Co-authored-by: nixw <> --- depends/ffiasm | 2 +- src/binfile_utils.cpp | 131 +++++------ src/binfile_utils.hpp | 8 +- src/fileloader.cpp | 28 ++- src/fileloader.hpp | 3 + src/groth16.cpp | 15 +- src/main_prover.cpp | 133 +++++------ src/prover.cpp | 488 +++++++++++++++++++++++++++++++---------- src/prover.h | 113 ++++++++-- src/test_public_size.c | 14 +- src/wtns_utils.cpp | 9 +- src/zkey_utils.cpp | 9 +- 12 files changed, 630 insertions(+), 323 deletions(-) diff --git a/depends/ffiasm b/depends/ffiasm index b9dce0a..fe3772e 160000 --- a/depends/ffiasm +++ b/depends/ffiasm @@ -1 +1 @@ -Subproject commit b9dce0a78a0ee172c23cee2465904383ea43f0ec +Subproject commit fe3772e8e62f2235d308d445d09379a0fca8f5a9 diff --git a/src/binfile_utils.cpp b/src/binfile_utils.cpp index 821a302..53cc5d9 100644 --- a/src/binfile_utils.cpp +++ b/src/binfile_utils.cpp @@ -12,92 +12,49 @@ namespace BinFileUtils { -BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion) { +BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion) + : fileLoader(fileName) +{ + addr = fileLoader.dataBuffer(); + size = fileLoader.dataSize(); - is_fd = true; - struct stat sb; - - fd = open(fileName.c_str(), O_RDONLY); - if (fd == -1) - throw std::system_error(errno, std::generic_category(), "open"); - - - if (fstat(fd, &sb) == -1) { /* To obtain file size */ - close(fd); - throw std::system_error(errno, std::generic_category(), "fstat"); - } - - size = sb.st_size; - - addr = mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0); - if (addr == MAP_FAILED) { - close(fd); - throw std::system_error(errno, std::generic_category(), "mmap failed"); - } - madvise(addr, size, MADV_SEQUENTIAL); - - type.assign((const char *)addr, 4); - pos = 4; - - if (type != _type) { - munmap(addr, size); - close(fd); - throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type + " filename: " + fileName); - } - - version = readU32LE(); - if (version > maxVersion) { - munmap(addr, size); - close(fd); - throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version)); - } - - u_int32_t nSections = readU32LE(); - - - for (u_int32_t i=0; i())); - } + readFileData(_type, maxVersion); +} - sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize)); +BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) { - pos += sSize; - } + addr = fileData; + size = fileSize; - pos = 0; - readingSection = nullptr; + readFileData(_type, maxVersion); } +void BinFile::readFileData(std::string _type, uint32_t maxVersion) { -BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) { + const u_int64_t headerSize = 12; + const u_int64_t minSectionSize = 12; - is_fd = false; - fd = -1; - - size = fileSize; - addr = malloc(size); - memcpy(addr, fileData, size); + if (size < headerSize) { + throw std::range_error("File is too short."); + } type.assign((const char *)addr, 4); pos = 4; if (type != _type) { - free(addr); throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type); } version = readU32LE(); if (version > maxVersion) { - free(addr); throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version)); } u_int32_t nSections = readU32LE(); + if (size < headerSize + nSections * minSectionSize) { + throw std::range_error("File is too short to contain " + std::to_string(nSections) + " sections."); + } for (u_int32_t i=0; i size) { + throw std::range_error("Section #" + std::to_string(i) + " is invalid." + ". It ends at pos " + std::to_string(pos) + + " but should end before " + std::to_string(size) + "."); + } } pos = 0; - readingSection = NULL; + readingSection = nullptr; } -BinFile::~BinFile() { - if (is_fd) { - munmap(addr, size); - close(fd); - } else { - free(addr); - } -} void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) { @@ -135,7 +90,7 @@ void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) { throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos)); } - if (readingSection != NULL) { + if (readingSection != nullptr) { throw std::range_error("Already reading a section"); } @@ -150,7 +105,7 @@ void BinFile::endReadSection(bool check) { throw std::range_error("Invalid section size"); } } - readingSection = NULL; + readingSection = nullptr; } void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) { @@ -169,31 +124,49 @@ void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) { u_int64_t BinFile::getSectionSize(u_int32_t sectionId, u_int32_t sectionPos) { if (sections.find(sectionId) == sections.end()) { - throw new std::range_error("Section does not exist: " + std::to_string(sectionId)); + throw std::range_error("Section does not exist: " + std::to_string(sectionId)); } if (sectionPos >= sections[sectionId].size()) { - throw new std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos)); + throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos)); } return sections[sectionId][sectionPos].size; } u_int32_t BinFile::readU32LE() { + const u_int64_t new_pos = pos + 4; + + if (new_pos > size) { + throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos)); + } + u_int32_t res = *((u_int32_t *)((u_int64_t)addr + pos)); - pos += 4; + pos = new_pos; return res; } u_int64_t BinFile::readU64LE() { + const u_int64_t new_pos = pos + 8; + + if (new_pos > size) { + throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos)); + } + u_int64_t res = *((u_int64_t *)((u_int64_t)addr + pos)); - pos += 8; + pos = new_pos; return res; } void *BinFile::read(u_int64_t len) { + const u_int64_t new_pos = pos + len; + + if (new_pos > size) { + throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos)); + } + void *res = (void *)((u_int64_t)addr + pos); - pos += len; + pos = new_pos; return res; } diff --git a/src/binfile_utils.hpp b/src/binfile_utils.hpp index 69c2201..e2a9d23 100644 --- a/src/binfile_utils.hpp +++ b/src/binfile_utils.hpp @@ -4,15 +4,15 @@ #include #include #include +#include "fileloader.hpp" namespace BinFileUtils { class BinFile { - bool is_fd; - int fd; + FileLoader fileLoader; - void *addr; + const void *addr; u_int64_t size; u_int64_t pos; @@ -32,6 +32,7 @@ namespace BinFileUtils { Section *readingSection; + void readFileData(std::string _type, uint32_t maxVersion); public: @@ -39,7 +40,6 @@ namespace BinFileUtils { BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion); BinFile(const BinFile&) = delete; BinFile& operator=(const BinFile&) = delete; - ~BinFile(); void startReadSection(u_int32_t sectionId, u_int32_t setionPos = 0); void endReadSection(bool check = true); diff --git a/src/fileloader.cpp b/src/fileloader.cpp index efc1b6c..4a64928 100644 --- a/src/fileloader.cpp +++ b/src/fileloader.cpp @@ -9,8 +9,23 @@ namespace BinFileUtils { +FileLoader::FileLoader() + : fd(-1) +{ +} + FileLoader::FileLoader(const std::string& fileName) + : fd(-1) { + load(fileName); +} + +void FileLoader::load(const std::string& fileName) +{ + if (fd != -1) { + throw std::invalid_argument("file already loaded"); + } + struct stat sb; fd = open(fileName.c_str(), O_RDONLY); @@ -26,12 +41,21 @@ FileLoader::FileLoader(const std::string& fileName) size = sb.st_size; addr = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0); + + if (addr == MAP_FAILED) { + close(fd); + throw std::system_error(errno, std::generic_category(), "mmap failed"); + } + + madvise(addr, size, MADV_SEQUENTIAL); } FileLoader::~FileLoader() { - munmap(addr, size); - close(fd); + if (fd != -1) { + munmap(addr, size); + close(fd); + } } } // Namespace diff --git a/src/fileloader.hpp b/src/fileloader.hpp index 831d61b..ad4756e 100644 --- a/src/fileloader.hpp +++ b/src/fileloader.hpp @@ -9,9 +9,12 @@ namespace BinFileUtils { class FileLoader { public: + FileLoader(); FileLoader(const std::string& fileName); ~FileLoader(); + void load(const std::string& fileName); + void* dataBuffer() { return addr; } size_t dataSize() const { return size; } diff --git a/src/groth16.cpp b/src/groth16.cpp index 761933e..a4acc14 100644 --- a/src/groth16.cpp +++ b/src/groth16.cpp @@ -1,6 +1,7 @@ #include "random_generator.hpp" #include "logging.hpp" #include "misc.hpp" +#include #include #include @@ -84,7 +85,7 @@ std::unique_ptr> Prover::prove(typename Engine::FrElement auto b = new typename Engine::FrElement[domainSize]; auto c = new typename Engine::FrElement[domainSize]; - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int32_t i=begin; i> Prover::prove(typename Engine::FrElement #define NLOCKS 1024 std::vector locks(NLOCKS); - threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, nCoefs, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; i> Prover::prove(typename Engine::FrElement } }); LOG_TRACE("Calculating c"); - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; i> Prover::prove(typename Engine::FrElement LOG_DEBUG(E.fr.toString(a[1]).c_str()); LOG_TRACE("Start Shift A"); - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; iroot(domainPower+1, i)); } @@ -157,7 +158,7 @@ std::unique_ptr> Prover::prove(typename Engine::FrElement LOG_DEBUG(E.fr.toString(b[0]).c_str()); LOG_DEBUG(E.fr.toString(b[1]).c_str()); LOG_TRACE("Start Shift B"); - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; iroot(domainPower+1, i)); } @@ -177,7 +178,7 @@ std::unique_ptr> Prover::prove(typename Engine::FrElement LOG_DEBUG(E.fr.toString(c[0]).c_str()); LOG_DEBUG(E.fr.toString(c[1]).c_str()); LOG_TRACE("Start Shift C"); - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; iroot(domainPower+1, i)); } @@ -192,7 +193,7 @@ std::unique_ptr> Prover::prove(typename Engine::FrElement LOG_DEBUG(E.fr.toString(c[1]).c_str()); LOG_TRACE("Start ABC"); - threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) { + threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) { for (u_int64_t i=begin; i #include -#include -#include +#include +#include #include -#include - -#include -#include "binfile_utils.hpp" -#include "zkey_utils.hpp" -#include "wtns_utils.hpp" -#include "groth16.hpp" - -using json = nlohmann::json; - -#define handle_error(msg) \ - do { perror(msg); exit(EXIT_FAILURE); } while (0) +#include "prover.h" +#include "fileloader.hpp" int main(int argc, char **argv) { if (argc != 5) { - std::cerr << "Invalid number of parameters:\n"; - std::cerr << "Usage: prover \n"; + std::cerr << "Invalid number of parameters" << std::endl; + std::cerr << "Usage: prover " << std::endl; return EXIT_FAILURE; } - mpz_t altBbn128r; - - mpz_init(altBbn128r); - mpz_set_str(altBbn128r, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10); - try { - std::string zkeyFilename = argv[1]; - std::string wtnsFilename = argv[2]; - std::string proofFilename = argv[3]; - std::string publicFilename = argv[4]; - - auto zkey = BinFileUtils::openExisting(zkeyFilename, "zkey", 1); - auto zkeyHeader = ZKeyUtils::loadHeader(zkey.get()); - - std::string proofStr; - if (mpz_cmp(zkeyHeader->rPrime, altBbn128r) != 0) { - throw std::invalid_argument( "zkey curve not supported" ); + const std::string zkeyFilename = argv[1]; + const std::string wtnsFilename = argv[2]; + const std::string proofFilename = argv[3]; + const std::string publicFilename = argv[4]; + + BinFileUtils::FileLoader zkeyFile(zkeyFilename); + BinFileUtils::FileLoader wtnsFile(wtnsFilename); + std::vector publicBuffer; + std::vector proofBuffer; + unsigned long long publicSize = 0; + unsigned long long proofSize = 0; + char errorMsg[1024]; + + int error = groth16_public_size_for_zkey_buf( + zkeyFile.dataBuffer(), + zkeyFile.dataSize(), + &publicSize, + errorMsg, + sizeof(errorMsg)); + + if (error != PROVER_OK) { + throw std::runtime_error(errorMsg); } - auto wtns = BinFileUtils::openExisting(wtnsFilename, "wtns", 2); - auto wtnsHeader = WtnsUtils::loadHeader(wtns.get()); - - if (mpz_cmp(wtnsHeader->prime, altBbn128r) != 0) { - throw std::invalid_argument( "different wtns curve" ); + groth16_proof_size(&proofSize); + + publicBuffer.resize(publicSize); + proofBuffer.resize(proofSize); + + error = groth16_prover( + zkeyFile.dataBuffer(), + zkeyFile.dataSize(), + wtnsFile.dataBuffer(), + wtnsFile.dataSize(), + proofBuffer.data(), + &proofSize, + publicBuffer.data(), + &publicSize, + errorMsg, + sizeof(errorMsg)); + + if (error != PROVER_OK) { + throw std::runtime_error(errorMsg); } - auto prover = Groth16::makeProver( - zkeyHeader->nVars, - zkeyHeader->nPublic, - zkeyHeader->domainSize, - zkeyHeader->nCoefs, - zkeyHeader->vk_alpha1, - zkeyHeader->vk_beta1, - zkeyHeader->vk_beta2, - zkeyHeader->vk_delta1, - zkeyHeader->vk_delta2, - zkey->getSectionData(4), // Coefs - zkey->getSectionData(5), // pointsA - zkey->getSectionData(6), // pointsB1 - zkey->getSectionData(7), // pointsB2 - zkey->getSectionData(8), // pointsC - zkey->getSectionData(9) // pointsH1 - ); - AltBn128::FrElement *wtnsData = (AltBn128::FrElement *)wtns->getSectionData(2); - auto proof = prover->prove(wtnsData); - - std::ofstream proofFile; - proofFile.open (proofFilename); - proofFile << proof->toJson(); - proofFile.close(); + std::ofstream proofFile(proofFilename); + proofFile.write(proofBuffer.data(), proofSize); - std::ofstream publicFile; - publicFile.open (publicFilename); - - json jsonPublic; - AltBn128::FrElement aux; - for (int i=1; i<=zkeyHeader->nPublic; i++) { - AltBn128::Fr.toMontgomery(aux, wtnsData[i]); - jsonPublic.push_back(AltBn128::Fr.toString(aux)); - } + std::ofstream publicFile(publicFilename); + publicFile.write(publicBuffer.data(), publicSize); - publicFile << jsonPublic; - publicFile.close(); - - } catch (std::exception* e) { - mpz_clear(altBbn128r); - std::cerr << e->what() << '\n'; - return EXIT_FAILURE; } catch (std::exception& e) { - mpz_clear(altBbn128r); - std::cerr << e.what() << '\n'; + std::cerr << "Error: " << e.what() << std::endl; return EXIT_FAILURE; + } - mpz_clear(altBbn128r); exit(EXIT_SUCCESS); } diff --git a/src/prover.cpp b/src/prover.cpp index 59ba672..6f79147 100644 --- a/src/prover.cpp +++ b/src/prover.cpp @@ -1,49 +1,84 @@ #include -#include #include #include #include #include #include - #include "prover.h" +#include "groth16.hpp" #include "zkey_utils.hpp" #include "wtns_utils.hpp" -#include "groth16.hpp" #include "binfile_utils.hpp" #include "fileloader.hpp" using json = nlohmann::json; -static size_t ProofBufferMinSize() + +class ShortBufferException : public std::invalid_argument +{ +public: + explicit ShortBufferException(const std::string &msg) + : std::invalid_argument(msg) {} +}; + +class InvalidWitnessLengthException : public std::invalid_argument +{ +public: + explicit InvalidWitnessLengthException(const std::string &msg) + : std::invalid_argument(msg) {} +}; + +static void +CopyError( + char *error_msg, + unsigned long long error_msg_maxsize, + const std::exception &e) +{ + if (error_msg) { + strncpy(error_msg, e.what(), error_msg_maxsize); + } +} + +static void +CopyError( + char *error_msg, + unsigned long long error_msg_maxsize, + const char *str) +{ + if (error_msg) { + strncpy(error_msg, str, error_msg_maxsize); + } +} + +static unsigned long long +ProofBufferMinSize() { return 810; } -static size_t PublicBufferMinSize(size_t count) +static unsigned long long +PublicBufferMinSize(unsigned long long count) { return count * 82 + 4; } -static void VerifyPrimes(mpz_srcptr zkey_prime, mpz_srcptr wtns_prime) +static bool +PrimeIsValid(mpz_srcptr prime) { mpz_t altBbn128r; mpz_init(altBbn128r); mpz_set_str(altBbn128r, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10); - if (mpz_cmp(zkey_prime, altBbn128r) != 0) { - throw std::invalid_argument( "zkey curve not supported" ); - } - - if (mpz_cmp(wtns_prime, altBbn128r) != 0) { - throw std::invalid_argument( "different wtns curve" ); - } + const bool is_valid = (mpz_cmp(prime, altBbn128r) == 0); mpz_clear(altBbn128r); + + return is_valid; } -std::string BuildPublicString(AltBn128::FrElement *wtnsData, size_t nPublic) +static std::string +BuildPublicString(AltBn128::FrElement *wtnsData, uint32_t nPublic) { json jsonPublic; AltBn128::FrElement aux; @@ -55,170 +90,391 @@ std::string BuildPublicString(AltBn128::FrElement *wtnsData, size_t nPublic) return jsonPublic.dump(); } +static void +CheckAndUpdateBufferSizes( + unsigned long long proofCalcSize, + unsigned long long *proofSize, + unsigned long long publicCalcSize, + unsigned long long *publicSize, + const std::string &type) +{ + if (*proofSize < proofCalcSize || *publicSize < publicCalcSize) { + + *proofSize = proofCalcSize; + *publicSize = publicCalcSize; + + if (*proofSize < proofCalcSize) { + throw ShortBufferException("Proof buffer is too short. " + type + " size: " + + std::to_string(proofCalcSize) + + ", actual size: " + + std::to_string(*proofSize)); + } else { + throw ShortBufferException("Public buffer is too short. " + type + " size: " + + std::to_string(proofCalcSize) + + ", actual size: " + + std::to_string(*proofSize)); + } + } +} + +class Groth16Prover +{ + BinFileUtils::BinFile zkey; + std::unique_ptr zkeyHeader; + std::unique_ptr> prover; + +public: + Groth16Prover(const void *zkey_buffer, + unsigned long long zkey_size) + + : zkey(zkey_buffer, zkey_size, "zkey", 1), + zkeyHeader(ZKeyUtils::loadHeader(&zkey)) + { + if (!PrimeIsValid(zkeyHeader->rPrime)) { + throw std::invalid_argument("zkey curve not supported"); + } + + prover = Groth16::makeProver( + zkeyHeader->nVars, + zkeyHeader->nPublic, + zkeyHeader->domainSize, + zkeyHeader->nCoefs, + zkeyHeader->vk_alpha1, + zkeyHeader->vk_beta1, + zkeyHeader->vk_beta2, + zkeyHeader->vk_delta1, + zkeyHeader->vk_delta2, + zkey.getSectionData(4), // Coefs + zkey.getSectionData(5), // pointsA + zkey.getSectionData(6), // pointsB1 + zkey.getSectionData(7), // pointsB2 + zkey.getSectionData(8), // pointsC + zkey.getSectionData(9) // pointsH1 + ); + } + + void prove(const void *wtns_buffer, + unsigned long long wtns_size, + std::string &stringProof, + std::string &stringPublic) + { + BinFileUtils::BinFile wtns(wtns_buffer, wtns_size, "wtns", 2); + auto wtnsHeader = WtnsUtils::loadHeader(&wtns); + + if (zkeyHeader->nVars != wtnsHeader->nVars) { + throw InvalidWitnessLengthException("Invalid witness length. Circuit: " + + std::to_string(zkeyHeader->nVars) + + ", witness: " + + std::to_string(wtnsHeader->nVars)); + } + + if (!PrimeIsValid(wtnsHeader->prime)) { + throw std::invalid_argument("different wtns curve"); + } + + AltBn128::FrElement *wtnsData = (AltBn128::FrElement *)wtns.getSectionData(2); + + auto proof = prover->prove(wtnsData); + + stringProof = proof->toJson().dump(); + stringPublic = BuildPublicString(wtnsData, zkeyHeader->nPublic); + } + + unsigned long long proofBufferMinSize() const + { + return ProofBufferMinSize(); + } + + unsigned long long publicBufferMinSize() const + { + return PublicBufferMinSize(zkeyHeader->nPublic); + } +}; + int -groth16_public_size_for_zkey_buf(const void *zkey_buffer, unsigned long zkey_size, - size_t *public_size, - char *error_msg, unsigned long error_msg_maxsize) { +groth16_public_size_for_zkey_buf( + const void *zkey_buffer, + unsigned long long zkey_size, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ try { BinFileUtils::BinFile zkey(zkey_buffer, zkey_size, "zkey", 1); auto zkeyHeader = ZKeyUtils::loadHeader(&zkey); + *public_size = PublicBufferMinSize(zkeyHeader->nPublic); - return PROVER_OK; + } catch (std::exception& e) { - if (error_msg) { - strncpy(error_msg, e.what(), error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, e); return PROVER_ERROR; + } catch (...) { - if (error_msg) { - strncpy(error_msg, "unknown error", error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, "unknown error"); return PROVER_ERROR; } + + return PROVER_OK; } int -groth16_public_size_for_zkey_file(const char *zkey_fname, - unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize) { +groth16_public_size_for_zkey_file( + const char *zkey_fname, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ try { auto zkey = BinFileUtils::openExisting(zkey_fname, "zkey", 1); auto zkeyHeader = ZKeyUtils::loadHeader(zkey.get()); + *public_size = PublicBufferMinSize(zkeyHeader->nPublic); - return PROVER_OK; + } catch (std::exception& e) { - if (error_msg) { - strncpy(error_msg, e.what(), error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, e); return PROVER_ERROR; + } catch (...) { - if (error_msg) { - strncpy(error_msg, "unknown error", error_msg_maxsize); + CopyError(error_msg, error_msg_maxsize, "unknown error"); + return PROVER_ERROR; + } + + return PROVER_OK; +} + +void +groth16_proof_size( + unsigned long long *proof_size) +{ + *proof_size = ProofBufferMinSize(); +} + +int +groth16_prover_create( + void **prover_object, + const void *zkey_buffer, + unsigned long long zkey_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ + try { + if (prover_object == NULL) { + throw std::invalid_argument("Null prover object"); } + + if (zkey_buffer == NULL) { + throw std::invalid_argument("Null zkey buffer"); + } + + Groth16Prover *prover = new Groth16Prover(zkey_buffer, zkey_size); + + *prover_object = prover; + + } catch (std::exception& e) { + CopyError(error_msg, error_msg_maxsize, e); + return PROVER_ERROR; + + } catch (std::exception *e) { + CopyError(error_msg, error_msg_maxsize, *e); + delete e; + return PROVER_ERROR; + + } catch (...) { + CopyError(error_msg, error_msg_maxsize, "unknown error"); return PROVER_ERROR; } + + return PROVER_OK; } int -groth16_prover(const void *zkey_buffer, unsigned long zkey_size, - const void *wtns_buffer, unsigned long wtns_size, - char *proof_buffer, unsigned long *proof_size, - char *public_buffer, unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize) +groth16_prover_create_zkey_file( + void **prover_object, + const char *zkey_file_path, + char *error_msg, + unsigned long long error_msg_maxsize) { + BinFileUtils::FileLoader fileLoader; + try { - BinFileUtils::BinFile zkey(zkey_buffer, zkey_size, "zkey", 1); - auto zkeyHeader = ZKeyUtils::loadHeader(&zkey); + fileLoader.load(zkey_file_path); - BinFileUtils::BinFile wtns(wtns_buffer, wtns_size, "wtns", 2); - auto wtnsHeader = WtnsUtils::loadHeader(&wtns); + } catch (std::exception& e) { + CopyError(error_msg, error_msg_maxsize, e); + return PROVER_ERROR; + } - if (zkeyHeader->nVars != wtnsHeader->nVars) { - snprintf(error_msg, error_msg_maxsize, - "Invalid witness length. Circuit: %u, witness: %u", - zkeyHeader->nVars, wtnsHeader->nVars); - return PROVER_INVALID_WITNESS_LENGTH; + return groth16_prover_create( + prover_object, + fileLoader.dataBuffer(), + fileLoader.dataSize(), + error_msg, + error_msg_maxsize); +} + +int +groth16_prover_prove( + void *prover_object, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ + try { + if (prover_object == NULL) { + throw std::invalid_argument("Null prover object"); } - size_t proofMinSize = ProofBufferMinSize(); - size_t publicMinSize = PublicBufferMinSize(zkeyHeader->nPublic); + if (wtns_buffer == NULL) { + throw std::invalid_argument("Null witness buffer"); + } - if (*proof_size < proofMinSize || *public_size < publicMinSize) { + if (proof_buffer == NULL) { + throw std::invalid_argument("Null proof buffer"); + } - if (*proof_size < proofMinSize) { - snprintf(error_msg, error_msg_maxsize, - "Proof buffer is too short. Minimum size: %lu, actual size: %lu", - proofMinSize, *proof_size); - } else { - snprintf(error_msg, error_msg_maxsize, - "Public buffer is too short. Minimum size: %lu, actual size: %lu", - publicMinSize, *public_size); - } + if (proof_size == NULL) { + throw std::invalid_argument("Null proof size"); + } - *proof_size = proofMinSize; - *public_size = publicMinSize; + if (public_buffer == NULL) { + throw std::invalid_argument("Null public buffer"); + } - return PROVER_ERROR_SHORT_BUFFER; + if (public_size == NULL) { + throw std::invalid_argument("Null public size"); } - VerifyPrimes(zkeyHeader->rPrime, wtnsHeader->prime); + Groth16Prover *prover = static_cast(prover_object); - auto prover = Groth16::makeProver( - zkeyHeader->nVars, - zkeyHeader->nPublic, - zkeyHeader->domainSize, - zkeyHeader->nCoefs, - zkeyHeader->vk_alpha1, - zkeyHeader->vk_beta1, - zkeyHeader->vk_beta2, - zkeyHeader->vk_delta1, - zkeyHeader->vk_delta2, - zkey.getSectionData(4), // Coefs - zkey.getSectionData(5), // pointsA - zkey.getSectionData(6), // pointsB1 - zkey.getSectionData(7), // pointsB2 - zkey.getSectionData(8), // pointsC - zkey.getSectionData(9) // pointsH1 - ); - AltBn128::FrElement *wtnsData = (AltBn128::FrElement *)wtns.getSectionData(2); - auto proof = prover->prove(wtnsData); + CheckAndUpdateBufferSizes(prover->proofBufferMinSize(), proof_size, + prover->publicBufferMinSize(), public_size, + "Minimum"); - std::string stringProof = proof->toJson().dump(); - std::string stringPublic = BuildPublicString(wtnsData, zkeyHeader->nPublic); + std::string stringProof; + std::string stringPublic; - size_t stringProofSize = stringProof.length(); - size_t stringPublicSize = stringPublic.length(); + prover->prove(wtns_buffer, wtns_size, stringProof, stringPublic); - if (*proof_size < stringProofSize || *public_size < stringPublicSize) { + CheckAndUpdateBufferSizes(stringProof.length(), proof_size, + stringPublic.length(), public_size, + "Required"); - *proof_size = stringProofSize; - *public_size = stringPublicSize; + std::strncpy(proof_buffer, stringProof.c_str(), *proof_size); + std::strncpy(public_buffer, stringPublic.c_str(), *public_size); - return PROVER_ERROR_SHORT_BUFFER; - } + } catch(InvalidWitnessLengthException& e) { + CopyError(error_msg, error_msg_maxsize, e); + return PROVER_INVALID_WITNESS_LENGTH; - std::strncpy(proof_buffer, stringProof.data(), *proof_size); - std::strncpy(public_buffer, stringPublic.data(), *public_size); + } catch(ShortBufferException& e) { + CopyError(error_msg, error_msg_maxsize, e); + return PROVER_ERROR_SHORT_BUFFER; } catch (std::exception& e) { - - if (error_msg) { - strncpy(error_msg, e.what(), error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, e); return PROVER_ERROR; } catch (std::exception *e) { - - if (error_msg) { - strncpy(error_msg, e->what(), error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, *e); delete e; return PROVER_ERROR; } catch (...) { - if (error_msg) { - strncpy(error_msg, "unknown error", error_msg_maxsize); - } + CopyError(error_msg, error_msg_maxsize, "unknown error"); return PROVER_ERROR; } return PROVER_OK; } +void +groth16_prover_destroy(void *prover_object) +{ + if (prover_object != NULL) { + Groth16Prover *prover = static_cast(prover_object); + + delete prover; + } +} + +int +groth16_prover( + const void *zkey_buffer, + unsigned long long zkey_size, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ + void *prover = NULL; + + int error = groth16_prover_create( + &prover, + zkey_buffer, + zkey_size, + error_msg, + error_msg_maxsize); + + if (error != PROVER_OK) { + return error; + } + + error = groth16_prover_prove( + prover, + wtns_buffer, + wtns_size, + proof_buffer, + proof_size, + public_buffer, + public_size, + error_msg, + error_msg_maxsize); + + groth16_prover_destroy(prover); + + return error; +} + int -groth16_prover_zkey_file(const char *zkey_file_path, - const void *wtns_buffer, unsigned long wtns_size, - char *proof_buffer, unsigned long *proof_size, - char *public_buffer, unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize) { +groth16_prover_zkey_file( + const char *zkey_file_path, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize) +{ + BinFileUtils::FileLoader fileLoader; - std::string zkey_filename(zkey_file_path); + try { + fileLoader.load(zkey_file_path); - BinFileUtils::FileLoader fileLoader(zkey_filename); + } catch (std::exception& e) { + CopyError(error_msg, error_msg_maxsize, e); + return PROVER_ERROR; + } - return groth16_prover(fileLoader.dataBuffer(), fileLoader.dataSize(), - wtns_buffer, wtns_size, - proof_buffer, proof_size, - public_buffer, public_size, - error_msg, error_msg_maxsize); + return groth16_prover( + fileLoader.dataBuffer(), + fileLoader.dataSize(), + wtns_buffer, + wtns_size, + proof_buffer, + proof_size, + public_buffer, + public_size, + error_msg, + error_msg_maxsize); } diff --git a/src/prover.h b/src/prover.h index 4257699..8baaa12 100644 --- a/src/prover.h +++ b/src/prover.h @@ -16,9 +16,12 @@ extern "C" { * @returns PROVER_OK in case of success, and the size of public buffer is written to public_size */ int -groth16_public_size_for_zkey_buf(const void *zkey_buffer, unsigned long zkey_size, - size_t *public_size, - char *error_msg, unsigned long error_msg_maxsize); +groth16_public_size_for_zkey_buf( + const void *zkey_buffer, + unsigned long long zkey_size, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize); /** * groth16_public_size_for_zkey_file calculates minimum buffer size for @@ -30,37 +33,109 @@ groth16_public_size_for_zkey_buf(const void *zkey_buffer, unsigned long zkey_siz * PROVER_ERROR - in case of an error, error_msg contains the error message */ int -groth16_public_size_for_zkey_file(const char *zkey_fname, - unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize); +groth16_public_size_for_zkey_file( + const char *zkey_fname, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize); /** - * groth16_prover + * Returns buffer size to output proof as json string + */ +void +groth16_proof_size( + unsigned long long *proof_size); + +/** + * Initializes 'prover_object' with a pointer to a new prover object. + * @return error code: + * PROVER_OK - in case of success + * PPOVER_ERROR - in case of an error + */ +int +groth16_prover_create( + void **prover_object, + const void *zkey_buffer, + unsigned long long zkey_size, + char *error_msg, + unsigned long long error_msg_maxsize); + +/** + * Initializes 'prover_object' with a pointer to a new prover object. + * @return error code: + * PROVER_OK - in case of success + * PPOVER_ERROR - in case of an error + */ +int +groth16_prover_create_zkey_file( + void **prover_object, + const char *zkey_file_path, + char *error_msg, + unsigned long long error_msg_maxsize); + +/** + * Proves 'wtns_buffer' and saves results to 'proof_buffer' and 'public_buffer'. * @return error code: * PROVER_OK - in case of success * PPOVER_ERROR - in case of an error - * PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizess + * PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes */ int -groth16_prover(const void *zkey_buffer, unsigned long zkey_size, - const void *wtns_buffer, unsigned long wtns_size, - char *proof_buffer, unsigned long *proof_size, - char *public_buffer, unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize); +groth16_prover_prove( + void *prover_object, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize); + +/** + * Destroys 'prover_object'. + */ +void +groth16_prover_destroy(void *prover_object); /** * groth16_prover * @return error code: * PROVER_OK - in case of success * PPOVER_ERROR - in case of an error - * PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizess + * PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes + */ +int +groth16_prover( + const void *zkey_buffer, + unsigned long long zkey_size, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize); + +/** + * groth16_prover_zkey_file + * @return error code: + * PROVER_OK - in case of success + * PPOVER_ERROR - in case of an error + * PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes */ int -groth16_prover_zkey_file(const char *zkey_file_path, - const void *wtns_buffer, unsigned long wtns_size, - char *proof_buffer, unsigned long *proof_size, - char *public_buffer, unsigned long *public_size, - char *error_msg, unsigned long error_msg_maxsize); +groth16_prover_zkey_file( + const char *zkey_file_path, + const void *wtns_buffer, + unsigned long long wtns_size, + char *proof_buffer, + unsigned long long *proof_size, + char *public_buffer, + unsigned long long *public_size, + char *error_msg, + unsigned long long error_msg_maxsize); #ifdef __cplusplus } diff --git a/src/test_public_size.c b/src/test_public_size.c index d5aebe4..73b7f19 100644 --- a/src/test_public_size.c +++ b/src/test_public_size.c @@ -19,7 +19,7 @@ #include "prover.h" int -test_groth16_public_size(const char *zkey_fname, size_t *public_size) { +test_groth16_public_size(const char *zkey_fname, unsigned long long *public_size) { int ret_val = 0; const int error_sz = 256; char error_msg[error_sz]; @@ -54,7 +54,7 @@ test_groth16_public_size(const char *zkey_fname, size_t *public_size) { int ok = groth16_public_size_for_zkey_buf(buf, sb.st_size, public_size, error_msg, error_sz); if (ok == 0) { - printf("Public size: %lu\n", *public_size); + printf("Public size: %llu\n", *public_size); } else { printf("Error: %s\n", error_msg); ret_val = 1; @@ -72,13 +72,13 @@ test_groth16_public_size(const char *zkey_fname, size_t *public_size) { int test_groth16_public_size_for_zkey_file(const char *zkey_fname, - size_t *public_size) { + unsigned long long *public_size) { const int err_ln = 256; char error_msg[err_ln]; int ret = groth16_public_size_for_zkey_file(zkey_fname, public_size, error_msg, err_ln); if (ret == 0) { - printf("Public size: %lu\n", *public_size); + printf("Public size: %llu\n", *public_size); } else { printf("Error: %s\n", error_msg); } @@ -98,7 +98,7 @@ main(int argc, char *argv[]) { int ret_val = 0; clock_t start = clock(); - size_t public_size = 0; + unsigned long long public_size = 0; int test_groth16_public_size_ok = test_groth16_public_size(argv[1], &public_size); @@ -114,7 +114,7 @@ main(int argc, char *argv[]) { if (public_size != want_pub_size) { printf("test_groth16_public_size expected public signals buf size: %ld\n", want_pub_size); - printf("test_groth16_public_size actual public signals buf size: %lu\n", + printf("test_groth16_public_size actual public signals buf size: %llu\n", public_size); ret_val = 1; } @@ -135,7 +135,7 @@ main(int argc, char *argv[]) { if (public_size != want_pub_size) { printf("test_groth16_public_size_for_zkey_file expected public signals buf size: %ld\n", want_pub_size); - printf("test_groth16_public_size_for_zkey_file actual public signals buf size: %lu\n", + printf("test_groth16_public_size_for_zkey_file actual public signals buf size: %llu\n", public_size); ret_val = 1; } diff --git a/src/wtns_utils.cpp b/src/wtns_utils.cpp index 2301c3a..3d24aef 100644 --- a/src/wtns_utils.cpp +++ b/src/wtns_utils.cpp @@ -3,6 +3,7 @@ namespace WtnsUtils { Header::Header() { + mpz_init(prime); } Header::~Header() { @@ -10,18 +11,18 @@ Header::~Header() { } std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { - Header *h = new Header(); + std::unique_ptr
h(new Header()); + f->startReadSection(1); h->n8 = f->readU32LE(); - mpz_init(h->prime); mpz_import(h->prime, h->n8, -1, 1, -1, 0, f->read(h->n8)); h->nVars = f->readU32LE(); f->endReadSection(); - return std::unique_ptr
(h); + return h; } -} // NAMESPACE \ No newline at end of file +} // NAMESPACE diff --git a/src/zkey_utils.cpp b/src/zkey_utils.cpp index a0477fa..10d9acd 100644 --- a/src/zkey_utils.cpp +++ b/src/zkey_utils.cpp @@ -6,6 +6,8 @@ namespace ZKeyUtils { Header::Header() { + mpz_init(qPrime); + mpz_init(rPrime); } Header::~Header() { @@ -15,7 +17,8 @@ Header::~Header() { std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { - auto h = new Header(); + + std::unique_ptr
h(new Header()); f->startReadSection(1); uint32_t protocol = f->readU32LE(); @@ -27,11 +30,9 @@ std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { f->startReadSection(2); h->n8q = f->readU32LE(); - mpz_init(h->qPrime); mpz_import(h->qPrime, h->n8q, -1, 1, -1, 0, f->read(h->n8q)); h->n8r = f->readU32LE(); - mpz_init(h->rPrime); mpz_import(h->rPrime, h->n8r , -1, 1, -1, 0, f->read(h->n8r)); h->nVars = f->readU32LE(); @@ -48,7 +49,7 @@ std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { h->nCoefs = f->getSectionSize(4) / (12 + h->n8r); - return std::unique_ptr
(h); + return h; } } // namespace