Skip to content

Commit

Permalink
Create phase number constants.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jun 24, 2024
1 parent 58d9263 commit 6cf38e5
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 39 deletions.
21 changes: 11 additions & 10 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::{ColumnVec, ComponentVec, InteractionElements, LookupValues};
Expand Down Expand Up @@ -37,8 +38,8 @@ pub trait AirExt: Air {
let mut interaction_component_points = vec![];
for component in self.components() {
let points = component.mask_points(point);
trace_component_points.extend(points[0].clone());
interaction_component_points.extend(points[1].clone());
trace_component_points.extend(points[BASE_TRACE].clone());
interaction_component_points.extend(points[INTERACTION_TRACE].clone());
}
let mut points = TreeVec::new(vec![trace_component_points]);
if !interaction_component_points
Expand Down Expand Up @@ -78,8 +79,8 @@ pub trait AirExt: Air {
let mut interaction_tree = vec![];
self.components().iter().for_each(|component| {
let bounds = component.trace_log_degree_bounds();
trace_tree.extend(bounds[0].clone());
interaction_tree.extend(bounds[1].clone());
trace_tree.extend(bounds[BASE_TRACE].clone());
interaction_tree.extend(bounds[INTERACTION_TRACE].clone());
});
let mut sizes = TreeVec::new(vec![trace_tree]);
if !interaction_tree.is_empty() {
Expand All @@ -92,11 +93,11 @@ pub trait AirExt: Air {
&'a self,
trees: &'a [CommitmentTreeProver<B>],
) -> Vec<ComponentTrace<'_, B>> {
let poly_iter = &mut trees[0].polynomials.iter();
let eval_iter = &mut trees[0].evaluations.iter();
let poly_iter = &mut trees[BASE_TRACE].polynomials.iter();
let eval_iter = &mut trees[BASE_TRACE].evaluations.iter();
let mut component_traces = vec![];
self.components().iter().for_each(|component| {
let n_columns = component.trace_log_degree_bounds()[0].len();
let n_columns = component.trace_log_degree_bounds()[BASE_TRACE].len();
let polys = poly_iter.take(n_columns).collect_vec();
let evals = eval_iter.take(n_columns).collect_vec();

Expand All @@ -107,13 +108,13 @@ pub trait AirExt: Air {
});

if trees.len() > 1 {
let poly_iter = &mut trees[1].polynomials.iter();
let eval_iter = &mut trees[1].evaluations.iter();
let poly_iter = &mut trees[INTERACTION_TRACE].polynomials.iter();
let eval_iter = &mut trees[INTERACTION_TRACE].evaluations.iter();
self.components()
.iter()
.zip_eq(&mut component_traces)
.for_each(|(component, component_trace)| {
let n_columns = component.trace_log_degree_bounds()[1].len();
let n_columns = component.trace_log_degree_bounds()[INTERACTION_TRACE].len();
let polys = poly_iter.take(n_columns).collect_vec();
let evals = eval_iter.take(n_columns).collect_vec();
component_trace.polys.push(polys);
Expand Down
20 changes: 16 additions & 4 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pub const LOG_LAST_LAYER_DEGREE_BOUND: u32 = 0;
pub const PROOF_OF_WORK_BITS: u32 = 12;
pub const N_QUERIES: usize = 3;

pub const BASE_TRACE: usize = 0;
pub const INTERACTION_TRACE: usize = 1;

#[derive(Debug)]
pub struct StarkProof {
pub commitments: TreeVec<<ChannelHasher as Hasher>::Hash>,
Expand Down Expand Up @@ -198,11 +201,19 @@ pub fn verify(
// Read trace commitment.
let mut commitment_scheme = CommitmentSchemeVerifier::new();
let column_log_sizes = air.column_log_sizes();
commitment_scheme.commit(proof.commitments[0], &column_log_sizes[0], channel);
commitment_scheme.commit(
proof.commitments[BASE_TRACE],
&column_log_sizes[BASE_TRACE],
channel,
);
let interaction_elements = air.interaction_elements(channel);

if air.n_interaction_phases() == 2 {
commitment_scheme.commit(proof.commitments[1], &column_log_sizes[1], channel);
commitment_scheme.commit(
proof.commitments[INTERACTION_TRACE],
&column_log_sizes[INTERACTION_TRACE],
channel,
);
}

let random_coeff = channel.draw_felt();
Expand Down Expand Up @@ -257,7 +268,7 @@ fn sampled_values_to_mask(
.iter();
let mut trace_oods_values = vec![];
air.components().iter().for_each(|component| {
let n_trace_points = component.mask_points(CirclePoint::zero())[0].len();
let n_trace_points = component.mask_points(CirclePoint::zero())[BASE_TRACE].len();
trace_oods_values.push(
flat_trace_values
.take(n_trace_points)
Expand All @@ -276,7 +287,8 @@ fn sampled_values_to_mask(
.iter()
.zip_eq(&mut trace_oods_values)
.for_each(|(component, values)| {
let n_interaction_points = component.mask_points(CirclePoint::zero())[1].len();
let n_interaction_points =
component.mask_points(CirclePoint::zero())[INTERACTION_TRACE].len();
values.extend(
interaction_values
.take(n_interaction_points)
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::BASE_TRACE;
use crate::core::utils::bit_reverse_index;
use crate::core::{ColumnVec, InteractionElements, LookupValues};

Expand Down Expand Up @@ -145,7 +146,7 @@ impl ComponentProver<CpuBackend> for FibonacciComponent {
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
let poly = &trace.polys[0][0];
let poly = &trace.polys[BASE_TRACE][0];
let trace_domain = CanonicCoset::new(self.log_size);
let trace_eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain();
let trace_eval = poly.evaluate(trace_eval_domain).bit_reverse();
Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ mod tests {
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
use crate::core::prover::VerificationError;
use crate::core::prover::{VerificationError, BASE_TRACE};
use crate::core::queries::Queries;
use crate::core::utils::bit_reverse;
use crate::core::{InteractionElements, LookupValues};
Expand Down Expand Up @@ -167,7 +167,7 @@ mod tests {
let point = CirclePoint::<SecureField>::get_point(98989892);

let points = fib.air.mask_points(point);
let mask_values = zip(&component_traces[0].polys[0], &points[0])
let mask_values = zip(&component_traces[0].polys[BASE_TRACE], &points[0])
.map(|(poly, points)| {
points
.iter()
Expand Down Expand Up @@ -238,7 +238,8 @@ mod tests {
let fib = Fibonacci::new(FIB_LOG_SIZE, m31!(443693538));

let mut invalid_proof = fib.prove().unwrap();
invalid_proof.commitment_scheme_proof.queried_values.0[0][0][3] += BaseField::one();
invalid_proof.commitment_scheme_proof.queried_values.0[BASE_TRACE][0][3] +=
BaseField::one();

let error = fib.verify(invalid_proof).unwrap_err();
assert_matches!(error, VerificationError::Merkle(_));
Expand Down Expand Up @@ -268,7 +269,7 @@ mod tests {
let fib = Fibonacci::new(FIB_LOG_SIZE, m31!(443693538));

let mut invalid_proof = fib.prove().unwrap();
invalid_proof.commitment_scheme_proof.queried_values.0[0][0].pop();
invalid_proof.commitment_scheme_proof.queried_values.0[BASE_TRACE][0].pop();

let error = fib.verify(invalid_proof).unwrap_err();
assert_matches!(error, VerificationError::Merkle(_));
Expand Down
34 changes: 18 additions & 16 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use crate::core::utils::{
bit_reverse, previous_bit_reversed_circle_domain_index, shifted_secure_combination,
};
Expand Down Expand Up @@ -88,14 +89,14 @@ impl WideFibComponent {
#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 4]
* (trace_evals[0][0][i] - lookup_values[LOOKUP_VALUE_0_ID])
* (trace_evals[BASE_TRACE][0][i] - lookup_values[LOOKUP_VALUE_0_ID])
+ accum.random_coeff_powers[self.n_columns() + 3]
* (trace_evals[0][1][i] - lookup_values[LOOKUP_VALUE_1_ID]);
* (trace_evals[BASE_TRACE][1][i] - lookup_values[LOOKUP_VALUE_1_ID]);
last_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 2]
* (trace_evals[0][self.n_columns() - 2][i]
* (trace_evals[BASE_TRACE][self.n_columns() - 2][i]
- lookup_values[LOOKUP_VALUE_N_MINUS_2_ID])
+ accum.random_coeff_powers[self.n_columns() + 1]
* (trace_evals[0][self.n_columns() - 1][i]
* (trace_evals[BASE_TRACE][self.n_columns() - 1][i]
- lookup_values[LOOKUP_VALUE_N_MINUS_1_ID]);
}
for (i, (num, denom_inverse)) in first_point_numerators
Expand Down Expand Up @@ -135,8 +136,9 @@ impl WideFibComponent {
for i in 0..trace_eval_domain.size() {
for j in 0..self.n_columns() - 2 {
numerators[i] += accum.random_coeff_powers[self.n_columns() - 3 - j]
* (trace_evals[0][j][i].square() + trace_evals[0][j + 1][i].square()
- trace_evals[0][j + 2][i]);
* (trace_evals[BASE_TRACE][j][i].square()
+ trace_evals[BASE_TRACE][j + 1][i].square()
- trace_evals[BASE_TRACE][j + 2][i]);
}
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
Expand Down Expand Up @@ -174,20 +176,20 @@ impl WideFibComponent {
for i in 0..trace_eval_domain.size() {
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
trace_evals[INTERACTION_TRACE][j][i].into()
}));
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() - 1]
* ((value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
trace_evals[0][self.n_columns() - 1][i],
trace_evals[BASE_TRACE][self.n_columns() - 2][i],
trace_evals[BASE_TRACE][self.n_columns() - 1][i],
],
alpha,
z,
))
- shifted_secure_combination(
&[trace_evals[0][0][i], trace_evals[0][1][i]],
&[trace_evals[BASE_TRACE][0][i], trace_evals[BASE_TRACE][1][i]],
alpha,
z,
));
Expand Down Expand Up @@ -251,27 +253,27 @@ impl WideFibComponent {
for i in 0..trace_eval_domain.size() {
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
trace_evals[INTERACTION_TRACE][j][i].into()
}));
let prev_index =
previous_bit_reversed_circle_domain_index(i, trace_eval_domain.log_size());
let prev_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][prev_index].into()
trace_evals[INTERACTION_TRACE][j][prev_index].into()
}));
numerators[i] = accum.random_coeff_powers[self.n_columns()]
* ((value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
trace_evals[0][self.n_columns() - 1][i],
trace_evals[BASE_TRACE][self.n_columns() - 2][i],
trace_evals[BASE_TRACE][self.n_columns() - 1][i],
],
alpha,
z,
))
- (prev_value
* shifted_secure_combination(
&[trace_evals[0][0][i], trace_evals[0][1][i]],
&[trace_evals[BASE_TRACE][0][i], trace_evals[BASE_TRACE][1][i]],
alpha,
z,
)));
Expand Down Expand Up @@ -329,7 +331,7 @@ impl ComponentProver<CpuBackend> for WideFibComponent {

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_column_size());
let trace_poly = &trace.polys[0];
let trace_poly = &trace.polys[BASE_TRACE];
let values = BTreeMap::from_iter([
(
LOOKUP_VALUE_0_ID.to_string(),
Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::BASE_TRACE;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::component::{ALPHA_ID, N_COLUMNS, Z_ID};

Expand Down Expand Up @@ -179,7 +180,7 @@ impl ComponentProver<SimdBackend> for SimdWideFibComponent {
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
assert_eq!(trace.polys[0].len(), self.n_columns());
assert_eq!(trace.polys[BASE_TRACE].len(), self.n_columns());
// TODO(spapini): Steal evaluation from commitment.
let eval_domain = CanonicCoset::new(self.log_column_size() + 1).circle_domain();
let trace_eval = &trace.evals;
Expand All @@ -204,14 +205,14 @@ impl ComponentProver<SimdBackend> for SimdWideFibComponent {

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) {
// Numerator.
let a = trace_eval[0][0].data[vec_row];
let a = trace_eval[BASE_TRACE][0].data[vec_row];
let mut row_res = PackedSecureField::zero();
let mut a_sq = a.square();
let mut b_sq = trace_eval[0][1].data[vec_row].square();
let mut b_sq = trace_eval[BASE_TRACE][1].data[vec_row].square();
#[allow(clippy::needless_range_loop)]
for i in 0..(self.n_columns() - 2) {
unsafe {
let c = *trace_eval[0]
let c = *trace_eval[BASE_TRACE]
.get_unchecked(i + 2)
.data
.get_unchecked(vec_row);
Expand Down

0 comments on commit 6cf38e5

Please sign in to comment.