Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/skip fixed commit of range table #797

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rom_type: ROMType,
record: Vec<Expression<E>>,
multiplicity: Expression<E>,
Expand All @@ -105,7 +105,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
N: FnOnce() -> NR,
{
self.cs
.lk_table_record(name_fn, table_len, rom_type, record, multiplicity)
.lk_table_record(name_fn, table_spec, rom_type, record, multiplicity)
}

pub fn r_table_record<NR, N>(
Expand Down
13 changes: 3 additions & 10 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ceno_emul::Addr;
use itertools::{Itertools, chain};
use std::{collections::HashMap, iter::once, marker::PhantomData};

Expand Down Expand Up @@ -56,13 +55,7 @@ impl NameSpace {
pub struct LogupTableExpression<E: ExtensionField> {
pub multiplicity: Expression<E>,
pub values: Expression<E>,
pub table_len: usize,
}

#[derive(Clone, Debug)]
pub struct DynamicAddr {
pub addr_witin_id: usize,
pub offset: Addr,
pub table_spec: SetTableSpec,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -297,7 +290,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rom_type: ROMType,
record: Vec<Expression<E>>,
multiplicity: Expression<E>,
Expand All @@ -321,7 +314,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
self.lk_table_expressions.push(LogupTableExpression {
values: rlc_record,
multiplicity,
table_len,
table_spec,
});
let path = self.ns.compute_path(name_fn().into());
self.lk_expressions_namespace_map.push(path);
Expand Down
10 changes: 7 additions & 3 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ Hints:
);

let mut wit_mles = HashMap::new();
let mut structural_wit_mles = HashMap::new();
let mut fixed_mles = HashMap::new();
let mut num_instances = HashMap::new();

Expand All @@ -815,15 +816,17 @@ Hints:

if witness.num_instances() == 0 {
wit_mles.insert(circuit_name.clone(), vec![]);
structural_wit_mles.insert(circuit_name.clone(), vec![]);
fixed_mles.insert(circuit_name.clone(), vec![]);
num_instances.insert(circuit_name.clone(), num_rows);
continue;
}
let witness = witness
let mut witness = witness
.into_mles()
.into_iter()
.map(|w| w.into())
.collect_vec();
let structural_witness = witness.split_off(cs.num_witin as usize);
let fixed: Vec<_> = fixed_trace
.circuit_fixed_traces
.remove(circuit_name)
Expand Down Expand Up @@ -876,7 +879,7 @@ Hints:
let lk_table = wit_infer_by_expr(
&fixed,
&witness,
&[],
&structural_witness,
&pi_mles,
&challenges,
&expr.values,
Expand All @@ -887,7 +890,7 @@ Hints:
let multiplicity = wit_infer_by_expr(
&fixed,
&witness,
&[],
&structural_witness,
&pi_mles,
&challenges,
&expr.multiplicity,
Expand All @@ -905,6 +908,7 @@ Hints:
}
}
wit_mles.insert(circuit_name.clone(), witness);
structural_wit_mles.insert(circuit_name.clone(), structural_witness);
fixed_mles.insert(circuit_name.clone(), fixed);
num_instances.insert(circuit_name.clone(), num_rows);
}
Expand Down
9 changes: 3 additions & 6 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,12 +964,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
);
exit_span!(tower_span);

// same point sumcheck is optional when all witin + fixed are in same num_vars
let is_skip_same_point_sumcheck = witnesses
.iter()
.chain(fixed.iter())
.map(|v| v.num_vars())
.all_equal();
// In table proof, we always skip same point sumcheck for now
// as tower sumcheck batch product argument/logup in same length
let is_skip_same_point_sumcheck = true;

let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) =
if is_skip_same_point_sumcheck {
Expand Down
31 changes: 22 additions & 9 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let tower_proofs = &proof.tower_proof;

let expected_rounds = cs
// only iterate r set, as read/write set round should match
.r_table_expressions
.iter()
.flat_map(|r| {
Expand All @@ -538,13 +539,24 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
.max()
.unwrap()
});
[num_vars, num_vars]
[num_vars, num_vars] // format: [read_round, write_round]
})
.chain(
cs.lk_table_expressions
.iter()
.map(|l| ceil_log2(l.table_len)),
)
.chain(cs.lk_table_expressions.iter().map(|l| {
// iterate through structural witins and collect max round.
let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| {
l.table_spec
.structural_witins
.iter()
.map(|StructuralWitIn { id, max_len, .. }| {
let hint_num_vars = proof.rw_hints_num_vars[*id as usize];
assert!((1 << hint_num_vars) <= *max_len);
hint_num_vars
})
.max()
.unwrap()
});
num_vars
}))
.collect_vec();

for var in proof.rw_hints_num_vars.iter() {
Expand Down Expand Up @@ -693,9 +705,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let structural_witnesses = cs
.r_table_expressions
.iter()
.flat_map(|set_table_expression| {
set_table_expression
.table_spec
.map(|r| &r.table_spec)
.chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec))
.flat_map(|table_spec| {
table_spec
.structural_witins
.iter()
.map(
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl CircuitStats {
})
} else {
let table_len = if !system.lk_table_expressions.is_empty() {
system.lk_table_expressions[0].table_len
system.lk_table_expressions[0].table_spec.len.unwrap_or(0)
} else {
0
};
Expand Down
13 changes: 11 additions & 2 deletions ceno_zkvm/src/tables/ops/ops_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
use std::collections::HashMap;

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
Expand Down Expand Up @@ -38,7 +38,16 @@ impl OpTableConfig {

let record_exprs = abc.into_iter().map(|f| Expression::Fixed(f)).collect_vec();

cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?;
cb.lk_table_record(
|| "record",
SetTableSpec {
len: Some(table_len),
structural_witins: vec![],
},
rom_type,
record_exprs,
mlt.expr(),
)?;

Ok(Self { abc, mlt })
}
Expand Down
7 changes: 5 additions & 2 deletions ceno_zkvm/src/tables/program.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, marker::PhantomData};

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
Expand Down Expand Up @@ -115,7 +115,10 @@ impl<E: ExtensionField> TableCircuit<E> for ProgramTableCircuit<E> {

cb.lk_table_record(
|| "prog table",
cb.params.program_size.next_power_of_two(),
SetTableSpec {
len: Some(cb.params.program_size.next_power_of_two()),
structural_witins: vec![],
},
ROMType::Instruction,
record_exprs,
mlt.expr(),
Expand Down
18 changes: 12 additions & 6 deletions ceno_zkvm/src/tables/range/range_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use super::range_impl::RangeTableConfig;
use std::{collections::HashMap, marker::PhantomData};

use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit,
witness::RowMajorMatrix,
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::InstancePaddingStrategy,
structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;

Expand Down Expand Up @@ -40,11 +40,11 @@ impl<E: ExtensionField, RANGE: RangeTable> TableCircuit<E> for RangeTableCircuit
}

fn generate_fixed_traces(
config: &RangeTableConfig,
num_fixed: usize,
_config: &RangeTableConfig,
_num_fixed: usize,
_input: &(),
) -> RowMajorMatrix<E::BaseField> {
config.generate_fixed_traces(num_fixed, RANGE::content())
RowMajorMatrix::<E::BaseField>::new(0, 0, InstancePaddingStrategy::Default)
}

fn assign_instances(
Expand All @@ -55,6 +55,12 @@ impl<E: ExtensionField, RANGE: RangeTable> TableCircuit<E> for RangeTableCircuit
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize];
config.assign_instances(num_witin, num_structural_witin, multiplicity, RANGE::len())
config.assign_instances(
num_witin,
num_structural_witin,
multiplicity,
RANGE::content(),
RANGE::len(),
)
}
}
54 changes: 26 additions & 28 deletions ceno_zkvm/src/tables/range/range_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
use std::collections::HashMap;

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
expression::{StructuralWitIn, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
set_val,
structs::ROMType,
witness::RowMajorMatrix,
};

#[derive(Clone, Debug)]
pub struct RangeTableConfig {
fixed: Fixed,
range: StructuralWitIn,
mlt: WitIn,
}

Expand All @@ -28,40 +28,31 @@ impl RangeTableConfig {
rom_type: ROMType,
table_len: usize,
) -> Result<Self, ZKVMError> {
let fixed = cb.create_fixed(|| "fixed")?;
let range = cb.create_structural_witin(|| "structural range witin", table_len, 0, 1);
let mlt = cb.create_witin(|| "mlt");

let record_exprs = vec![Expression::Fixed(fixed)];
let record_exprs = vec![range.expr()];

cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?;
cb.lk_table_record(
|| "record",
SetTableSpec {
len: Some(table_len),
structural_witins: vec![range],
},
rom_type,
record_exprs,
mlt.expr(),
)?;

Ok(Self { fixed, mlt })
}

pub fn generate_fixed_traces<F: SmallField>(
&self,
num_fixed: usize,
content: Vec<u64>,
) -> RowMajorMatrix<F> {
let mut fixed =
RowMajorMatrix::<F>::new(content.len(), num_fixed, InstancePaddingStrategy::Default);

fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(content.into_par_iter())
.for_each(|(row, i)| {
set_fixed_val!(row, self.fixed, F::from(i));
});

fixed
Ok(Self { range, mlt })
}

pub fn assign_instances<F: SmallField>(
&self,
num_witin: usize,
num_structural_witin: usize,
multiplicity: &HashMap<u64, usize>,
content: Vec<u64>,
length: usize,
) -> Result<RowMajorMatrix<F>, ZKVMError> {
let mut witness = RowMajorMatrix::<F>::new(
Expand All @@ -75,12 +66,19 @@ impl RangeTableConfig {
mlts[*idx as usize] = *mlt;
}

let offset_range = StructuralWitIn {
Copy link
Collaborator

@hero78119 hero78119 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just indicate some following up improvement: previous it's fixed column, thus fixed column just generate once and cached. Now we save prover key size & no PCS opening anymore on range column, but we need to assign and create same witness polynomial every time. We can improve this by introducing a extra design cache structural witness column for first time and reuse.

id: self.range.id + (num_witin as u16),
..self.range
};

witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(mlts.into_par_iter())
.for_each(|(row, mlt)| {
.zip(content.into_par_iter())
.for_each(|((row, mlt), i)| {
set_val!(row, self.mlt, F::from(mlt as u64));
set_val!(row, offset_range, F::from(i));
});

Ok(witness)
Expand Down