Skip to content

Commit

Permalink
feat: MerkleSumTreeCircuit to MstInclusionCircuit
Browse files Browse the repository at this point in the history
The only difference is that the new `MstInclusion` doesn't store `assets_sum` and neither performs `enforce_less_than`
  • Loading branch information
enricobottazzi committed Jun 16, 2023
1 parent 72c8a58 commit a8a354e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 228 deletions.
29 changes: 8 additions & 21 deletions zk_prover/benches/full_solvency_flow.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use criterion::{criterion_group, criterion_main, Criterion};
use halo2_proofs::{
halo2curves::bn256::{Bn256, Fr as Fp},
halo2curves::bn256::Bn256,
plonk::{keygen_pk, keygen_vk},
poly::kzg::commitment::ParamsKZG,
};
use snark_verifier_sdk::CircuitExt;
use summa_solvency::{
circuits::merkle_sum_tree::MerkleSumTreeCircuit,
circuits::merkle_sum_tree::MstInclusionCircuit,
circuits::utils::{full_prover, full_verifier, generate_setup_params},
merkle_sum_tree::{MerkleSumTree, MST_WIDTH, N_ASSETS},
};
Expand All @@ -32,7 +32,7 @@ fn verification_key_gen_benchmark(_c: &mut Criterion) {

let params: ParamsKZG<Bn256> = generate_setup_params(11);

let empty_circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();
let empty_circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();

let bench_name = format!("gen verification key for 2 power of {} entries", LEVELS);
criterion.bench_function(&bench_name, |b| {
Expand All @@ -47,7 +47,7 @@ fn proving_key_gen_benchmark(_c: &mut Criterion) {

let params: ParamsKZG<Bn256> = generate_setup_params(11);

let empty_circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();
let empty_circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();

let vk = keygen_vk(&params, &empty_circuit).expect("vk generation should not fail");
let bench_name = format!("gen proving key for 2 power of {} entries", LEVELS);
Expand All @@ -63,19 +63,15 @@ fn generate_zk_proof_benchmark(_c: &mut Criterion) {

let params: ParamsKZG<Bn256> = generate_setup_params(11);

let empty_circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();
let empty_circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();

let vk = keygen_vk(&params, &empty_circuit).expect("vk generation should not fail");
let pk = keygen_pk(&params, vk, &empty_circuit).expect("pk generation should not fail");

let csv_file = format!("benches/csv/entry_2_{}.csv", LEVELS);

let assets_sum = [Fp::from(556863u64), Fp::from(556863u64)]; // greater than liabilities sum (556862)

// Only now we can instantiate the circuit with the actual inputs
let circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_from_assets_and_path(
assets_sum, &csv_file, 0,
);
let circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init(&csv_file, 0);

let bench_name = format!("generate zk proof - tree of 2 power of {} entries", LEVELS);
criterion.bench_function(&bench_name, |b| {
Expand All @@ -90,24 +86,15 @@ fn verify_zk_proof_benchmark(_c: &mut Criterion) {

let params: ParamsKZG<Bn256> = generate_setup_params(11);

let empty_circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();
let empty_circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_empty();

let vk = keygen_vk(&params, &empty_circuit).expect("vk generation should not fail");
let pk = keygen_pk(&params, vk.clone(), &empty_circuit).expect("pk generation should not fail");

let csv_file = format!("benches/csv/entry_2_{}.csv", LEVELS);

let assets_sum = [Fp::from(556863u64), Fp::from(556863u64)]; // greater than liabilities sum (556862)

// Only now we can instantiate the circuit with the actual inputs
let circuit = MerkleSumTreeCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init_from_assets_and_path(
assets_sum, &csv_file, 0,
);

let mut public_input = vec![circuit.leaf_hash];
public_input.extend(&circuit.leaf_balances);
public_input.push(circuit.root_hash);
public_input.extend(&circuit.assets_sum);
let circuit = MstInclusionCircuit::<LEVELS, MST_WIDTH, N_ASSETS>::init(&csv_file, 0);

let proof = full_prover(&params, &pk, circuit.clone(), circuit.instances());

Expand Down
34 changes: 8 additions & 26 deletions zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,29 @@ use halo2_proofs::plonk::{Advice, Circuit, Column, ConstraintSystem, Error};
use snark_verifier_sdk::CircuitExt;

#[derive(Clone)]
pub struct MerkleSumTreeCircuit<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize>
{
pub struct MstInclusionCircuit<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize> {
pub leaf_hash: Fp,
pub leaf_balances: Vec<Fp>,
pub path_element_hashes: Vec<Fp>,
pub path_element_balances: Vec<[Fp; N_ASSETS]>,
pub path_indices: Vec<Fp>,
pub assets_sum: Vec<Fp>,
pub root_hash: Fp,
}

impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize> CircuitExt<Fp>
for MerkleSumTreeCircuit<LEVELS, MST_WIDTH, N_ASSETS>
for MstInclusionCircuit<LEVELS, MST_WIDTH, N_ASSETS>
{
fn num_instance(&self) -> Vec<usize> {
vec![2 + N_ASSETS]
vec![2]
}

fn instances(&self) -> Vec<Vec<Fp>> {
let mut instances = vec![self.leaf_hash];
instances.push(self.root_hash);
instances.extend(&self.assets_sum);
vec![instances]
vec![vec![self.leaf_hash, self.root_hash]]
}
}

impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize>
MerkleSumTreeCircuit<LEVELS, MST_WIDTH, N_ASSETS>
MstInclusionCircuit<LEVELS, MST_WIDTH, N_ASSETS>
{
pub fn init_empty() -> Self {
Self {
Expand All @@ -42,16 +37,11 @@ impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize>
path_element_hashes: vec![Fp::zero(); LEVELS],
path_element_balances: vec![[Fp::zero(); N_ASSETS]; LEVELS],
path_indices: vec![Fp::zero(); LEVELS],
assets_sum: vec![Fp::zero(); N_ASSETS],
root_hash: Fp::zero(),
}
}

pub fn init_from_assets_and_path(
assets_sum: [Fp; N_ASSETS],
path: &str,
user_index: usize,
) -> Self {
pub fn init(path: &str, user_index: usize) -> Self {
let merkle_sum_tree = MerkleSumTree::new(path).unwrap();

let proof: MerkleProof<N_ASSETS> = merkle_sum_tree.generate_proof(user_index).unwrap();
Expand All @@ -71,14 +61,13 @@ impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize>
path_element_hashes: proof.sibling_hashes,
path_element_balances: proof.sibling_sums,
path_indices: proof.path_indices,
assets_sum: assets_sum.to_vec(),
root_hash: proof.root_hash,
}
}
}

impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize> Circuit<Fp>
for MerkleSumTreeCircuit<LEVELS, MST_WIDTH, N_ASSETS>
for MstInclusionCircuit<LEVELS, MST_WIDTH, N_ASSETS>
{
type Config = MerkleSumTreeConfig<MST_WIDTH>;
type FloorPlanner = SimpleFloorPlanner;
Expand Down Expand Up @@ -141,14 +130,7 @@ impl<const LEVELS: usize, const MST_WIDTH: usize, const N_ASSETS: usize> Circuit
)?;
}

// enforce computed sum to be less than the assets sum
chip.enforce_less_than(layouter.namespace(|| "enforce less than"), &next_sum)?;

chip.expose_public(
layouter.namespace(|| "public root"),
&next_hash,
1,
)?;
chip.expose_public(layouter.namespace(|| "public root"), &next_hash, 1)?;
Ok(())
}
}
Loading

0 comments on commit a8a354e

Please sign in to comment.