From 2460efe324bbbf000e97e2c254990bb6321b84cd Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Mon, 13 May 2024 19:27:54 +0800 Subject: [PATCH] Apply sumcheck to nonzero constraints directly --- .vscode/settings.json | 1 + backend/Cargo.lock | 2 +- prover/Cargo.lock | 2 +- prover/Cargo.toml | 2 +- prover/benches/proof_of_liabilities.rs | 25 +- prover/src/chips/range/range_check.rs | 28 +- prover/src/chips/range/tests.rs | 9 +- prover/src/circuits/config/circuit_config.rs | 119 ++++++ prover/src/circuits/config/mod.rs | 3 + .../circuits/config/no_range_check_config.rs | 75 ++++ .../src/circuits/config/range_check_config.rs | 130 +++++++ prover/src/circuits/mod.rs | 1 + prover/src/circuits/summa_circuit.rs | 355 ++++++------------ prover/src/circuits/tests.rs | 110 +++++- 14 files changed, 576 insertions(+), 286 deletions(-) create mode 100644 prover/src/circuits/config/circuit_config.rs create mode 100644 prover/src/circuits/config/mod.rs create mode 100644 prover/src/circuits/config/no_range_check_config.rs create mode 100644 prover/src/circuits/config/range_check_config.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index dffd4e77..ddd9eed2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,4 +4,5 @@ "editor.formatOnSave": true, "editor.formatOnSaveMode": "file" }, + "cSpell.words": ["hyperplonk", "plonkish", "layouter", "sumcheck"] } diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 687679d1..95318bca 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -2427,7 +2427,7 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "plonkish_backend" version = "0.1.0" -source = "git+https://github.com/summa-dev/plonkish?branch=summa-changes#50093af18280f4e1efd79ac258ae9c65b9401999" +source = "git+https://github.com/summa-dev/plonkish?branch=nonzero-constraints#e37ba53dcc8f8bd6e7add4e479d0e3295ee80661" dependencies = [ "bincode", "bitvec 1.0.1", diff --git a/prover/Cargo.lock b/prover/Cargo.lock index 9ddf20f5..610ff69a 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -989,7 +989,7 @@ checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "plonkish_backend" version = "0.1.0" -source = "git+https://github.com/summa-dev/plonkish?branch=summa-changes#50093af18280f4e1efd79ac258ae9c65b9401999" +source = "git+https://github.com/summa-dev/plonkish?branch=nonzero-constraints#e37ba53dcc8f8bd6e7add4e479d0e3295ee80661" dependencies = [ "bincode", "bitvec", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index f44bd0d1..a897e375 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -13,7 +13,7 @@ parallel = ["dep:rayon"] frontend-halo2 = ["dep:halo2_proofs"] [dependencies] -plonkish_backend = { git = "https://github.com/summa-dev/plonkish", branch="summa-changes", package = "plonkish_backend", features= ["frontend-halo2", "benchmark"] } +plonkish_backend = { git = "https://github.com/summa-dev/plonkish", branch="nonzero-constraints", package = "plonkish_backend", features= ["frontend-halo2", "benchmark"] } plotters = { version = "0.3.4", optional = true } rand = "0.8" csv = "1.1" diff --git a/prover/benches/proof_of_liabilities.rs b/prover/benches/proof_of_liabilities.rs index 5e866eaf..0ac4a505 100644 --- a/prover/benches/proof_of_liabilities.rs +++ b/prover/benches/proof_of_liabilities.rs @@ -14,7 +14,7 @@ use rand::{ CryptoRng, Rng, RngCore, SeedableRng, }; use summa_hyperplonk::{ - circuits::summa_circuit::summa_hyperplonk::SummaHyperplonk, + circuits::{config::range_check_config::RangeCheckConfig, summa_circuit::SummaHyperplonk}, utils::{big_uint_to_fp, generate_dummy_entries, uni_to_multivar_binary_index}, }; @@ -30,12 +30,15 @@ fn bench_summa() type ProvingBackend = HyperPlonk>; let entries = generate_dummy_entries::().unwrap(); - let halo2_circuit = SummaHyperplonk::::init(entries.to_vec()); + let halo2_circuit = + SummaHyperplonk::>::init( + entries.to_vec(), + ); - let circuit = Halo2Circuit::>::new::( - K as usize, - halo2_circuit.clone(), - ); + let circuit = Halo2Circuit::< + Fp, + SummaHyperplonk>, + >::new::(K as usize, halo2_circuit.clone()); let circuit_info: PlonkishCircuitInfo<_> = circuit.circuit_info().unwrap(); let instances = circuit.instances(); @@ -46,10 +49,10 @@ fn bench_summa() c.bench_function(&grand_sum_proof_bench_name, |b| { b.iter_batched( || { - Halo2Circuit::>::new::( - K as usize, - halo2_circuit.clone(), - ) + Halo2Circuit::< + Fp, + SummaHyperplonk>, + >::new::(K as usize, halo2_circuit.clone()) }, |circuit| { let mut transcript = Keccak256Transcript::default(); @@ -189,7 +192,7 @@ fn bench_summa() } fn criterion_benchmark(_c: &mut Criterion) { - const N_CURRENCIES: usize = 1; + const N_CURRENCIES: usize = 100; { const K: u32 = 17; diff --git a/prover/src/chips/range/range_check.rs b/prover/src/chips/range/range_check.rs index 26598aa4..76c84fcb 100644 --- a/prover/src/chips/range/range_check.rs +++ b/prover/src/chips/range/range_check.rs @@ -2,7 +2,7 @@ use crate::chips::range::utils::decompose_fp_to_byte_pairs; use halo2_proofs::arithmetic::Field; use halo2_proofs::circuit::{AssignedCell, Region, Value}; use halo2_proofs::halo2curves::bn256::Fr as Fp; -use halo2_proofs::plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed}; +use halo2_proofs::plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed, Selector}; use halo2_proofs::poly::Rotation; use std::fmt::Debug; @@ -27,8 +27,7 @@ use std::fmt::Debug; /// # Fields /// /// * `zs`: Four advice columns - contain the truncated right-shifted values of the element to be checked -/// * `z0`: An advice column - for storing the zero value from the instance column -/// * `instance`: An instance column - zero value provided to the circuit +/// * `selector`: Selector used to enable the range check /// /// # Assumptions /// @@ -36,8 +35,9 @@ use std::fmt::Debug; /// /// Patterned after [halo2_gadgets](https://github.com/privacy-scaling-explorations/halo2/blob/main/halo2_gadgets/src/utilities/decompose_running_sum.rs) #[derive(Debug, Copy, Clone)] -pub struct RangeCheckU64Config { +pub struct RangeCheckChipConfig { zs: [Column; 4], + selector: Selector, } /// Helper chip that verifies that the element witnessed in a given cell lies within the u64 range. @@ -72,22 +72,23 @@ pub struct RangeCheckU64Config { /// zs[3] == z0 #[derive(Debug, Clone)] pub struct RangeCheckU64Chip { - config: RangeCheckU64Config, + config: RangeCheckChipConfig, } impl RangeCheckU64Chip { - pub fn construct(config: RangeCheckU64Config) -> Self { + pub fn construct(config: RangeCheckChipConfig) -> Self { Self { config } } /// Configures the Range Chip - /// Note: the lookup table should be loaded with values from `0` to `2^16 - 1` otherwise the range check will fail. + /// Note: the lookup table should be loaded with values from `0` to `2^16 - 1`, otherwise the range check will fail. pub fn configure( meta: &mut ConstraintSystem, z: Column, zs: [Column; 4], range_u16: Column, - ) -> RangeCheckU64Config { + range_check_enabled: Selector, + ) -> RangeCheckChipConfig { // Constraint that the difference between the element to be checked and the 0-th truncated right-shifted value of the element to be within the range. // z - 2^16⋅zs[0] = ks[0] ∈ range_u16 meta.lookup_any( @@ -99,7 +100,9 @@ impl RangeCheckU64Chip { let range_u16 = meta.query_fixed(range_u16, Rotation::cur()); - let diff = element - zero_truncation * Expression::Constant(Fp::from(1 << 16)); + let s = meta.query_selector(range_check_enabled); + + let diff = s * (element - zero_truncation * Expression::Constant(Fp::from(1 << 16))); vec![(diff, range_u16)] }, @@ -123,7 +126,10 @@ impl RangeCheckU64Chip { ); } - RangeCheckU64Config { zs } + RangeCheckChipConfig { + zs, + selector: range_check_enabled, + } } /// Assign the truncated right-shifted values of the element to be checked to the corresponding columns zs at offset 0 starting from the element to be checked. @@ -163,6 +169,8 @@ impl RangeCheckU64Chip { zs.push(z.clone()); } + self.config.selector.enable(region, 0)?; + Ok(()) } } diff --git a/prover/src/chips/range/tests.rs b/prover/src/chips/range/tests.rs index f9425443..3647c9d6 100644 --- a/prover/src/chips/range/tests.rs +++ b/prover/src/chips/range/tests.rs @@ -1,4 +1,4 @@ -use crate::chips::range::range_check::{RangeCheckU64Chip, RangeCheckU64Config}; +use crate::chips::range::range_check::{RangeCheckChipConfig, RangeCheckU64Chip}; use halo2_proofs::{ circuit::{AssignedCell, Layouter, SimpleFloorPlanner, Value}, halo2curves::bn256::Fr as Fp, @@ -87,7 +87,7 @@ impl AddChip { #[derive(Debug, Clone)] pub struct TestConfig { pub addchip_config: AddConfig, - pub range_check_config: RangeCheckU64Config, + pub range_check_config: RangeCheckChipConfig, pub range_u16: Column, pub instance: Column, } @@ -134,7 +134,10 @@ impl Circuit for TestCircuit { let instance = meta.instance_column(); meta.enable_equality(instance); - let range_check_config = RangeCheckU64Chip::configure(meta, c, zs, range_u16); + let range_check_selector = meta.complex_selector(); + + let range_check_config = + RangeCheckU64Chip::configure(meta, c, zs, range_u16, range_check_selector); let addchip_config = AddChip::configure(meta, a, b, c, add_selector); diff --git a/prover/src/circuits/config/circuit_config.rs b/prover/src/circuits/config/circuit_config.rs new file mode 100644 index 00000000..51382773 --- /dev/null +++ b/prover/src/circuits/config/circuit_config.rs @@ -0,0 +1,119 @@ +use halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Advice, Column, ConstraintSystem, Error, Instance}, +}; + +use crate::{entry::Entry, utils::big_uint_to_fp}; + +use crate::chips::range::range_check::RangeCheckU64Chip; +use halo2_proofs::halo2curves::bn256::Fr as Fp; + +/// The abstract configuration of the circuit. +/// The default implementation assigns the entries and grand total to the circuit, and constrains +/// grand total to the instance values. +/// +/// The specific implementations have to provide the range check logic. +pub trait CircuitConfig: Clone { + /// Configures the circuit + fn configure( + meta: &mut ConstraintSystem, + username: Column, + balances: [Column; N_CURRENCIES], + instance: Column, + ) -> Self; + + fn get_username(&self) -> Column; + + fn get_balances(&self) -> [Column; N_CURRENCIES]; + + fn get_instance(&self) -> Column; + + /// Assigns the entries to the circuit, constrains the grand total to the instance values. + fn synthesize( + &self, + mut layouter: impl Layouter, + entries: &[Entry], + grand_total: &[Fp], + ) -> Result<(), Error> { + // Initiate the range check chips + let range_check_chips = self.initialize_range_check_chips(); + + for (i, entry) in entries.iter().enumerate() { + let last_decompositions = layouter.assign_region( + || format!("assign entry {} to the table", i), + |mut region| { + region.assign_advice( + || "username", + self.get_username(), + 0, + || Value::known(big_uint_to_fp::(entry.username_as_big_uint())), + )?; + + let mut last_decompositions = vec![]; + + for (j, balance) in entry.balances().iter().enumerate() { + let assigned_balance = region.assign_advice( + || format!("balance {}", j), + self.get_balances()[j], + 0, + || Value::known(big_uint_to_fp(balance)), + )?; + + let mut zs = Vec::with_capacity(4); + + if !range_check_chips.is_empty() { + range_check_chips[j].assign(&mut region, &mut zs, &assigned_balance)?; + + last_decompositions.push(zs[3].clone()); + } + } + + Ok(last_decompositions) + }, + )?; + + self.constrain_decompositions(last_decompositions, &mut layouter)?; + } + + let assigned_total = layouter.assign_region( + || "assign total".to_string(), + |mut region| { + let mut assigned_total = vec![]; + + for (j, total) in grand_total.iter().enumerate() { + let balance_total = region.assign_advice( + || format!("total {}", j), + self.get_balances()[j], + 0, + || Value::known(total.neg()), + )?; + + assigned_total.push(balance_total); + } + + Ok(assigned_total) + }, + )?; + + for (j, total) in assigned_total.iter().enumerate() { + layouter.constrain_instance(total.cell(), self.get_instance(), 1 + j)?; + } + + self.load_lookup_table(layouter)?; + + Ok(()) + } + + /// Initializes the range check chips + fn initialize_range_check_chips(&self) -> Vec; + + /// Loads the lookup table + fn load_lookup_table(&self, layouter: impl Layouter) -> Result<(), Error>; + + /// Constrains the last decompositions of the balances to be zero + fn constrain_decompositions( + &self, + last_decompositions: Vec>, + layouter: &mut impl Layouter, + ) -> Result<(), Error>; +} diff --git a/prover/src/circuits/config/mod.rs b/prover/src/circuits/config/mod.rs new file mode 100644 index 00000000..361692f8 --- /dev/null +++ b/prover/src/circuits/config/mod.rs @@ -0,0 +1,3 @@ +pub mod circuit_config; +pub mod no_range_check_config; +pub mod range_check_config; diff --git a/prover/src/circuits/config/no_range_check_config.rs b/prover/src/circuits/config/no_range_check_config.rs new file mode 100644 index 00000000..5d6d4ada --- /dev/null +++ b/prover/src/circuits/config/no_range_check_config.rs @@ -0,0 +1,75 @@ +use halo2_proofs::{ + circuit::Layouter, + plonk::{Advice, Column, ConstraintSystem, Error, Instance}, +}; + +use crate::chips::range::range_check::RangeCheckU64Chip; +use halo2_proofs::halo2curves::bn256::Fr as Fp; + +use super::circuit_config::CircuitConfig; + +/// Configuration that does not perform range checks. Warning: not for use in production! +/// The circuit without range checks can use a lower K value (9+) than the full circuit (convenient for prototyping and testing). +/// +/// # Type Parameters +/// +/// * `N_CURRENCIES`: The number of currencies for which the solvency is verified. +/// * `N_USERS`: The number of users for which the solvency is verified. +/// +/// # Fields +/// +/// * `username`: Advice column used to store the usernames of the users +/// * `balances`: Advice columns used to store the balances of the users +#[derive(Clone)] +pub struct NoRangeCheckConfig { + username: Column, + balances: [Column; N_CURRENCIES], + instance: Column, +} + +impl CircuitConfig + for NoRangeCheckConfig +{ + fn configure( + _: &mut ConstraintSystem, + username: Column, + balances: [Column; N_CURRENCIES], + instance: Column, + ) -> NoRangeCheckConfig { + Self { + username, + balances, + instance, + } + } + + fn get_username(&self) -> Column { + self.username + } + + fn get_balances(&self) -> [Column; N_CURRENCIES] { + self.balances + } + + fn get_instance(&self) -> Column { + self.instance + } + + // The following methods are not implemented for NoRangeCheckConfig + + fn initialize_range_check_chips(&self) -> Vec { + vec![] + } + + fn load_lookup_table(&self, _: impl Layouter) -> Result<(), Error> { + Ok(()) + } + + fn constrain_decompositions( + &self, + _: Vec>, + _: &mut impl Layouter, + ) -> Result<(), Error> { + Ok(()) + } +} diff --git a/prover/src/circuits/config/range_check_config.rs b/prover/src/circuits/config/range_check_config.rs new file mode 100644 index 00000000..7a7c0e8a --- /dev/null +++ b/prover/src/circuits/config/range_check_config.rs @@ -0,0 +1,130 @@ +use halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Advice, Column, ConstraintSystem, Error, Fixed, Instance}, +}; + +use crate::chips::range::range_check::{RangeCheckChipConfig, RangeCheckU64Chip}; +use halo2_proofs::halo2curves::bn256::Fr as Fp; + +use super::circuit_config::CircuitConfig; + +/// Configuration that performs range checks. +/// +/// # Type Parameters +/// +/// * `N_CURRENCIES`: The number of currencies for which the solvency is verified. +/// * `N_USERS`: The number of users for which the solvency is verified. +/// +/// # Fields +/// +/// * `username`: Advice column used to store the usernames of the users +/// * `balances`: Advice columns used to store the balances of the users +/// * `range_check_configs`: Range check chip configurations +/// * `range_u16`: Fixed column used to store the lookup table +/// * `instance`: Instance column used to constrain the last balance decomposition +#[derive(Clone)] +pub struct RangeCheckConfig { + username: Column, + balances: [Column; N_CURRENCIES], + range_check_configs: [RangeCheckChipConfig; N_CURRENCIES], + range_u16: Column, + instance: Column, +} + +impl CircuitConfig + for RangeCheckConfig +{ + fn configure( + meta: &mut ConstraintSystem, + username: Column, + balances: [Column; N_CURRENCIES], + instance: Column, + ) -> Self { + let range_u16 = meta.fixed_column(); + + meta.enable_constant(range_u16); + + meta.annotate_lookup_any_column(range_u16, || "LOOKUP_MAXBITS_RANGE"); + + let range_check_selector = meta.complex_selector(); + + // Create an empty array of range check configs + let mut range_check_configs = Vec::with_capacity(N_CURRENCIES); + + for balance_column in balances.iter() { + let z = *balance_column; + // Create 4 advice columns for each range check chip + let zs = [(); 4].map(|_| meta.advice_column()); + + for column in &zs { + meta.enable_equality(*column); + } + + let range_check_config = + RangeCheckU64Chip::configure(meta, z, zs, range_u16, range_check_selector); + + range_check_configs.push(range_check_config); + } + + Self { + username, + balances, + range_check_configs: range_check_configs.try_into().unwrap(), + range_u16, + instance, + } + } + + fn get_username(&self) -> Column { + self.username + } + + fn get_balances(&self) -> [Column; N_CURRENCIES] { + self.balances + } + + fn get_instance(&self) -> Column { + self.instance + } + + fn initialize_range_check_chips(&self) -> Vec { + self.range_check_configs + .iter() + .map(|config| RangeCheckU64Chip::construct(*config)) + .collect::>() + } + + fn load_lookup_table(&self, mut layouter: impl Layouter) -> Result<(), Error> { + // Load lookup table for range check u64 chip + let range = 1 << 16; + + layouter.assign_region( + || "load range check table of 16 bits".to_string(), + |mut region| { + for i in 0..range { + region.assign_fixed( + || "assign cell in fixed column", + self.range_u16, + i, + || Value::known(Fp::from(i as u64)), + )?; + } + Ok(()) + }, + )?; + + Ok(()) + } + + /// Constrains the last decompositions of each balance to a zero value (necessary for range checks) + fn constrain_decompositions( + &self, + last_decompositions: Vec>, + layouter: &mut impl Layouter, + ) -> Result<(), Error> { + for last_decomposition in last_decompositions { + layouter.constrain_instance(last_decomposition.cell(), self.instance, 0)?; + } + Ok(()) + } +} diff --git a/prover/src/circuits/mod.rs b/prover/src/circuits/mod.rs index 1f44e02c..283abc4a 100644 --- a/prover/src/circuits/mod.rs +++ b/prover/src/circuits/mod.rs @@ -1,3 +1,4 @@ +pub mod config; pub mod summa_circuit; #[cfg(test)] mod tests; diff --git a/prover/src/circuits/summa_circuit.rs b/prover/src/circuits/summa_circuit.rs index ea6e8bdb..e8762d45 100644 --- a/prover/src/circuits/summa_circuit.rs +++ b/prover/src/circuits/summa_circuit.rs @@ -1,274 +1,129 @@ -pub mod summa_hyperplonk { - - use crate::chips::range::range_check::{RangeCheckU64Chip, RangeCheckU64Config}; - use crate::entry::Entry; - use crate::utils::big_uint_to_fp; - use halo2_proofs::arithmetic::Field; - use halo2_proofs::halo2curves::bn256::Fr as Fp; - use halo2_proofs::plonk::{Expression, Selector}; - use halo2_proofs::poly::Rotation; - use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, - }; - use num_bigint::BigUint; - use plonkish_backend::frontend::halo2::CircuitExt; - use rand::RngCore; - - #[derive(Clone)] - pub struct SummaConfig { - username: Column, - balances: [Column; N_CURRENCIES], - running_sums: [Column; N_CURRENCIES], - range_check_configs: [RangeCheckU64Config; N_CURRENCIES], - range_u16: Column, - instance: Column, - selector: Selector, - } - - impl SummaConfig { - fn configure(meta: &mut ConstraintSystem, running_sum_selector: &Selector) -> Self { - let username = meta.advice_column(); - - let balances = [(); N_CURRENCIES].map(|_| meta.advice_column()); - let running_sums = [(); N_CURRENCIES].map(|_| meta.advice_column()); - - for column in &running_sums { - meta.enable_equality(*column); - } - - let range_u16 = meta.fixed_column(); - - meta.enable_constant(range_u16); - - meta.annotate_lookup_any_column(range_u16, || "LOOKUP_MAXBITS_RANGE"); - - // Create an empty array of range check configs - let mut range_check_configs = Vec::with_capacity(N_CURRENCIES); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - for item in balances.iter().take(N_CURRENCIES) { - let z = *item; - // Create 4 advice columns for each range check chip - let zs = [(); 4].map(|_| meta.advice_column()); - - for column in &zs { - meta.enable_equality(*column); - } - - let range_check_config = RangeCheckU64Chip::configure(meta, z, zs, range_u16); - - range_check_configs.push(range_check_config); - } - - meta.create_gate("Running sum gate", |meta| { - let mut running_sum_constraint = vec![]; - let s = meta.query_selector(*running_sum_selector); - for j in 0..N_CURRENCIES { - let prev_running_sum = meta.query_advice(running_sums[j], Rotation::prev()); - let curr_running_sum = meta.query_advice(running_sums[j], Rotation::cur()); - let curr_balance = meta.query_advice(balances[j], Rotation::cur()); - running_sum_constraint.push( - s.clone() - * (curr_running_sum.clone() - prev_running_sum - curr_balance.clone()) - + (Expression::Constant(Fp::ONE) - s.clone()) - * (curr_running_sum - curr_balance), - ) - } - running_sum_constraint - }); +use std::marker::PhantomData; + +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, ConstraintSystem, Error}, + poly::Rotation, +}; + +use crate::{entry::Entry, utils::big_uint_to_fp}; + +use halo2_proofs::arithmetic::Field; +use halo2_proofs::halo2curves::bn256::Fr as Fp; +use plonkish_backend::frontend::halo2::CircuitExt; +use rand::RngCore; + +use super::config::circuit_config::CircuitConfig; + +#[derive(Clone, Default)] +pub struct SummaHyperplonk< + const N_USERS: usize, + const N_CURRENCIES: usize, + CONFIG: CircuitConfig, +> { + pub entries: Vec>, + pub grand_total: Vec, + _marker: PhantomData, +} - Self { - username, - balances, - running_sums, - range_check_configs: range_check_configs.try_into().unwrap(), - range_u16, - instance, - selector: *running_sum_selector, +impl< + const N_USERS: usize, + const N_CURRENCIES: usize, + CONFIG: CircuitConfig, + > SummaHyperplonk +{ + pub fn init(user_entries: Vec>) -> Self { + let mut grand_total = vec![Fp::ZERO; N_CURRENCIES]; + for entry in user_entries.iter() { + for (i, balance) in entry.balances().iter().enumerate() { + grand_total[i] += big_uint_to_fp::(balance); } } - } - #[derive(Clone, Default)] - pub struct SummaHyperplonk { - pub entries: Vec>, - pub grand_total: Vec, - } - - impl SummaHyperplonk { - pub fn init(user_entries: Vec>) -> Self { - let mut grand_total = vec![BigUint::from(0u64); N_CURRENCIES]; - for entry in user_entries.iter() { - for (i, balance) in entry.balances().iter().enumerate() { - grand_total[i] += balance; - } - } - - Self { - entries: user_entries, - grand_total, - } + Self { + entries: user_entries, + grand_total, + _marker: PhantomData, } } - impl Circuit - for SummaHyperplonk - { - type Config = SummaConfig; - type FloorPlanner = SimpleFloorPlanner; + /// Initialize the circuit with an invalid grand total + /// (for testing purposes only). + #[cfg(test)] + pub fn init_invalid_grand_total(user_entries: Vec>) -> Self { + use plonkish_backend::util::test::seeded_std_rng; - fn without_witnesses(&self) -> Self { - unimplemented!() + let mut grand_total = vec![Fp::ZERO; N_CURRENCIES]; + for i in 0..N_CURRENCIES { + grand_total[i] = Fp::random(seeded_std_rng()); } - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - meta.set_minimum_degree(4); - let running_sum_selector = &meta.complex_selector(); - SummaConfig::configure(meta, running_sum_selector) + Self { + entries: user_entries, + grand_total, + _marker: PhantomData, } + } +} - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // Assign entries - let (assigned_balances, last_running_sums) = layouter - .assign_region( - || "assign user entries", - |mut region| { - // create a bidimensional vector to store the assigned balances. The first dimension is N_USERS, the second dimension is N_CURRENCIES - let mut assigned_balances = vec![]; - - let mut running_sum_values = vec![vec![]]; - let mut last_assigned_running_sums = vec![]; - - for i in 0..N_USERS { - running_sum_values.push(vec![]); - - region.assign_advice( - || format!("username {}", i), - config.username, - i, - || { - Value::known(big_uint_to_fp::( - self.entries[i].username_as_big_uint(), - )) - }, - )?; - - let mut assigned_balances_row = vec![]; - - for (j, balance) in self.entries[i].balances().iter().enumerate() { - let balance_value: Value = - Value::known(big_uint_to_fp(balance)); - - let assigned_balance = region.assign_advice( - || format!("balance {}", j), - config.balances[j], - i, - || balance_value, - )?; - - assigned_balances_row.push(assigned_balance); - - let prev_running_sum_value = if i == 0 { - Value::known(Fp::ZERO) - } else { - running_sum_values[i - 1][j] - }; - - running_sum_values[i].push(prev_running_sum_value + balance_value); - - let assigned_running_sum = region.assign_advice( - || format!("running sum {}", j), - config.running_sums[j], - i, - || running_sum_values[i][j], - )?; - - if i == N_USERS - 1 { - last_assigned_running_sums.push(assigned_running_sum); - } - } - - if i > 0 { - config.selector.enable(&mut region, i)?; - } - - assigned_balances.push(assigned_balances_row); - } - - Ok((assigned_balances, last_assigned_running_sums)) - }, - ) - .unwrap(); - - // Initialize the range check chips - let range_check_chips = config - .range_check_configs - .iter() - .map(|config| RangeCheckU64Chip::construct(*config)) - .collect::>(); - - // Load lookup table for range check u64 chip - let range = 1 << 16; +impl< + const N_USERS: usize, + const N_CURRENCIES: usize, + CONFIG: CircuitConfig, + > Circuit for SummaHyperplonk +{ + type Config = CONFIG; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + unimplemented!() + } - layouter.assign_region( - || "load range check table of 16 bits".to_string(), - |mut region| { - for i in 0..range { - region.assign_fixed( - || "assign cell in fixed column", - config.range_u16, - i, - || Value::known(Fp::from(i as u64)), - )?; - } - Ok(()) - }, - )?; + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); - // Perform range check on the assigned balances - for (i, user_balances) in assigned_balances.iter().enumerate().take(N_USERS) { - for (j, balance) in user_balances.iter().enumerate() { - let mut zs = Vec::with_capacity(4); + let username = meta.advice_column(); - layouter.assign_region( - || format!("Perform range check on balance {} of user {}", j, i), - |mut region| { - range_check_chips[j].assign(&mut region, &mut zs, balance)?; - Ok(()) - }, - )?; + let balances = [(); N_CURRENCIES].map(|_| meta.advice_column()); + for column in &balances { + meta.enable_equality(*column); + } - layouter.constrain_instance(zs[3].cell(), config.instance, 0)?; - } + meta.create_gate("Balance sumcheck gate", |meta| { + let mut nonzero_constraint = vec![]; + for balance in balances { + let current_balance = meta.query_advice(balance, Rotation::cur()); + nonzero_constraint.push(current_balance.clone()); } + nonzero_constraint + }); - for (i, last_running_sum) in last_running_sums.iter().enumerate().take(N_CURRENCIES) { - layouter.constrain_instance(last_running_sum.cell(), config.instance, 1 + i)?; - } + let instance = meta.instance_column(); + meta.enable_equality(instance); - Ok(()) - } + CONFIG::configure(meta, username, balances, instance) } - impl CircuitExt - for SummaHyperplonk - { - fn rand(_: usize, _: impl RngCore) -> Self { - unimplemented!() - } + fn synthesize(&self, config: Self::Config, layouter: impl Layouter) -> Result<(), Error> { + CONFIG::synthesize(&config, layouter, &self.entries, &self.grand_total) + } +} - fn instances(&self) -> Vec> { - // The last decomposition of each range check chip should be zero - let mut instances = vec![Fp::ZERO]; - instances.extend(self.grand_total.iter().map(big_uint_to_fp::)); - vec![instances] - } +impl< + const N_USERS: usize, + const N_CURRENCIES: usize, + CONFIG: CircuitConfig, + > CircuitExt for SummaHyperplonk +{ + fn rand(_: usize, _: impl RngCore) -> Self { + unimplemented!() + } + + fn instances(&self) -> Vec> { + // The 1st element is zero because the last decomposition of each range check chip should be zero + vec![vec![Fp::ZERO] + .into_iter() + .chain(self.grand_total.iter().map(|x| x.neg())) + .collect::>()] } } diff --git a/prover/src/circuits/tests.rs b/prover/src/circuits/tests.rs index 58c518e5..88bbb4af 100644 --- a/prover/src/circuits/tests.rs +++ b/prover/src/circuits/tests.rs @@ -1,7 +1,8 @@ use halo2_proofs::arithmetic::Field; +use plonkish_backend::Error::InvalidSnark; use plonkish_backend::{ backend::{hyperplonk::HyperPlonk, PlonkishBackend, PlonkishCircuit}, - frontend::halo2::Halo2Circuit, + frontend::halo2::{CircuitExt, Halo2Circuit}, halo2_curves::bn256::{Bn256, Fr as Fp}, pcs::{multilinear::MultilinearKzg, Evaluation, PolynomialCommitmentScheme}, util::{ @@ -12,14 +13,16 @@ use plonkish_backend::{ }, Error::InvalidSumcheck, }; - use rand::{ rngs::{OsRng, StdRng}, CryptoRng, Rng, RngCore, SeedableRng, }; use crate::{ - circuits::summa_circuit::summa_hyperplonk::SummaHyperplonk, + circuits::{ + config::{no_range_check_config::NoRangeCheckConfig, range_check_config::RangeCheckConfig}, + summa_circuit::SummaHyperplonk, + }, utils::{ big_uint_to_fp, fp_to_big_uint, generate_dummy_entries, uni_to_multivar_binary_index, MultilinearAsUnivariate, @@ -27,23 +30,47 @@ use crate::{ }; const K: u32 = 17; const N_CURRENCIES: usize = 2; -const N_USERS: usize = 1 << 16; +// One row is reserved for the grand total. +// TODO find out what occupies one extra row +const N_USERS: usize = (1 << K) - 2; pub fn seeded_std_rng() -> impl RngCore + CryptoRng { StdRng::seed_from_u64(OsRng.next_u64()) } #[test] -fn test_summa_hyperplonk() { +fn test_summa_hyperplonk_e2e() { type ProvingBackend = HyperPlonk>; let entries = generate_dummy_entries::().unwrap(); - let circuit = SummaHyperplonk::::init(entries.to_vec()); + + let halo2_circuit = + SummaHyperplonk::>::init( + entries.to_vec(), + ); + + let neg_grand_total = halo2_circuit + .grand_total + .iter() + .fold(Fp::ZERO, |acc, f| acc + f) + .neg(); + + // We're putting the negated grand total at the end of each balance column, + // so the sumcheck over such balance column would yield zero (given the special gate, + // see the circuit). + assert!( + neg_grand_total + == halo2_circuit.instances()[0] + .iter() + .fold(Fp::ZERO, |acc, instance| { acc + instance }) + ); + let num_vars = K; let circuit_fn = |num_vars| { - let circuit = Halo2Circuit::>::new::< - ProvingBackend, - >(num_vars, circuit.clone()); + let circuit = Halo2Circuit::< + Fp, + SummaHyperplonk>, + >::new::(num_vars, halo2_circuit.clone()); (circuit.circuit_info().unwrap(), circuit) }; @@ -241,6 +268,71 @@ fn test_summa_hyperplonk() { .unwrap(); } +/// Test the sumcheck failure case +/// The grand total is set to a random value, which will cause the sumcheck to fail +/// because the sum of all valid balances is not equal to the negated random grand total +#[test] +fn test_sumcheck_fail() { + type ProvingBackend = HyperPlonk>; + let entries = generate_dummy_entries::().unwrap(); + + let halo2_circuit = SummaHyperplonk::< + N_USERS, + N_CURRENCIES, + NoRangeCheckConfig, + >::init_invalid_grand_total(entries.to_vec()); + + let num_vars = K; + + let circuit_fn = |num_vars| { + let circuit = Halo2Circuit::< + Fp, + SummaHyperplonk>, + >::new::(num_vars, halo2_circuit.clone()); + (circuit.circuit_info().unwrap(), circuit) + }; + + let (circuit_info, circuit) = circuit_fn(num_vars as usize); + let instances = circuit.instances(); + + let param = ProvingBackend::setup(&circuit_info, seeded_std_rng()).unwrap(); + + let (prover_parameters, verifier_parameters) = + ProvingBackend::preprocess(¶m, &circuit_info).unwrap(); + + let (_, proof_transcript) = { + let mut proof_transcript = Keccak256Transcript::new(()); + + let witness_polys = ProvingBackend::prove( + &prover_parameters, + &circuit, + &mut proof_transcript, + seeded_std_rng(), + ) + .unwrap(); + (witness_polys, proof_transcript) + }; + + let proof = proof_transcript.into_proof(); + + let mut transcript; + let result: Result<(), plonkish_backend::Error> = { + transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + ProvingBackend::verify( + &verifier_parameters, + instances, + &mut transcript, + seeded_std_rng(), + ) + }; + assert_eq!( + result, + Err(InvalidSnark( + "Unmatched between sum_check output and query evaluation".to_string() + )) + ); +} + #[cfg(feature = "dev-graph")] #[test] fn print_univariate_grand_sum_circuit() {