Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/server #11

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@

This is the code needed to calculate the witness by a circuit compiled with [circom](https://github.com/iden3/circom).

## Server
To compile server implementation:
- Compile Server
```
g++ -O3 -std=c++17 -DSERVER_ENABLE -fopenmp -pthread calcwit.cpp main.cpp utils.cpp fr.cpp fr.o socket.cpp circuit-1960-32-256-64.cpp -o circuit-1960-32-256-64 -lgmp
```
- Compile Client
```
g++ -std=c++17 client.cpp socket.cpp -o client
```
- Launch Server
```
./circuit-1960-32-256-64
```
- Launch Client
```
./client input-1960-32-256-64.json circuit-1960-32-256-64_w.wshm
```


## License

Expand Down
51 changes: 51 additions & 0 deletions c/client.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string>
#include <iostream>
#include "socket.hpp"

//main driver program
int main(int argc, char *argv[]) {
if (argc!=3) {
std::string cl = argv[0];
std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1);
std::cout << "Usage: " << base_filename << " <input.<bin|json>> <output.<wtns|json|wshm>>\n";
} else {
int hSocket, read_size, server_reply;
struct sockaddr_in server;
t_witness_msg message;
strcpy(message.inputFile, argv[1]);
strcpy(message.outputFile, argv[2]);

//Create socket
hSocket = SocketCreate();
if(hSocket == -1) {
printf("Could not create socket\n");
return 1;
}
//Connect to remote server
if (SocketConnect(hSocket) < 0) {
perror("connect failed.\n");
return 1;
}
//Send data to the server, and retry until file created
SocketSend(hSocket, (void *) &message, sizeof(t_witness_msg));

while (1) {
if (access(message.outputFile, F_OK) == 0) {
break;
}
sleep(1);
}

close(hSocket);
shutdown(hSocket,0);
shutdown(hSocket,1);
shutdown(hSocket,2);
return 0;
}
}
193 changes: 132 additions & 61 deletions c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@ using json = nlohmann::json;
#include "circom.hpp"
#include "utils.hpp"

#ifdef SERVER_ENABLE
#include "socket.hpp"
#endif

Circom_Circuit *circuit;


#define handle_error(msg) \
do { perror(msg); exit(EXIT_FAILURE); } while (0)

#define SHMEM_WITNESS_KEY (123456)
#define FAST_LOG2(x) (sizeof(unsigned long)*8 - 1 - __builtin_clzl((unsigned long)(x)))
#define FAST_LOG2_UP(x) (((x) - (1 << FAST_LOG2(x))) ? FAST_LOG2(x) + 1 : FAST_LOG2(x))



// assumptions
// 1) There is only one key assigned for shared memory. This means
Expand Down Expand Up @@ -82,18 +90,19 @@ void writeOutShmem(Circom_CalcWit *ctx, std::string filename) {
u64 idSection2length = n8*circuit->NVars;
fwrite(&idSection2length, 8, 1, write_ptr);

u64 nElems = (1 << (FAST_LOG2_UP(nVars)+1)) + 8;

// generate key
key_t key = SHMEM_WITNESS_KEY;
fwrite(&key, sizeof(key_t), 1, write_ptr);

// Setup shared memory
if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0) {
if ((shmid = shmget(key, nElems * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0) {
// preallocated shared memory segment is too small => Retrieve id by accesing old segment
// Delete old segment and create new with corret size
shmid = shmget(key, 4, IPC_CREAT | 0666);
shmctl(shmid, IPC_RMID, NULL);
if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0){
if ((shmid = shmget(key, nElems * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0){
status = -1;
fwrite(&status, sizeof(status), 1, write_ptr);
fclose(write_ptr);
Expand Down Expand Up @@ -344,78 +353,140 @@ Circom_Circuit *loadCircuit(std::string const &datFileName) {
return circuit;
}


void computeWitness(char *inputFile, char *outputFile) {
struct timeval begin, end;
long seconds, microseconds;
double elapsed;

gettimeofday(&begin,0);
Circom_CalcWit *ctx = new Circom_CalcWit(circuit);

std::string infilename = inputFile;
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to loadJson %.20f\n", elapsed);

if (hasEnding(infilename, std::string(".bin"))) {
loadBin(ctx, infilename);
} else if (hasEnding(infilename, std::string(".json"))) {
loadJson(ctx, infilename);
} else {
handle_error("Invalid input extension (.bin / .json)");
}

ctx->join();

// printf("Finished!\n");

std::string outfilename = outputFile;

if (hasEnding(outfilename, std::string(".wtns"))) {
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to WriteWtns %.20f\n", elapsed);
writeOutBin(ctx, outfilename);
} else if (hasEnding(outfilename, std::string(".json"))) {
writeOutJson(ctx, outfilename);
} else if (hasEnding(outfilename, std::string(".wshm"))) {
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to WriteShmem %.20f\n", elapsed);
writeOutShmem(ctx, outfilename);
} else {
handle_error("Invalid output extension (.bin / .json)");
}

delete ctx;
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Total %.20f\n", elapsed);
#ifndef SERVER_ENABLE
exit(EXIT_SUCCESS);
#endif
}

int main(int argc, char *argv[]) {
struct timeval begin, end;
long seconds, microseconds;
double elapsed;

gettimeofday(&begin,0);
#ifndef SERVER_ENABLE
if (argc!=3) {
std::string cl = argv[0];
std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1);
std::cout << "Usage: " << base_filename << " <input.<bin|json>> <output.<wtns|json|wshm>>\n";
} else {

struct timeval begin, end;
long seconds, microseconds;
double elapsed;

gettimeofday(&begin,0);

std::string datFileName = argv[0];
datFileName += ".dat";

circuit = loadCircuit(datFileName);

// open output
Circom_CalcWit *ctx = new Circom_CalcWit(circuit);

std::string infilename = argv[1];
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to loadJson %.20f\n", elapsed);

if (hasEnding(infilename, std::string(".bin"))) {
loadBin(ctx, infilename);
} else if (hasEnding(infilename, std::string(".json"))) {
loadJson(ctx, infilename);
} else {
handle_error("Invalid input extension (.bin / .json)");
}

ctx->join();

// printf("Finished!\n");

std::string outfilename = argv[2];
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

if (hasEnding(outfilename, std::string(".wtns"))) {
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to WriteWtns %.20f\n", elapsed);
writeOutBin(ctx, outfilename);
} else if (hasEnding(outfilename, std::string(".json"))) {
writeOutJson(ctx, outfilename);
} else if (hasEnding(outfilename, std::string(".wshm"))) {
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to WriteShmem %.20f\n", elapsed);
writeOutShmem(ctx, outfilename);
} else {
handle_error("Invalid output extension (.bin / .json)");
}
printf("Up to computeWitness %.20f\n", elapsed);
// open output
computeWitness(argv[1], argv[2]);

delete ctx;
gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;
#else
{
std::string datFileName = argv[0];
datFileName += ".dat";

printf("Total %.20f\n", elapsed);
exit(EXIT_SUCCESS);
int circuitInit=0;
t_witness_msg message;

if (!ServerInit()) {
exit(1);
}

while(1) {
int sock;
sock = ReceiveMsg((void *) &message, sizeof(t_witness_msg));
if (!sock) {
continue;
}
std::cout << " Output file " << message.outputFile << "\n";
std::cout << " Input file " << message.inputFile << "\n";

if (!circuitInit) {
std::cout << " Load Circuit " << datFileName << "\n";
circuit = loadCircuit(datFileName);
circuitInit=1;

gettimeofday(&end,0);
seconds = end.tv_sec - begin.tv_sec;
microseconds = end.tv_usec - begin.tv_usec;
elapsed = seconds + microseconds*1e-6;

printf("Up to computeWitness %.20f\n", elapsed);
}

if (circuitInit) {
std::cout << " Compute Witness " << message.outputFile << "\n";
computeWitness(message.inputFile, message.outputFile);
}

SocketClose(sock);
sleep(1);
}
#endif
}
}

Loading