Skip to content

Commit

Permalink
Improve zkey loading (#27)
Browse files Browse the repository at this point in the history
Co-authored-by: nixw <>
  • Loading branch information
olomix authored Dec 10, 2024
1 parent 38c832a commit af527f9
Show file tree
Hide file tree
Showing 12 changed files with 630 additions and 323 deletions.
2 changes: 1 addition & 1 deletion depends/ffiasm
Submodule ffiasm updated 5 files
+5 −4 c/fft.cpp
+2 −0 c/fft.hpp
+31 −24 c/misc.hpp
+37 −36 c/msm.cpp
+3 −1 c/msm.hpp
131 changes: 52 additions & 79 deletions src/binfile_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,92 +12,49 @@

namespace BinFileUtils {

BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion) {
BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion)
: fileLoader(fileName)
{
addr = fileLoader.dataBuffer();
size = fileLoader.dataSize();

is_fd = true;
struct stat sb;

fd = open(fileName.c_str(), O_RDONLY);
if (fd == -1)
throw std::system_error(errno, std::generic_category(), "open");


if (fstat(fd, &sb) == -1) { /* To obtain file size */
close(fd);
throw std::system_error(errno, std::generic_category(), "fstat");
}

size = sb.st_size;

addr = mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (addr == MAP_FAILED) {
close(fd);
throw std::system_error(errno, std::generic_category(), "mmap failed");
}
madvise(addr, size, MADV_SEQUENTIAL);

type.assign((const char *)addr, 4);
pos = 4;

if (type != _type) {
munmap(addr, size);
close(fd);
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type + " filename: " + fileName);
}

version = readU32LE();
if (version > maxVersion) {
munmap(addr, size);
close(fd);
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
}

u_int32_t nSections = readU32LE();


for (u_int32_t i=0; i<nSections; i++) {
u_int32_t sType=readU32LE();
u_int64_t sSize=readU64LE();

if (sections.find(sType) == sections.end()) {
sections.insert(std::make_pair(sType, std::vector<Section>()));
}
readFileData(_type, maxVersion);
}

sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));
BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {

pos += sSize;
}
addr = fileData;
size = fileSize;

pos = 0;
readingSection = nullptr;
readFileData(_type, maxVersion);
}

void BinFile::readFileData(std::string _type, uint32_t maxVersion) {

BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {
const u_int64_t headerSize = 12;
const u_int64_t minSectionSize = 12;

is_fd = false;
fd = -1;

size = fileSize;
addr = malloc(size);
memcpy(addr, fileData, size);
if (size < headerSize) {
throw std::range_error("File is too short.");
}

type.assign((const char *)addr, 4);
pos = 4;

if (type != _type) {
free(addr);
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type);
}

version = readU32LE();
if (version > maxVersion) {
free(addr);
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
}

u_int32_t nSections = readU32LE();

if (size < headerSize + nSections * minSectionSize) {
throw std::range_error("File is too short to contain " + std::to_string(nSections) + " sections.");
}

for (u_int32_t i=0; i<nSections; i++) {
u_int32_t sType=readU32LE();
Expand All @@ -110,20 +67,18 @@ BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint3
sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));

pos += sSize;

if (pos > size) {
throw std::range_error("Section #" + std::to_string(i) + " is invalid."
". It ends at pos " + std::to_string(pos) +
" but should end before " + std::to_string(size) + ".");
}
}

pos = 0;
readingSection = NULL;
readingSection = nullptr;
}

BinFile::~BinFile() {
if (is_fd) {
munmap(addr, size);
close(fd);
} else {
free(addr);
}
}

void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {

Expand All @@ -135,7 +90,7 @@ void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
}

if (readingSection != NULL) {
if (readingSection != nullptr) {
throw std::range_error("Already reading a section");
}

Expand All @@ -150,7 +105,7 @@ void BinFile::endReadSection(bool check) {
throw std::range_error("Invalid section size");
}
}
readingSection = NULL;
readingSection = nullptr;
}

void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
Expand All @@ -169,31 +124,49 @@ void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
u_int64_t BinFile::getSectionSize(u_int32_t sectionId, u_int32_t sectionPos) {

if (sections.find(sectionId) == sections.end()) {
throw new std::range_error("Section does not exist: " + std::to_string(sectionId));
throw std::range_error("Section does not exist: " + std::to_string(sectionId));
}

if (sectionPos >= sections[sectionId].size()) {
throw new std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
}

return sections[sectionId][sectionPos].size;
}

u_int32_t BinFile::readU32LE() {
const u_int64_t new_pos = pos + 4;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

u_int32_t res = *((u_int32_t *)((u_int64_t)addr + pos));
pos += 4;
pos = new_pos;
return res;
}

u_int64_t BinFile::readU64LE() {
const u_int64_t new_pos = pos + 8;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

u_int64_t res = *((u_int64_t *)((u_int64_t)addr + pos));
pos += 8;
pos = new_pos;
return res;
}

void *BinFile::read(u_int64_t len) {
const u_int64_t new_pos = pos + len;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

void *res = (void *)((u_int64_t)addr + pos);
pos += len;
pos = new_pos;
return res;
}

Expand Down
8 changes: 4 additions & 4 deletions src/binfile_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include <map>
#include <vector>
#include <memory>
#include "fileloader.hpp"

namespace BinFileUtils {

class BinFile {

bool is_fd;
int fd;
FileLoader fileLoader;

void *addr;
const void *addr;
u_int64_t size;
u_int64_t pos;

Expand All @@ -32,14 +32,14 @@ namespace BinFileUtils {

Section *readingSection;

void readFileData(std::string _type, uint32_t maxVersion);

public:

BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion);
BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion);
BinFile(const BinFile&) = delete;
BinFile& operator=(const BinFile&) = delete;
~BinFile();

void startReadSection(u_int32_t sectionId, u_int32_t setionPos = 0);
void endReadSection(bool check = true);
Expand Down
28 changes: 26 additions & 2 deletions src/fileloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,23 @@

namespace BinFileUtils {

FileLoader::FileLoader()
: fd(-1)
{
}

FileLoader::FileLoader(const std::string& fileName)
: fd(-1)
{
load(fileName);
}

void FileLoader::load(const std::string& fileName)
{
if (fd != -1) {
throw std::invalid_argument("file already loaded");
}

struct stat sb;

fd = open(fileName.c_str(), O_RDONLY);
Expand All @@ -26,12 +41,21 @@ FileLoader::FileLoader(const std::string& fileName)
size = sb.st_size;

addr = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0);

if (addr == MAP_FAILED) {
close(fd);
throw std::system_error(errno, std::generic_category(), "mmap failed");
}

madvise(addr, size, MADV_SEQUENTIAL);
}

FileLoader::~FileLoader()
{
munmap(addr, size);
close(fd);
if (fd != -1) {
munmap(addr, size);
close(fd);
}
}

} // Namespace
3 changes: 3 additions & 0 deletions src/fileloader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ namespace BinFileUtils {
class FileLoader
{
public:
FileLoader();
FileLoader(const std::string& fileName);
~FileLoader();

void load(const std::string& fileName);

void* dataBuffer() { return addr; }
size_t dataSize() const { return size; }

Expand Down
15 changes: 8 additions & 7 deletions src/groth16.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "random_generator.hpp"
#include "logging.hpp"
#include "misc.hpp"
#include <sstream>
#include <vector>
#include <mutex>

Expand Down Expand Up @@ -84,7 +85,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
auto b = new typename Engine::FrElement[domainSize];
auto c = new typename Engine::FrElement[domainSize];

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int32_t i=begin; i<end; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
Expand All @@ -96,7 +97,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
#define NLOCKS 1024
std::vector<std::mutex> locks(NLOCKS);

threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, nCoefs, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;
Expand All @@ -117,7 +118,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
}
});
LOG_TRACE("Calculating c");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(
c[i],
Expand All @@ -137,7 +138,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(a[1]).c_str());
LOG_TRACE("Start Shift A");

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}
Expand All @@ -157,7 +158,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
LOG_TRACE("Start Shift B");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
Expand All @@ -177,7 +178,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
LOG_TRACE("Start Shift C");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
Expand All @@ -192,7 +193,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[1]).c_str());

LOG_TRACE("Start ABC");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
Expand Down
Loading

0 comments on commit af527f9

Please sign in to comment.