From 7919da29ff5b20d868c8ac4f74dcbab3fcc5739a Mon Sep 17 00:00:00 2001 From: even <35983442+10to4@users.noreply.github.com> Date: Mon, 13 Jan 2025 16:39:23 +0800 Subject: [PATCH] Feat/skip fixed commit of range table (#797) Work for #789 --------- Co-authored-by: sm.wu --- ceno_zkvm/src/chip_handler/general.rs | 4 +- ceno_zkvm/src/circuit_builder.rs | 13 ++--- ceno_zkvm/src/scheme/mock_prover.rs | 10 ++-- ceno_zkvm/src/scheme/prover.rs | 9 ++-- ceno_zkvm/src/scheme/verifier.rs | 31 ++++++++---- ceno_zkvm/src/stats.rs | 2 +- ceno_zkvm/src/tables/ops/ops_impl.rs | 13 ++++- ceno_zkvm/src/tables/program.rs | 7 ++- ceno_zkvm/src/tables/range/range_circuit.rs | 18 ++++--- ceno_zkvm/src/tables/range/range_impl.rs | 54 ++++++++++----------- 10 files changed, 92 insertions(+), 69 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 4fe85f3d9..6692a8cb2 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -95,7 +95,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn lk_table_record( &mut self, name_fn: N, - table_len: usize, + table_spec: SetTableSpec, rom_type: ROMType, record: Vec>, multiplicity: Expression, @@ -105,7 +105,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR, { self.cs - .lk_table_record(name_fn, table_len, rom_type, record, multiplicity) + .lk_table_record(name_fn, table_spec, rom_type, record, multiplicity) } pub fn r_table_record( diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 5c3e2dc2c..6fc2b6c43 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,4 +1,3 @@ -use ceno_emul::Addr; use itertools::{Itertools, chain}; use std::{collections::HashMap, iter::once, marker::PhantomData}; @@ -56,13 +55,7 @@ impl NameSpace { pub struct LogupTableExpression { pub multiplicity: Expression, pub values: Expression, - pub table_len: usize, -} - -#[derive(Clone, Debug)] -pub struct DynamicAddr { - pub addr_witin_id: usize, - pub offset: Addr, + pub table_spec: SetTableSpec, } #[derive(Clone, Debug)] @@ -297,7 +290,7 @@ impl ConstraintSystem { pub fn lk_table_record( &mut self, name_fn: N, - table_len: usize, + table_spec: SetTableSpec, rom_type: ROMType, record: Vec>, multiplicity: Expression, @@ -321,7 +314,7 @@ impl ConstraintSystem { self.lk_table_expressions.push(LogupTableExpression { values: rlc_record, multiplicity, - table_len, + table_spec, }); let path = self.ns.compute_path(name_fn().into()); self.lk_expressions_namespace_map.push(path); diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index e63ae4ee7..9ea7791c3 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -791,6 +791,7 @@ Hints: ); let mut wit_mles = HashMap::new(); + let mut structural_wit_mles = HashMap::new(); let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); @@ -815,15 +816,17 @@ Hints: if witness.num_instances() == 0 { wit_mles.insert(circuit_name.clone(), vec![]); + structural_wit_mles.insert(circuit_name.clone(), vec![]); fixed_mles.insert(circuit_name.clone(), vec![]); num_instances.insert(circuit_name.clone(), num_rows); continue; } - let witness = witness + let mut witness = witness .into_mles() .into_iter() .map(|w| w.into()) .collect_vec(); + let structural_witness = witness.split_off(cs.num_witin as usize); let fixed: Vec<_> = fixed_trace .circuit_fixed_traces .remove(circuit_name) @@ -876,7 +879,7 @@ Hints: let lk_table = wit_infer_by_expr( &fixed, &witness, - &[], + &structural_witness, &pi_mles, &challenges, &expr.values, @@ -887,7 +890,7 @@ Hints: let multiplicity = wit_infer_by_expr( &fixed, &witness, - &[], + &structural_witness, &pi_mles, &challenges, &expr.multiplicity, @@ -905,6 +908,7 @@ Hints: } } wit_mles.insert(circuit_name.clone(), witness); + structural_wit_mles.insert(circuit_name.clone(), structural_witness); fixed_mles.insert(circuit_name.clone(), fixed); num_instances.insert(circuit_name.clone(), num_rows); } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index a7042e5ca..06eda546b 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -964,12 +964,9 @@ impl> ZKVMProver { ); exit_span!(tower_span); - // same point sumcheck is optional when all witin + fixed are in same num_vars - let is_skip_same_point_sumcheck = witnesses - .iter() - .chain(fixed.iter()) - .map(|v| v.num_vars()) - .all_equal(); + // In table proof, we always skip same point sumcheck for now + // as tower sumcheck batch product argument/logup in same length + let is_skip_same_point_sumcheck = true; let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) = if is_skip_same_point_sumcheck { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 03ef32e64..e9cc8fa6e 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -522,6 +522,7 @@ impl> ZKVMVerifier let tower_proofs = &proof.tower_proof; let expected_rounds = cs + // only iterate r set, as read/write set round should match .r_table_expressions .iter() .flat_map(|r| { @@ -538,13 +539,24 @@ impl> ZKVMVerifier .max() .unwrap() }); - [num_vars, num_vars] + [num_vars, num_vars] // format: [read_round, write_round] }) - .chain( - cs.lk_table_expressions - .iter() - .map(|l| ceil_log2(l.table_len)), - ) + .chain(cs.lk_table_expressions.iter().map(|l| { + // iterate through structural witins and collect max round. + let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| { + l.table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { id, max_len, .. }| { + let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + assert!((1 << hint_num_vars) <= *max_len); + hint_num_vars + }) + .max() + .unwrap() + }); + num_vars + })) .collect_vec(); for var in proof.rw_hints_num_vars.iter() { @@ -693,9 +705,10 @@ impl> ZKVMVerifier let structural_witnesses = cs .r_table_expressions .iter() - .flat_map(|set_table_expression| { - set_table_expression - .table_spec + .map(|r| &r.table_spec) + .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) + .flat_map(|table_spec| { + table_spec .structural_witins .iter() .map( diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs index 7643d0c12..89271277a 100644 --- a/ceno_zkvm/src/stats.rs +++ b/ceno_zkvm/src/stats.rs @@ -107,7 +107,7 @@ impl CircuitStats { }) } else { let table_len = if !system.lk_table_expressions.is_empty() { - system.lk_table_expressions[0].table_len + system.lk_table_expressions[0].table_spec.len.unwrap_or(0) } else { 0 }; diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index efe7c4de9..8ece6a077 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -7,7 +7,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato use std::collections::HashMap; use crate::{ - circuit_builder::CircuitBuilder, + circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, instructions::InstancePaddingStrategy, @@ -38,7 +38,16 @@ impl OpTableConfig { let record_exprs = abc.into_iter().map(|f| Expression::Fixed(f)).collect_vec(); - cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?; + cb.lk_table_record( + || "record", + SetTableSpec { + len: Some(table_len), + structural_witins: vec![], + }, + rom_type, + record_exprs, + mlt.expr(), + )?; Ok(Self { abc, mlt }) } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 5a43af187..5f3773878 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, marker::PhantomData}; use crate::{ - circuit_builder::CircuitBuilder, + circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, instructions::InstancePaddingStrategy, @@ -115,7 +115,10 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", - cb.params.program_size.next_power_of_two(), + SetTableSpec { + len: Some(cb.params.program_size.next_power_of_two()), + structural_witins: vec![], + }, ROMType::Instruction, record_exprs, mlt.expr(), diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index 83d8da017..2ecb6c6f3 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -5,8 +5,8 @@ use super::range_impl::RangeTableConfig; use std::{collections::HashMap, marker::PhantomData}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit, - witness::RowMajorMatrix, + circuit_builder::CircuitBuilder, error::ZKVMError, instructions::InstancePaddingStrategy, + structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix, }; use ff_ext::ExtensionField; @@ -40,11 +40,11 @@ impl TableCircuit for RangeTableCircuit } fn generate_fixed_traces( - config: &RangeTableConfig, - num_fixed: usize, + _config: &RangeTableConfig, + _num_fixed: usize, _input: &(), ) -> RowMajorMatrix { - config.generate_fixed_traces(num_fixed, RANGE::content()) + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) } fn assign_instances( @@ -55,6 +55,12 @@ impl TableCircuit for RangeTableCircuit _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; - config.assign_instances(num_witin, num_structural_witin, multiplicity, RANGE::len()) + config.assign_instances( + num_witin, + num_structural_witin, + multiplicity, + RANGE::content(), + RANGE::len(), + ) } } diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 6e7ebaee4..2832892dd 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -6,19 +6,19 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato use std::collections::HashMap; use crate::{ - circuit_builder::CircuitBuilder, + circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, - expression::{Expression, Fixed, ToExpr, WitIn}, + expression::{StructuralWitIn, ToExpr, WitIn}, instructions::InstancePaddingStrategy, scheme::constants::MIN_PAR_SIZE, - set_fixed_val, set_val, + set_val, structs::ROMType, witness::RowMajorMatrix, }; #[derive(Clone, Debug)] pub struct RangeTableConfig { - fixed: Fixed, + range: StructuralWitIn, mlt: WitIn, } @@ -28,33 +28,23 @@ impl RangeTableConfig { rom_type: ROMType, table_len: usize, ) -> Result { - let fixed = cb.create_fixed(|| "fixed")?; + let range = cb.create_structural_witin(|| "structural range witin", table_len, 0, 1); let mlt = cb.create_witin(|| "mlt"); - let record_exprs = vec![Expression::Fixed(fixed)]; + let record_exprs = vec![range.expr()]; - cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?; + cb.lk_table_record( + || "record", + SetTableSpec { + len: Some(table_len), + structural_witins: vec![range], + }, + rom_type, + record_exprs, + mlt.expr(), + )?; - Ok(Self { fixed, mlt }) - } - - pub fn generate_fixed_traces( - &self, - num_fixed: usize, - content: Vec, - ) -> RowMajorMatrix { - let mut fixed = - RowMajorMatrix::::new(content.len(), num_fixed, InstancePaddingStrategy::Default); - - fixed - .par_iter_mut() - .with_min_len(MIN_PAR_SIZE) - .zip(content.into_par_iter()) - .for_each(|(row, i)| { - set_fixed_val!(row, self.fixed, F::from(i)); - }); - - fixed + Ok(Self { range, mlt }) } pub fn assign_instances( @@ -62,6 +52,7 @@ impl RangeTableConfig { num_witin: usize, num_structural_witin: usize, multiplicity: &HashMap, + content: Vec, length: usize, ) -> Result, ZKVMError> { let mut witness = RowMajorMatrix::::new( @@ -75,12 +66,19 @@ impl RangeTableConfig { mlts[*idx as usize] = *mlt; } + let offset_range = StructuralWitIn { + id: self.range.id + (num_witin as u16), + ..self.range + }; + witness .par_iter_mut() .with_min_len(MIN_PAR_SIZE) .zip(mlts.into_par_iter()) - .for_each(|(row, mlt)| { + .zip(content.into_par_iter()) + .for_each(|((row, mlt), i)| { set_val!(row, self.mlt, F::from(mlt as u64)); + set_val!(row, offset_range, F::from(i)); }); Ok(witness)