Skip to content

Commit f40f02a

Browse files
author
nixw
committed
Check groth16_prover args
1 parent 3e78e1d commit f40f02a

File tree

7 files changed

+116
-39
lines changed

7 files changed

+116
-39
lines changed

build/fq.o

-48.9 KB
Binary file not shown.

build/fr.o

-48.9 KB
Binary file not shown.

build_gmp.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ build_host()
4343
mkdir "$BUILD_DIR"
4444
cd "$BUILD_DIR"
4545

46-
../configure ABI=64 --prefix="$PACKAGE_DIR" --disable-shared &&
46+
../configure --prefix="$PACKAGE_DIR" --disable-shared &&
4747
make -j$(nproc) &&
4848
make install
4949

@@ -88,7 +88,7 @@ build_android()
8888
mkdir "$BUILD_DIR"
8989
cd "$BUILD_DIR"
9090

91-
../configure ABI=64 --host $TARGET --prefix="$PACKAGE_DIR" --disable-shared --disable-fft &&
91+
../configure --host $TARGET --prefix="$PACKAGE_DIR" --disable-shared --disable-fft &&
9292
make -j$(nproc) &&
9393
make install
9494

@@ -126,7 +126,7 @@ build_ios()
126126
mkdir "$BUILD_DIR"
127127
cd "$BUILD_DIR"
128128

129-
../configure ABI=64 --host $TARGET --prefix="$PACKAGE_DIR" --disable-shared --disable-fft --disable-assembly &&
129+
../configure --host $TARGET --prefix="$PACKAGE_DIR" --disable-shared --disable-fft --disable-assembly &&
130130
make -j$(nproc) &&
131131
make install
132132

src/CMakeLists.txt

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
if(USE_ASM)
22
add_definitions(-DUSE_ASM)
3+
4+
if (CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin")
5+
set(NASM_FLAGS "-fmacho64 --prefix _")
6+
else()
7+
set(NASM_FLAGS -felf64)
8+
endif()
9+
10+
add_custom_command(OUTPUT ${CMAKE_SOURCE_DIR}/build/fq.o
11+
COMMAND nasm ${NASM_FLAGS} fq.asm -o fq.o
12+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/build)
13+
14+
add_custom_command(OUTPUT ${CMAKE_SOURCE_DIR}/build/fr.o
15+
COMMAND nasm ${NASM_FLAGS} fr.asm -o fr.o
16+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/build)
317
endif()
418

519
if(OpenMP_CXX_FOUND)

src/main_prover.cpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ int main(int argc, char **argv) {
2828

2929
char proofBuffer[BufferSize];
3030
char publicBuffer[BufferSize];
31+
size_t proofSize = sizeof(proofBuffer);
32+
size_t publicSize = sizeof(publicBuffer);
3133
char errorMessage[256];
3234
int error = 0;
3335

@@ -36,11 +38,17 @@ int main(int argc, char **argv) {
3638

3739
error = groth16_prover(zkeyFileLoader.dataBuffer(), zkeyFileLoader.dataSize(),
3840
wtnsFileLoader.dataBuffer(), wtnsFileLoader.dataSize(),
39-
proofBuffer, sizeof(proofBuffer),
40-
publicBuffer, sizeof(publicBuffer),
41+
proofBuffer, &proofSize,
42+
publicBuffer, &publicSize,
4143
errorMessage, sizeof(errorMessage));
4244

43-
if (error) {
45+
if (error == PPROVER_ERROR_SHORT_BUFFER) {
46+
47+
std::cerr << "Error: Short buffer for proof or public" << '\n';
48+
return EXIT_FAILURE;
49+
}
50+
51+
else if (error) {
4452
std::cerr << errorMessage << '\n';
4553
return EXIT_FAILURE;
4654
}
@@ -55,6 +63,10 @@ int main(int argc, char **argv) {
5563
publicFile << publicBuffer;
5664
publicFile.close();
5765

66+
} catch (std::exception* e) {
67+
std::cerr << e->what() << '\n';
68+
return EXIT_FAILURE;
69+
5870
} catch (std::exception& e) {
5971
std::cerr << e.what() << '\n';
6072
return EXIT_FAILURE;

src/prover.cpp

+76-26
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,73 @@
1414

1515
using json = nlohmann::json;
1616

17+
static size_t ProofBufferMinSize()
18+
{
19+
return 726;
20+
}
1721

18-
int
19-
groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
20-
const void *wtns_buffer, unsigned long wtns_size,
21-
char *proof_buffer, unsigned long proof_size,
22-
char *public_buffer, unsigned long public_size,
23-
char *error_msg, unsigned long error_msg_maxsize)
22+
static size_t PublicBufferMinSize(size_t count)
23+
{
24+
return count * 81 + 3;
25+
}
26+
27+
static void VerifyPrimes(mpz_srcptr zkey_prime, mpz_srcptr wtns_prime)
2428
{
2529
mpz_t altBbn128r;
2630

2731
mpz_init(altBbn128r);
2832
mpz_set_str(altBbn128r, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10);
2933

34+
if (mpz_cmp(zkey_prime, altBbn128r) != 0) {
35+
throw std::invalid_argument( "zkey curve not supported" );
36+
}
37+
38+
if (mpz_cmp(wtns_prime, altBbn128r) != 0) {
39+
throw std::invalid_argument( "different wtns curve" );
40+
}
41+
42+
mpz_clear(altBbn128r);
43+
}
44+
45+
std::string BuildPublicString(AltBn128::FrElement *wtnsData, size_t nPublic)
46+
{
47+
json jsonPublic;
48+
AltBn128::FrElement aux;
49+
for (u_int32_t i=1; i<= nPublic; i++) {
50+
AltBn128::Fr.toMontgomery(aux, wtnsData[i]);
51+
jsonPublic.push_back(AltBn128::Fr.toString(aux));
52+
}
53+
54+
return jsonPublic.dump();
55+
}
56+
57+
int
58+
groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
59+
const void *wtns_buffer, unsigned long wtns_size,
60+
char *proof_buffer, unsigned long *proof_size,
61+
char *public_buffer, unsigned long *public_size,
62+
char *error_msg, unsigned long error_msg_maxsize)
63+
{
3064
try {
3165
BinFileUtils::BinFile zkey(zkey_buffer, zkey_size, "zkey", 1);
32-
3366
auto zkeyHeader = ZKeyUtils::loadHeader(&zkey);
3467

35-
std::string proofStr;
36-
if (mpz_cmp(zkeyHeader->rPrime, altBbn128r) != 0) {
37-
throw std::invalid_argument( "zkey curve not supported" );
38-
}
39-
4068
BinFileUtils::BinFile wtns(wtns_buffer, wtns_size, "wtns", 2);
4169
auto wtnsHeader = WtnsUtils::loadHeader(&wtns);
4270

43-
if (mpz_cmp(wtnsHeader->prime, altBbn128r) != 0) {
44-
throw std::invalid_argument( "different wtns curve" );
71+
size_t proofMinSize = ProofBufferMinSize();
72+
size_t publicMinSize = PublicBufferMinSize(zkeyHeader->nPublic);
73+
74+
if (*proof_size < proofMinSize || *public_size < publicMinSize) {
75+
76+
*proof_size = proofMinSize;
77+
*public_size = publicMinSize;
78+
79+
return PPROVER_ERROR_SHORT_BUFFER;
4580
}
4681

82+
VerifyPrimes(zkeyHeader->rPrime, wtnsHeader->prime);
83+
4784
auto prover = Groth16::makeProver<AltBn128::Engine>(
4885
zkeyHeader->nVars,
4986
zkeyHeader->nPublic,
@@ -65,30 +102,43 @@ groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
65102
auto proof = prover->prove(wtnsData);
66103

67104
std::string stringProof = proof->toJson().dump();
105+
std::string stringPublic = BuildPublicString(wtnsData, zkeyHeader->nPublic);
68106

69-
std::strncpy(proof_buffer, stringProof.data(), proof_size);
107+
size_t stringProofSize = stringProof.length();
108+
size_t stringPublicSize = stringPublic.length();
70109

71-
json jsonPublic;
72-
AltBn128::FrElement aux;
73-
for (u_int32_t i=1; i<=zkeyHeader->nPublic; i++) {
74-
AltBn128::Fr.toMontgomery(aux, wtnsData[i]);
75-
jsonPublic.push_back(AltBn128::Fr.toString(aux));
76-
}
110+
if (*proof_size < stringProofSize || *public_size < stringPublicSize) {
77111

78-
std::string stringPublic = jsonPublic.dump();
112+
*proof_size = stringProofSize;
113+
*public_size = stringPublicSize;
79114

80-
std::strncpy(public_buffer, stringPublic.data(), public_size);
115+
return PPROVER_ERROR_SHORT_BUFFER;
116+
}
117+
118+
std::strncpy(proof_buffer, stringProof.data(), *proof_size);
119+
std::strncpy(public_buffer, stringPublic.data(), *public_size);
81120

82121
} catch (std::exception& e) {
83-
mpz_clear(altBbn128r);
84122

85123
if (error_msg) {
86124
strncpy(error_msg, e.what(), error_msg_maxsize);
87125
}
88126
return PPROVER_ERROR;
89-
}
90127

91-
mpz_clear(altBbn128r);
128+
} catch (std::exception *e) {
129+
130+
if (error_msg) {
131+
strncpy(error_msg, e->what(), error_msg_maxsize);
132+
}
133+
delete e;
134+
return PPROVER_ERROR;
135+
136+
} catch (...) {
137+
if (error_msg) {
138+
strncpy(error_msg, "unknown error", error_msg_maxsize);
139+
}
140+
return PPROVER_ERROR;
141+
}
92142

93143
return PRPOVER_OK;
94144
}

src/prover.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ extern "C" {
66
#endif
77

88
//Error codes returned by the functions.
9-
#define PRPOVER_OK 0x0
10-
#define PPROVER_ERROR 0x1
9+
#define PRPOVER_OK 0x0
10+
#define PPROVER_ERROR 0x1
11+
#define PPROVER_ERROR_SHORT_BUFFER 0x2
1112

1213

1314
/**
@@ -17,11 +18,11 @@ extern "C" {
1718
*/
1819

1920
int
20-
groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
21-
const void *wtns_buffer, unsigned long wtns_size,
22-
char *proof_buffer, unsigned long proof_size,
23-
char *public_buffer, unsigned long public_size,
24-
char *error_msg, unsigned long error_msg_maxsize);
21+
groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
22+
const void *wtns_buffer, unsigned long wtns_size,
23+
char *proof_buffer, unsigned long *proof_size,
24+
char *public_buffer, unsigned long *public_size,
25+
char *error_msg, unsigned long error_msg_maxsize);
2526

2627
#ifdef __cplusplus
2728
}

0 commit comments

Comments
 (0)