From ae26ec47f083202c3e4d5d6847f24ad175d52356 Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Tue, 28 May 2024 14:11:35 +0300 Subject: [PATCH] Add wide fib lookup step constraint. --- crates/prover/src/core/backend/cpu/circle.rs | 11 +-- crates/prover/src/core/poly/circle/canonic.rs | 4 + crates/prover/src/core/prover/mod.rs | 2 +- crates/prover/src/core/utils.rs | 92 ++++++++++++++++++- .../src/examples/wide_fibonacci/component.rs | 34 +++++-- .../wide_fibonacci/constraint_eval.rs | 72 ++++++++++++++- .../prover/src/examples/wide_fibonacci/mod.rs | 36 ++++---- .../src/examples/wide_fibonacci/trace_gen.rs | 31 ++++++- 8 files changed, 238 insertions(+), 44 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index b3b434b6b..3172d5f82 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -13,7 +13,7 @@ use crate::core::poly::circle::{ use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse; +use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; impl PolyOps for CpuBackend { type Twiddles = Vec; @@ -24,14 +24,7 @@ impl PolyOps for CpuBackend { ) -> CircleEvaluation { let domain = coset.circle_domain(); assert_eq!(values.len(), domain.size()); - let mut new_values = Vec::with_capacity(values.len()); - let half_len = 1 << (coset.log_size() - 1); - for i in 0..half_len { - new_values.push(values[i << 1]); - } - for i in 0..half_len { - new_values.push(values[domain.size() - 1 - (i << 1)]); - } + let mut new_values = coset_order_to_circle_domain_order(&values); CpuBackend::bit_reverse_column(&mut new_values); CircleEvaluation::new(domain, new_values) } diff --git a/crates/prover/src/core/poly/circle/canonic.rs b/crates/prover/src/core/poly/circle/canonic.rs index 0e571d51a..da0a5955b 100644 --- a/crates/prover/src/core/poly/circle/canonic.rs +++ b/crates/prover/src/core/poly/circle/canonic.rs @@ -63,6 +63,10 @@ impl CanonicCoset { self.coset.step_size } + pub fn step(&self) -> CirclePoint { + self.coset.step + } + pub fn index_at(&self, index: usize) -> CirclePointIndex { self.coset.index_at(index) } diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 41cb83146..9f9f0ad47 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -57,7 +57,7 @@ pub fn evaluate_and_commit_on_trace>( trace: ColumnVec>, ) -> Result<(CommitmentSchemeProver, InteractionElements), ProvingError> { let span = span!(Level::INFO, "Trace interpolation").entered(); - // TODO(AlonH): Remove clone. + // TODO(AlonH): Clone only the columns needed for interaction. let trace_polys = trace .clone() .into_iter() diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index cb6cc222e..d5176e93d 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -60,6 +60,50 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { i.reverse_bits() >> (usize::BITS - log_size) } +/// Returns the index of the previous element in a bit reversed +/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a domain of +/// size `domain_log_size`. +pub(crate) fn previous_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, +) -> usize { + let mut prev_index = bit_reverse_index(i, eval_log_size); + let half_size = 1 << (eval_log_size - 1); + let step_size = (eval_log_size - domain_log_size) as usize; + if prev_index < half_size { + prev_index = (prev_index + half_size - step_size) % half_size; + } else { + prev_index = ((prev_index + step_size) % half_size) + half_size; + } + bit_reverse_index(prev_index, eval_log_size) +} + +// TODO(AlonH): Pair both functions below with bit reverse. Consider removing both and calculating +// the indices instead. +pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec { + let n = values.len(); + let mut coset_order = vec![]; + for i in 0..(n / 2) { + coset_order.push(values[i]); + coset_order.push(values[n - 1 - i]); + } + coset_order +} + +pub(crate) fn coset_order_to_circle_domain_order(values: &[BaseField]) -> Vec { + let mut circle_domain_order = Vec::with_capacity(values.len()); + let n = values.len(); + let half_len = n / 2; + for i in 0..half_len { + circle_domain_order.push(values[i << 1]); + } + for i in 0..half_len { + circle_domain_order.push(values[n - 1 - (i << 1)]); + } + circle_domain_order +} + /// Performs a naive bit-reversal permutation inplace. /// /// # Panics @@ -107,12 +151,16 @@ where #[cfg(test)] mod tests { + use itertools::Itertools; use num_traits::One; + use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; + use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::NaturalOrder; use crate::core::utils::bit_reverse; - use crate::qm31; + use crate::{m31, qm31}; #[test] fn bit_reverse_works() { @@ -150,4 +198,46 @@ mod tests { assert_eq!(powers, vec![]); } + + #[test] + fn test_previous_bit_reversed_circle_domain_index() { + let log_size = 3; + let n = 1 << log_size; + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..n).map(|i| m31!(i as u32)).collect_vec(); + let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values.clone()); + let bit_reversed_evaluation = evaluation.clone().bit_reverse(); + + let neighbor_pairs = (0..n) + .map(|i| { + let prev_index = + super::previous_bit_reversed_circle_domain_index(i, log_size - 1, log_size); + ( + bit_reversed_evaluation[i], + bit_reversed_evaluation[prev_index], + ) + }) + .sorted() + .collect_vec(); + // 1 O 7 + // O O + // 6 0 + // O O + // 2 4 + // O O + // 5 O 3 + let mut expected_neighbor_pairs = vec![ + (m31!(0), m31!(3)), + (m31!(7), m31!(4)), + (m31!(1), m31!(0)), + (m31!(6), m31!(7)), + (m31!(2), m31!(1)), + (m31!(5), m31!(6)), + (m31!(3), m31!(2)), + (m31!(4), m31!(5)), + ]; + expected_neighbor_pairs.sort(); + + assert_eq!(neighbor_pairs, expected_neighbor_pairs); + } } diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs index 6a0c0e685..af4e4c6b8 100644 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ b/crates/prover/src/examples/wide_fibonacci/component.rs @@ -3,9 +3,10 @@ use itertools::Itertools; use crate::core::air::accumulation::PointEvaluationAccumulator; use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, Component, ComponentTraceWriter}; +use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::circle::CirclePoint; -use crate::core::constraints::{coset_vanishing, point_vanishing}; +use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE}; @@ -60,7 +61,7 @@ impl Air for WideFibAir { impl Component for WideFibComponent { fn n_constraints(&self) -> usize { - self.n_columns() - 1 + self.n_columns() } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -82,9 +83,10 @@ impl Component for WideFibComponent { &self, point: CirclePoint, ) -> TreeVec>>> { + let domain = CanonicCoset::new(self.log_column_size()); TreeVec::new(vec![ fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point), - vec![vec![point]; SECURE_EXTENSION_DEGREE], + vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], ]) } @@ -105,15 +107,29 @@ impl Component for WideFibComponent { SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|i| { mask[self.n_columns() + i][0] })); - let lookup_numerator = (lookup_value + let lookup_prev_value = + SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|i| { + mask[self.n_columns() + i][1] + })); + let lookup_step_numerator = (lookup_value + * shifted_secure_combination( + &[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]], + alpha, + z, + )) + - (lookup_prev_value * shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z)); + let lookup_step_denom = coset_vanishing(constraint_zero_domain, point) + / point_excluder(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(lookup_step_numerator / lookup_step_denom); + let lookup_boundary_numerator = (lookup_value * shifted_secure_combination( &[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]], alpha, z, )) - shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z); - let lookup_denom = point_vanishing(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(lookup_numerator / lookup_denom); + let lookup_boundary_denom = point_vanishing(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(lookup_boundary_numerator / lookup_boundary_denom); let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); @@ -139,10 +155,8 @@ impl ComponentTraceWriter for WideFibComponent { .columns .into_iter() .map(|eval| { - CircleEvaluation::::new( - trace[0].domain, - eval, - ) + let coset = CanonicCoset::new(trace[0].domain.log_size()); + CpuCircleEvaluation::new_canonical_ordered(coset, eval) }) .collect_vec() } diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index fd9fde5bb..08e607e7a 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -13,14 +13,16 @@ use crate::core::air::{ use crate::core::backend::CpuBackend; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::circle::Coset; -use crate::core::constraints::{coset_vanishing, point_vanishing}; +use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::pcs::TreeVec; use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, SecureCirclePoly}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse, shifted_secure_combination}; +use crate::core::utils::{ + bit_reverse, previous_bit_reversed_circle_domain_index, shifted_secure_combination, +}; use crate::core::{ColumnVec, InteractionElements}; use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS; @@ -131,6 +133,65 @@ impl WideFibComponent { accum.accumulate(i, *num * *denom_inverse); } } + + // TODO(AlonH): Simplify this function by using utility functions. + fn evaluate_lookup_step_constraints( + &self, + trace_evals: &TreeVec>>, + trace_eval_domain: CircleDomain, + zero_domain: Coset, + accum: &mut ColumnAccumulator<'_, CpuBackend>, + interaction_elements: &InteractionElements, + ) { + let max_constraint_degree = self.max_constraint_log_degree_bound(); + let mut denoms = vec![]; + for point in trace_eval_domain.iter() { + denoms.push( + coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), + ); + } + bit_reverse(&mut denoms); + let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; + BaseField::batch_inverse(&denoms, &mut denom_inverses); + let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)]; + let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); + + #[allow(clippy::needless_range_loop)] + for i in 0..trace_eval_domain.size() { + let value = + SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|j| { + trace_evals[1][j][i].into() + })); + let prev_index = previous_bit_reversed_circle_domain_index( + i, + zero_domain.log_size, + trace_eval_domain.log_size(), + ); + let prev_value = + SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|j| { + trace_evals[1][j][prev_index].into() + })); + numerators[i] = accum.random_coeff_powers[self.n_columns() - 1] + * ((value + * shifted_secure_combination( + &[ + trace_evals[0][self.n_columns() - 2][i], + trace_evals[0][self.n_columns() - 1][i], + ], + alpha, + z, + )) + - (prev_value + * shifted_secure_combination( + &[trace_evals[0][0][i], trace_evals[0][1][i]], + alpha, + z, + ))); + } + for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() { + accum.accumulate(i, *num * *denom_inverse); + } + } } impl ComponentProver for WideFibComponent { @@ -160,6 +221,13 @@ impl ComponentProver for WideFibComponent { &mut accum, interaction_elements, ); + self.evaluate_lookup_step_constraints( + trace_evals, + trace_eval_domain, + zero_domain, + &mut accum, + interaction_elements, + ) } } diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 678e6ca22..4e1ea77c6 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -22,9 +22,10 @@ mod tests { use crate::core::fields::IntoSlice; use crate::core::pcs::TreeVec; use crate::core::poly::circle::CanonicCoset; - use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, verify}; - use crate::core::utils::shifted_secure_combination; + use crate::core::utils::{ + bit_reverse, circle_domain_order_to_coset_order, shifted_secure_combination, + }; use crate::core::vcs::blake2_hash::Blake2sHasher; use crate::core::vcs::hasher::Hasher; use crate::core::InteractionElements; @@ -70,10 +71,7 @@ mod tests { assert_eq!( column[column_length - 1] * shifted_secure_combination( - &[ - input_trace[n_columns - 2][column_length - 1], - input_trace[n_columns - 1][column_length - 1] - ], + &[input_trace[n_columns - 2][1], input_trace[n_columns - 1][1]], alpha, z, ), @@ -113,10 +111,17 @@ mod tests { let alpha = qm31!(7, 1, 3, 4); let z = qm31!(11, 1, 2, 3); - let trace = gen_trace(&wide_fib, vec![input]); + let mut trace = gen_trace(&wide_fib, vec![input]); let input_trace = trace.iter().map(|values| &values[..]).collect_vec(); let lookup_column = write_lookup_column(&input_trace, alpha, z); + trace = trace + .iter_mut() + .map(|column| { + bit_reverse(column); + circle_domain_order_to_coset_order(column) + }) + .collect_vec(); assert_constraints_on_lookup_column(&lookup_column, &trace, alpha, z) } @@ -135,16 +140,16 @@ mod tests { let inputs = (0..1 << wide_fib.log_n_instances) .map(|i| Input { a: m31!(1), - b: m31!(i as u32), + b: m31!(i + 1_u32), }) .collect_vec(); - let trace = gen_trace(&wide_fib, inputs); + let trace_values = gen_trace(&wide_fib, inputs); - let trace_domain = CanonicCoset::new(wide_fib.log_column_size()).circle_domain(); - let trace = trace + let trace_domain = CanonicCoset::new(wide_fib.log_column_size()); + let trace = trace_values .into_iter() - .map(|eval| CpuCircleEvaluation::<_, BitReversedOrder>::new(trace_domain, eval)) + .map(|eval| CpuCircleEvaluation::new_canonical_ordered(trace_domain, eval)) .collect_vec(); let trace_polys = trace .clone() @@ -191,8 +196,7 @@ mod tests { let res = acc.finalize(); let poly = res.0[0].clone(); - - for coeff in poly.coeffs[1 << wide_fib.max_constraint_log_degree_bound()..].iter() { + for coeff in poly.coeffs[(1 << wide_fib.max_constraint_log_degree_bound()) - 1..].iter() { assert_eq!(*coeff, BaseField::zero()); } } @@ -216,10 +220,10 @@ mod tests { .collect(); let trace = gen_trace(&component, private_input); - let trace_domain = CanonicCoset::new(component.log_column_size()).circle_domain(); + let trace_domain = CanonicCoset::new(component.log_column_size()); let trace = trace .into_iter() - .map(|eval| CpuCircleEvaluation::<_, BitReversedOrder>::new(trace_domain, eval)) + .map(|eval| CpuCircleEvaluation::new_canonical_ordered(trace_domain, eval)) .collect_vec(); let air = WideFibAir { component }; let prover_channel = diff --git a/crates/prover/src/examples/wide_fibonacci/trace_gen.rs b/crates/prover/src/examples/wide_fibonacci/trace_gen.rs index 9f6e9e50b..788fc842c 100644 --- a/crates/prover/src/examples/wide_fibonacci/trace_gen.rs +++ b/crates/prover/src/examples/wide_fibonacci/trace_gen.rs @@ -1,10 +1,13 @@ +use itertools::Itertools; use num_traits::One; use super::component::Input; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::utils::shifted_secure_combination; +use crate::core::utils::{ + bit_reverse, circle_domain_order_to_coset_order, shifted_secure_combination, +}; /// Writes the trace row for the wide Fibonacci example to dst, given a private input. Returns the /// last two elements of the row in case the sequence is continued. @@ -34,12 +37,30 @@ pub fn write_lookup_column( let n_rows = input_trace[0].len(); let n_columns = input_trace.len(); let mut prev_value = SecureField::one(); + let mut input_trace = input_trace + .iter() + .map(|column| column.to_vec()) + .collect_vec(); + let natural_ordered_trace = input_trace + .iter_mut() + .map(|column| { + bit_reverse(column); + circle_domain_order_to_coset_order(column) + }) + .collect_vec(); + (0..n_rows) .map(|i| { - let numerator = - shifted_secure_combination(&[input_trace[0][i], input_trace[1][i]], alpha, z); + let numerator = shifted_secure_combination( + &[natural_ordered_trace[0][i], natural_ordered_trace[1][i]], + alpha, + z, + ); let denominator = shifted_secure_combination( - &[input_trace[n_columns - 2][i], input_trace[n_columns - 1][i]], + &[ + natural_ordered_trace[n_columns - 2][i], + natural_ordered_trace[n_columns - 1][i], + ], alpha, z, ); @@ -48,5 +69,5 @@ pub fn write_lookup_column( prev_value = cell; cell }) - .collect() + .collect_vec() }