Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wide fib lookup step constraint. #646

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};

impl PolyOps for CpuBackend {
type Twiddles = Vec<BaseField>;
Expand All @@ -24,14 +24,7 @@ impl PolyOps for CpuBackend {
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
let domain = coset.circle_domain();
assert_eq!(values.len(), domain.size());
let mut new_values = Vec::with_capacity(values.len());
let half_len = 1 << (coset.log_size() - 1);
for i in 0..half_len {
new_values.push(values[i << 1]);
}
for i in 0..half_len {
new_values.push(values[domain.size() - 1 - (i << 1)]);
}
let mut new_values = coset_order_to_circle_domain_order(&values);
CpuBackend::bit_reverse_column(&mut new_values);
CircleEvaluation::new(domain, new_values)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/poly/circle/canonic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ impl CanonicCoset {
self.coset.step_size
}

pub fn step(&self) -> CirclePoint<BaseField> {
self.coset.step
}

pub fn index_at(&self, index: usize) -> CirclePointIndex {
self.coset.index_at(index)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> Result<(CommitmentSchemeProver<B>, InteractionElements), ProvingError> {
let span = span!(Level::INFO, "Trace interpolation").entered();
// TODO(AlonH): Remove clone.
// TODO(AlonH): Clone only the columns needed for interaction.
let trace_polys = trace
.clone()
.into_iter()
Expand Down
92 changes: 91 additions & 1 deletion crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,50 @@ pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
i.reverse_bits() >> (usize::BITS - log_size)
}

/// Returns the index of the previous element in a bit reversed
/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a domain of
/// size `domain_log_size`.
pub(crate) fn previous_bit_reversed_circle_domain_index(
i: usize,
domain_log_size: u32,
eval_log_size: u32,
) -> usize {
let mut prev_index = bit_reverse_index(i, eval_log_size);
let half_size = 1 << (eval_log_size - 1);
let step_size = (eval_log_size - domain_log_size) as usize;
if prev_index < half_size {
prev_index = (prev_index + half_size - step_size) % half_size;
} else {
prev_index = ((prev_index + step_size) % half_size) + half_size;
}
bit_reverse_index(prev_index, eval_log_size)
}

// TODO(AlonH): Pair both functions below with bit reverse. Consider removing both and calculating
// the indices instead.
pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec<BaseField> {
let n = values.len();
let mut coset_order = vec![];
for i in 0..(n / 2) {
coset_order.push(values[i]);
coset_order.push(values[n - 1 - i]);
}
coset_order
}

pub(crate) fn coset_order_to_circle_domain_order(values: &[BaseField]) -> Vec<BaseField> {
let mut circle_domain_order = Vec::with_capacity(values.len());
let n = values.len();
let half_len = n / 2;
for i in 0..half_len {
circle_domain_order.push(values[i << 1]);
}
for i in 0..half_len {
circle_domain_order.push(values[n - 1 - (i << 1)]);
}
circle_domain_order
}

/// Performs a naive bit-reversal permutation inplace.
///
/// # Panics
Expand Down Expand Up @@ -107,12 +151,16 @@ where

#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;

use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::NaturalOrder;
use crate::core::utils::bit_reverse;
use crate::qm31;
use crate::{m31, qm31};

#[test]
fn bit_reverse_works() {
Expand Down Expand Up @@ -150,4 +198,46 @@ mod tests {

assert_eq!(powers, vec![]);
}

#[test]
fn test_previous_bit_reversed_circle_domain_index() {
let log_size = 3;
let n = 1 << log_size;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..n).map(|i| m31!(i as u32)).collect_vec();
let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values.clone());
let bit_reversed_evaluation = evaluation.clone().bit_reverse();

let neighbor_pairs = (0..n)
.map(|i| {
let prev_index =
super::previous_bit_reversed_circle_domain_index(i, log_size - 1, log_size);
(
bit_reversed_evaluation[i],
bit_reversed_evaluation[prev_index],
)
})
.sorted()
.collect_vec();
// 1 O 7
// O O
// 6 0
// O O
// 2 4
// O O
// 5 O 3
let mut expected_neighbor_pairs = vec![
(m31!(0), m31!(3)),
(m31!(7), m31!(4)),
(m31!(1), m31!(0)),
(m31!(6), m31!(7)),
(m31!(2), m31!(1)),
(m31!(5), m31!(6)),
(m31!(3), m31!(2)),
(m31!(4), m31!(5)),
];
expected_neighbor_pairs.sort();

assert_eq!(neighbor_pairs, expected_neighbor_pairs);
}
}
34 changes: 24 additions & 10 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use itertools::Itertools;
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component, ComponentTraceWriter};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::CpuBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
Expand Down Expand Up @@ -60,7 +61,7 @@ impl Air for WideFibAir {

impl Component for WideFibComponent {
fn n_constraints(&self) -> usize {
self.n_columns() - 1
self.n_columns()
}

fn max_constraint_log_degree_bound(&self) -> u32 {
Expand All @@ -82,9 +83,10 @@ impl Component for WideFibComponent {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let domain = CanonicCoset::new(self.log_column_size());
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point),
vec![vec![point]; SECURE_EXTENSION_DEGREE],
vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -105,15 +107,29 @@ impl Component for WideFibComponent {
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
}));
let lookup_numerator = (lookup_value
let lookup_prev_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][1]
}));
let lookup_step_numerator = (lookup_value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
z,
))
- (lookup_prev_value * shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z));
let lookup_step_denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(lookup_step_numerator / lookup_step_denom);
let lookup_boundary_numerator = (lookup_value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
z,
))
- shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z);
let lookup_denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(lookup_numerator / lookup_denom);
let lookup_boundary_denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(lookup_boundary_numerator / lookup_boundary_denom);

let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
Expand All @@ -139,10 +155,8 @@ impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
.columns
.into_iter()
.map(|eval| {
CircleEvaluation::<CpuBackend, BaseField, BitReversedOrder>::new(
trace[0].domain,
eval,
)
let coset = CanonicCoset::new(trace[0].domain.log_size());
CpuCircleEvaluation::new_canonical_ordered(coset, eval)
})
.collect_vec()
}
Expand Down
72 changes: 70 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ use crate::core::air::{
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::circle::Coset;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, shifted_secure_combination};
use crate::core::utils::{
bit_reverse, previous_bit_reversed_circle_domain_index, shifted_secure_combination,
};
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;

Expand Down Expand Up @@ -131,6 +133,65 @@ impl WideFibComponent {
accum.accumulate(i, *num * *denom_inverse);
}
}

// TODO(AlonH): Simplify this function by using utility functions.
fn evaluate_lookup_step_constraints(
&self,
trace_evals: &TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
trace_eval_domain: CircleDomain,
zero_domain: Coset,
accum: &mut ColumnAccumulator<'_, CpuBackend>,
interaction_elements: &InteractionElements,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let mut denoms = vec![];
for point in trace_eval_domain.iter() {
denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][prev_index].into()
}));
numerators[i] = accum.random_coeff_powers[self.n_columns() - 1]
* ((value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
trace_evals[0][self.n_columns() - 1][i],
],
alpha,
z,
))
- (prev_value
* shifted_secure_combination(
&[trace_evals[0][0][i], trace_evals[0][1][i]],
alpha,
z,
)));
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
}
}
}

impl ComponentProver<CpuBackend> for WideFibComponent {
Expand Down Expand Up @@ -160,6 +221,13 @@ impl ComponentProver<CpuBackend> for WideFibComponent {
&mut accum,
interaction_elements,
);
self.evaluate_lookup_step_constraints(
trace_evals,
trace_eval_domain,
zero_domain,
&mut accum,
interaction_elements,
)
}
}

Expand Down
Loading
Loading